Skip to content

Commit aa16ef0

Browse files
epilysstsquad
authored andcommitted
sound/pipewire: handle StreamWithIdNotFound consistently
When handling controlq messages, handle StreamWithIdNotFound the same way on all methods: - if we are passed a &mut ControlMessage, set it to VIRTIO_SND_S_BAD_MSG. - if we are passed a stream_id, return StreamWithIdNotFound. Signed-off-by: Manos Pitsidianakis <[email protected]>
1 parent 208a796 commit aa16ef0

File tree

1 file changed

+71
-51
lines changed
  • staging/vhost-device-sound/src/audio_backends

1 file changed

+71
-51
lines changed

staging/vhost-device-sound/src/audio_backends/pipewire.rs

Lines changed: 71 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -139,21 +139,22 @@ impl AudioBackend for PwBackend {
139139
{
140140
let stream_clone = self.stream_params.clone();
141141
let mut stream_params = stream_clone.write().unwrap();
142-
let st = stream_params
143-
.get_mut(stream_id as usize)
144-
.expect("Stream does not exist");
145-
if let Err(err) = st.state.set_parameters() {
146-
log::error!("Stream {} set_parameters {}", stream_id, err);
147-
msg.code = VIRTIO_SND_S_BAD_MSG;
148-
} else if !st.supports_format(request.format) || !st.supports_rate(request.rate) {
149-
msg.code = VIRTIO_SND_S_NOT_SUPP;
142+
if let Some(st) = stream_params.get_mut(stream_id as usize) {
143+
if let Err(err) = st.state.set_parameters() {
144+
log::error!("Stream {} set_parameters {}", stream_id, err);
145+
msg.code = VIRTIO_SND_S_BAD_MSG;
146+
} else if !st.supports_format(request.format) || !st.supports_rate(request.rate) {
147+
msg.code = VIRTIO_SND_S_NOT_SUPP;
148+
} else {
149+
st.params.features = request.features;
150+
st.params.buffer_bytes = request.buffer_bytes;
151+
st.params.period_bytes = request.period_bytes;
152+
st.params.channels = request.channels;
153+
st.params.format = request.format;
154+
st.params.rate = request.rate;
155+
}
150156
} else {
151-
st.params.features = request.features;
152-
st.params.buffer_bytes = request.buffer_bytes;
153-
st.params.period_bytes = request.period_bytes;
154-
st.params.channels = request.channels;
155-
st.params.format = request.format;
156-
st.params.rate = request.rate;
157+
msg.code = VIRTIO_SND_S_BAD_MSG;
157158
}
158159
}
159160
drop(msg);
@@ -163,7 +164,12 @@ impl AudioBackend for PwBackend {
163164

164165
fn prepare(&self, stream_id: u32) -> Result<()> {
165166
debug!("pipewire prepare");
166-
let prepare_result = self.stream_params.write().unwrap()[stream_id as usize]
167+
let prepare_result = self
168+
.stream_params
169+
.write()
170+
.unwrap()
171+
.get_mut(stream_id as usize)
172+
.ok_or_else(|| Error::StreamWithIdNotFound(stream_id))?
167173
.state
168174
.prepare();
169175
if let Err(err) = prepare_result {
@@ -401,69 +407,83 @@ impl AudioBackend for PwBackend {
401407

402408
fn release(&self, stream_id: u32, mut msg: ControlMessage) -> Result<()> {
403409
debug!("pipewire backend, release function");
404-
let release_result = self.stream_params.write().unwrap()[stream_id as usize]
410+
let release_result = self
411+
.stream_params
412+
.write()
413+
.unwrap()
414+
.get_mut(stream_id as usize)
415+
.ok_or_else(|| {
416+
msg.code = VIRTIO_SND_S_BAD_MSG;
417+
Error::StreamWithIdNotFound(stream_id)
418+
})?
405419
.state
406420
.release();
407421
if let Err(err) = release_result {
408422
log::error!("Stream {} release {}", stream_id, err);
409423
msg.code = VIRTIO_SND_S_BAD_MSG;
410-
} else {
411-
let lock_guard = self.thread_loop.lock();
412-
let mut stream_hash = self.stream_hash.write().unwrap();
413-
let mut stream_listener = self.stream_listener.write().unwrap();
414-
let st_buffer = &mut self.stream_params.write().unwrap();
415-
416-
let Some(stream) = stream_hash.get(&stream_id) else {
417-
return Err(Error::StreamWithIdNotFound(stream_id));
418-
};
419-
stream.disconnect().expect("could not disconnect stream");
420-
std::mem::take(&mut st_buffer[stream_id as usize].buffers);
421-
stream_hash.remove(&stream_id);
422-
stream_listener.remove(&stream_id);
423-
424-
lock_guard.unlock();
424+
return Ok(());
425425
}
426-
426+
let lock_guard = self.thread_loop.lock();
427+
let mut stream_hash = self.stream_hash.write().unwrap();
428+
let mut stream_listener = self.stream_listener.write().unwrap();
429+
let st_buffer = &mut self.stream_params.write().unwrap();
430+
let stream = stream_hash
431+
.get(&stream_id)
432+
.expect("Could not find stream with this id in `stream_hash`.");
433+
stream.disconnect().expect("could not disconnect stream");
434+
std::mem::take(&mut st_buffer[stream_id as usize].buffers);
435+
stream_hash.remove(&stream_id);
436+
stream_listener.remove(&stream_id);
437+
lock_guard.unlock();
427438
Ok(())
428439
}
429440

430441
fn start(&self, stream_id: u32) -> Result<()> {
431442
debug!("pipewire start");
432-
let start_result = self.stream_params.write().unwrap()[stream_id as usize]
443+
let start_result = self
444+
.stream_params
445+
.write()
446+
.unwrap()
447+
.get_mut(stream_id as usize)
448+
.ok_or_else(|| Error::StreamWithIdNotFound(stream_id))?
433449
.state
434450
.start();
435451
if let Err(err) = start_result {
436452
// log the error and continue
437453
log::error!("Stream {} start {}", stream_id, err);
438-
} else {
439-
let lock_guard = self.thread_loop.lock();
440-
let stream_hash = self.stream_hash.read().unwrap();
441-
let Some(stream) = stream_hash.get(&stream_id) else {
442-
return Err(Error::StreamWithIdNotFound(stream_id));
443-
};
444-
stream.set_active(true).expect("could not start stream");
445-
lock_guard.unlock();
454+
return Ok(());
446455
}
456+
let lock_guard = self.thread_loop.lock();
457+
let stream_hash = self.stream_hash.read().unwrap();
458+
let stream = stream_hash
459+
.get(&stream_id)
460+
.expect("Could not find stream with this id in `stream_hash`.");
461+
stream.set_active(true).expect("could not start stream");
462+
lock_guard.unlock();
447463
Ok(())
448464
}
449465

450466
fn stop(&self, stream_id: u32) -> Result<()> {
451467
debug!("pipewire stop");
452-
let stop_result = self.stream_params.write().unwrap()[stream_id as usize]
468+
let stop_result = self
469+
.stream_params
470+
.write()
471+
.unwrap()
472+
.get_mut(stream_id as usize)
473+
.ok_or_else(|| Error::StreamWithIdNotFound(stream_id))?
453474
.state
454475
.stop();
455476
if let Err(err) = stop_result {
456477
log::error!("Stream {} stop {}", stream_id, err);
457-
} else {
458-
let lock_guard = self.thread_loop.lock();
459-
let stream_hash = self.stream_hash.read().unwrap();
460-
let Some(stream) = stream_hash.get(&stream_id) else {
461-
return Err(Error::StreamWithIdNotFound(stream_id));
462-
};
463-
stream.set_active(false).expect("could not stop stream");
464-
lock_guard.unlock();
478+
return Ok(());
465479
}
466-
480+
let lock_guard = self.thread_loop.lock();
481+
let stream_hash = self.stream_hash.read().unwrap();
482+
let stream = stream_hash
483+
.get(&stream_id)
484+
.expect("Could not find stream with this id in `stream_hash`.");
485+
stream.set_active(false).expect("could not stop stream");
486+
lock_guard.unlock();
467487
Ok(())
468488
}
469489
}

0 commit comments

Comments
 (0)