diff --git a/drivers/usb/xhcid/src/xhci/irq_reactor.rs b/drivers/usb/xhcid/src/xhci/irq_reactor.rs index ac492d5b..68317b2d 100644 --- a/drivers/usb/xhcid/src/xhci/irq_reactor.rs +++ b/drivers/usb/xhcid/src/xhci/irq_reactor.rs @@ -8,8 +8,11 @@ use std::task; use std::os::unix::io::AsRawFd; use crossbeam_channel::{Receiver, Sender}; +use futures::task::noop_waker; use log::{debug, error, info, trace, warn}; +use common::timeout::Timeout; + use super::doorbell::Doorbell; use super::event::EventRing; use super::ring::Ring; @@ -44,6 +47,30 @@ pub struct NextEventTrb { pub src_trb: Option, } +pub struct PendingEventWait { + message: Arc>>, +} + +impl PendingEventWait { + pub fn wait_timeout(&self, timeout: std::time::Duration) -> Option { + let timeout = Timeout::new(timeout); + + loop { + let Ok(mut message) = self.message.lock() else { + return None; + }; + + if let Some(message) = message.take() { + return Some(message); + } + + drop(message); + + timeout.run().ok()?; + } + } +} + // TODO: Perhaps all of the transfer rings used by the xHC should be stored linearly, and then // indexed using this struct instead. #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] @@ -626,6 +653,27 @@ impl Future for EventTrbFuture { } impl Xhci { + pub fn queue_command_completion_wait( + &self, + command_ring: &Ring, + trb: &Trb, + doorbell: EventDoorbell, + ) -> PendingEventWait { + let message = Arc::new(Mutex::new(None)); + let send_result = self.irq_reactor_sender.send(State { + waker: noop_waker(), + kind: StateKind::CommandCompletion { + phys_ptr: command_ring.trb_phys_ptr(self.cap.ac64(), trb), + }, + message: Arc::clone(&message), + is_isoch_or_vf: false, + }); + if send_result.is_ok() { + doorbell.ring(); + } + PendingEventWait { message } + } + pub fn get_transfer_trb(&self, paddr: u64, id: RingId) -> Option { self.with_ring(id, |ring| ring.phys_addr_to_entry(self.cap.ac64(), paddr)) .flatten() diff --git a/drivers/usb/xhcid/src/xhci/mod.rs b/drivers/usb/xhcid/src/xhci/mod.rs index f2143676..f406d16e 100644 --- a/drivers/usb/xhcid/src/xhci/mod.rs +++ b/drivers/usb/xhcid/src/xhci/mod.rs @@ -12,10 +12,11 @@ use std::collections::BTreeMap; use std::convert::TryFrom; use std::fs::File; -use std::sync::atomic::AtomicUsize; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::{Arc, Mutex}; use std::{mem, process, slice, thread}; +use std::time::Duration; use syscall::error::{Error, Result, EBADF, EBADMSG, EIO, ENOENT}; use syscall::{EAGAIN, PAGE_SIZE}; @@ -54,7 +55,7 @@ use self::event::EventRing; use self::extended::{CapabilityId, ExtendedCapabilitiesIter, ProtocolSpeed, SupportedProtoCap}; use self::irq_reactor::{EventDoorbell, IrqReactor, NewPendingTrb, RingId}; use self::operational::*; -use self::port::Port; +use self::port::{Port, PortLinkState}; use self::ring::Ring; use self::runtime::RuntimeRegs; use self::trb::{TransferKind, Trb, TrbCompletionCode}; @@ -77,7 +78,174 @@ pub enum InterruptMethod { Msi, } +const ATTACH_STEP_COUNT: usize = 6; +const DETACH_TIMEOUT: Duration = Duration::from_millis(100); + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +enum PortAttachmentState { + Pending, + Attached, + Detaching, +} + +struct PortRuntime { + detaching: AtomicBool, + inflight_transfers: AtomicUsize, +} + +impl PortRuntime { + fn new() -> Arc { + Arc::new(Self { + detaching: AtomicBool::new(false), + inflight_transfers: AtomicUsize::new(0), + }) + } + + fn begin_transfer(self: &Arc) -> Result { + if self.detaching.load(Ordering::SeqCst) { + return Err(Error::new(EAGAIN)); + } + + self.inflight_transfers.fetch_add(1, Ordering::SeqCst); + + if self.detaching.load(Ordering::SeqCst) { + self.inflight_transfers.fetch_sub(1, Ordering::SeqCst); + return Err(Error::new(EAGAIN)); + } + + Ok(InflightTransferGuard { + runtime: Arc::clone(self), + }) + } + + fn begin_detach(&self) { + self.detaching.store(true, Ordering::SeqCst); + } + + fn inflight_transfers(&self) -> usize { + self.inflight_transfers.load(Ordering::SeqCst) + } +} + +struct InflightTransferGuard { + runtime: Arc, +} + +impl Drop for InflightTransferGuard { + fn drop(&mut self) { + self.runtime.inflight_transfers.fetch_sub(1, Ordering::SeqCst); + } +} + impl Xhci { + fn any_port_state(&self, port: PortId) -> Result>> { + self.port_states + .get(&port) + .or_else(|| self.pending_port_states.get(&port)) + .ok_or(Error::new(ENOENT)) + } + + fn any_port_state_mut( + &self, + port: PortId, + ) -> Result>> { + self.port_states + .get_mut(&port) + .or_else(|| self.pending_port_states.get_mut(&port)) + .ok_or(Error::new(ENOENT)) + } + + pub(crate) fn begin_transfer_guard(&self, port: PortId) -> Result { + let port_state = self.port_states.get(&port).ok_or(Error::new(ENOENT))?; + port_state.runtime.begin_transfer() + } + + fn begin_internal_transfer_guard(&self, port: PortId) -> Result { + let port_state = self.any_port_state(port)?; + port_state.runtime.begin_transfer() + } + + fn wait_for_transfer_drain(&self, runtime: &PortRuntime, timeout: Duration) -> bool { + let timeout = Timeout::new(timeout); + loop { + if runtime.inflight_transfers() == 0 { + return true; + } + + if timeout.run().is_err() { + return false; + } + } + } + + fn attach_log_step(slot: u8, step: usize, step_name: &str) { + info!( + "xhcid: attach step {}/{} for slot {}: {}", + step, ATTACH_STEP_COUNT, slot, step_name + ); + } + + async fn rollback_attach(&self, port_id: PortId, slot: u8, step: usize) -> Result<()> { + warn!( + "xhcid: attach failed at step {}, rolling back slot {}", + step, slot + ); + self.pending_port_states.remove(&port_id); + + match self.disable_port_slot(slot).await { + Ok(()) => Ok(()), + Err(err) => { + warn!( + "xhcid: failed to disable slot {} during attach rollback on port {}: {}", + slot, port_id, err + ); + Err(err) + } + } + } + + fn current_link_state(&self, port_id: PortId) -> Option { + PortLinkState::from_port_state(self.get_pls(port_id)) + } + + fn transition_link_state(&self, port_id: PortId, slot: u8, to: PortLinkState) { + let from = self.current_link_state(port_id); + + if from == Some(to) { + return; + } + + if let Some(from) = from { + info!( + "xhcid: PM slot {}: {}→{}", + slot, + from.as_str(), + to.as_str() + ); + } + + if let Ok(mut ports) = self.ports.lock() { + if let Some(port) = ports.get_mut(port_id.root_hub_port_index()) { + port.set_link_state(to); + } + } + } + + fn wake_port_to_u0(&self, port_id: PortId, slot: u8) { + if self.current_link_state(port_id) == Some(PortLinkState::U3) { + self.transition_link_state(port_id, slot, PortLinkState::U0); + } + } + + fn quiesce_port_to_u3(&self, port_id: PortId, slot: u8) { + if self.current_link_state(port_id) == Some(PortLinkState::U0) { + self.transition_link_state(port_id, slot, PortLinkState::U2); + } + if self.current_link_state(port_id) == Some(PortLinkState::U2) { + self.transition_link_state(port_id, slot, PortLinkState::U3); + } + } + /// Gets descriptors, before the port state is initiated. async fn get_desc_raw( &self, @@ -103,8 +271,11 @@ impl Xhci { len ); + self.wake_port_to_u0(port, slot); + let _transfer_guard = self.begin_internal_transfer_guard(port)?; + let future = { - let mut port_state = self.port_states.get_mut(&port).ok_or(Error::new(ENOENT))?; + let mut port_state = self.any_port_state_mut(port)?; let ring = port_state .endpoint_states .get_mut(&0) @@ -283,6 +454,7 @@ pub struct Xhci { handles: CHashMap, next_handle: AtomicUsize, port_states: CHashMap>, + pending_port_states: CHashMap>, drivers: CHashMap>, scheme_name: String, @@ -305,9 +477,11 @@ unsafe impl Send for Xhci {} unsafe impl Sync for Xhci {} struct PortState { + attachment_state: PortAttachmentState, slot: u8, protocol_speed: &'static ProtocolSpeed, cfg_idx: Option, + runtime: Arc, input_context: Mutex>>, dev_desc: Option, endpoint_states: BTreeMap, @@ -463,6 +637,7 @@ impl Xhci { handles: CHashMap::new(), next_handle: AtomicUsize::new(0), port_states: CHashMap::new(), + pending_port_states: CHashMap::new(), drivers: CHashMap::new(), scheme_name, @@ -793,7 +968,8 @@ impl Xhci { } pub async fn attach_device(&self, port_id: PortId) -> syscall::Result<()> { - if self.port_states.contains_key(&port_id) { + if self.port_states.contains_key(&port_id) || self.pending_port_states.contains_key(&port_id) + { debug!("Already contains port {}", port_id); return Err(syscall::Error::new(EAGAIN)); } @@ -822,10 +998,12 @@ impl Xhci { let slot = match self.enable_port_slot(slot_ty).await { Ok(ok) => ok, Err(err) => { + warn!("xhcid: attach failed at step 1, rolling back slot 0"); error!("Failed to enable slot for port {}: {}", port_id, err); return Err(err); } }; + Self::attach_log_step(slot, 1, "enable_slot"); debug!("Enabled port {}, which the xHC mapped to {}", port_id, slot); @@ -836,6 +1014,7 @@ impl Xhci { let mut input = unsafe { self.alloc_dma_zeroed::>()? }; + Self::attach_log_step(slot, 2, "address_device"); debug!("Attempting to address the device"); let mut ring = match self .address_device(&mut input, port_id, slot_ty, slot, protocol_speed, speed) @@ -844,6 +1023,7 @@ impl Xhci { Ok(device_ring) => device_ring, Err(err) => { error!("Failed to address device for port {}: `{}`", port_id, err); + let _ = self.rollback_attach(port_id, slot, 2).await; return Err(err); } }; @@ -853,11 +1033,13 @@ impl Xhci { // TODO: Should the descriptors be cached in PortState, or refetched? let mut port_state = PortState { + attachment_state: PortAttachmentState::Pending, slot, protocol_speed, input_context: Mutex::new(input), dev_desc: None, cfg_idx: None, + runtime: PortRuntime::new(), endpoint_states: std::iter::once(( 0, EndpointState { @@ -867,42 +1049,81 @@ impl Xhci { )) .collect::>(), }; - self.port_states.insert(port_id, port_state); - debug!("Got port states!"); + self.pending_port_states.insert(port_id, port_state); + debug!("Got pending port states!"); // Ensure correct packet size is used - let dev_desc_8_byte = self.fetch_dev_desc_8_byte(port_id, slot).await?; + Self::attach_log_step(slot, 3, "fetch_device_descriptor_8"); + let dev_desc_8_byte = match self.fetch_dev_desc_8_byte(port_id, slot).await { + Ok(desc) => desc, + Err(err) => { + let _ = self.rollback_attach(port_id, slot, 3).await; + return Err(err); + } + }; { - let mut port_state = self.port_states.get_mut(&port_id).unwrap(); - - let mut input = port_state.input_context.lock().unwrap(); - - self.update_max_packet_size(&mut *input, slot, dev_desc_8_byte) - .await?; + let mut port_state = self.any_port_state_mut(port_id)?; + + let mut input = port_state + .input_context + .lock() + .map_err(|_| Error::new(EIO))?; + + Self::attach_log_step(slot, 4, "update_max_packet_size"); + if let Err(err) = self.update_max_packet_size(&mut *input, slot, dev_desc_8_byte).await + { + drop(input); + drop(port_state); + let _ = self.rollback_attach(port_id, slot, 4).await; + return Err(err); + } } debug!("Got the 8 byte dev descriptor: {:X?}", dev_desc_8_byte); - let dev_desc = self.get_desc(port_id, slot).await?; + Self::attach_log_step(slot, 5, "fetch_device_descriptor"); + let dev_desc = match self.get_desc(port_id, slot).await { + Ok(desc) => desc, + Err(err) => { + let _ = self.rollback_attach(port_id, slot, 5).await; + return Err(err); + } + }; debug!("Got the full device descriptor!"); - self.port_states.get_mut(&port_id).unwrap().dev_desc = Some(dev_desc); + self.any_port_state_mut(port_id)?.dev_desc = Some(dev_desc); debug!("Got the port states again!"); { - let mut port_state = self.port_states.get_mut(&port_id).unwrap(); + let mut port_state = self.any_port_state_mut(port_id)?; - let mut input = port_state.input_context.lock().unwrap(); + let mut input = port_state + .input_context + .lock() + .map_err(|_| Error::new(EIO))?; debug!("Got the input context!"); - let dev_desc = port_state.dev_desc.as_ref().unwrap(); - - self.update_default_control_pipe(&mut *input, slot, dev_desc) - .await?; + let dev_desc = port_state.dev_desc.as_ref().ok_or(Error::new(EIO))?; + + Self::attach_log_step(slot, 6, "configure_default_control_pipe"); + if let Err(err) = self.update_default_control_pipe(&mut *input, slot, dev_desc).await + { + drop(input); + drop(port_state); + let _ = self.rollback_attach(port_id, slot, 6).await; + return Err(err); + } } debug!("Updated the default control pipe"); + if let Some(mut published_state) = self.pending_port_states.remove(&port_id) { + published_state.attachment_state = PortAttachmentState::Attached; + self.port_states.insert(port_id, published_state); + } + match self.spawn_drivers(port_id) { - Ok(()) => (), + Ok(()) => { + info!("xhcid: uevent add device usb/{}", port_id.root_hub_port_num); + } Err(err) => { error!("Failed to spawn driver for port {}: `{}`", port_id, err) } @@ -915,6 +1136,32 @@ impl Xhci { } pub async fn detach_device(&self, port_id: PortId) -> Result { + if let Some(mut pending_state) = self.pending_port_states.remove(&port_id) { + pending_state.attachment_state = PortAttachmentState::Detaching; + pending_state.runtime.begin_detach(); + let _ = self.rollback_attach(port_id, pending_state.slot, ATTACH_STEP_COUNT).await; + return Ok(true); + } + + let (slot, runtime, endpoints) = match self.port_states.get_mut(&port_id) { + Some(mut state) => { + state.attachment_state = PortAttachmentState::Detaching; + state.runtime.begin_detach(); + ( + state.slot, + Arc::clone(&state.runtime), + state.endpoint_states.keys().copied().collect::>(), + ) + } + None => { + debug!( + "Attempted to detach from port {}, which wasn't previously attached.", + port_id + ); + return Ok(false); + } + }; + if let Some(children) = self.drivers.remove(&port_id) { for mut child in children { info!("killing driver process {} for port {}", child.id(), port_id); @@ -962,21 +1209,38 @@ impl Xhci { } } - if let Some(state) = self.port_states.remove(&port_id) { - debug!("disabling port slot {} for port {}", state.slot, port_id); - let result = self.disable_port_slot(state.slot).await.and(Ok(true)); - debug!( - "disabled port slot {} for port {} with result: {:?}", - state.slot, port_id, result - ); - result - } else { - debug!( - "Attempted to detach from port {}, which wasn't previously attached.", - port_id - ); - Ok(false) + let drained = self.wait_for_transfer_drain(&runtime, DETACH_TIMEOUT); + + self.wake_port_to_u0(port_id, slot); + + let mut timed_out = !drained; + for endp_num in endpoints { + if timed_out { + break; + } + + if !self.stop_endpoint_with_timeout(port_id, slot, endp_num, DETACH_TIMEOUT).await? { + timed_out = true; + break; + } } + + self.quiesce_port_to_u3(port_id, slot); + + if timed_out { + warn!("xhcid: forced detach slot {} after timeout", slot); + } + + debug!("disabling port slot {} for port {}", slot, port_id); + let result = self.disable_port_slot(slot).await.and(Ok(true)); + debug!( + "disabled port slot {} for port {} with result: {:?}", + slot, port_id, result + ); + + self.port_states.remove(&port_id); + info!("xhcid: uevent remove device usb/{}", port_id.root_hub_port_num); + result } pub async fn update_max_packet_size( diff --git a/drivers/usb/xhcid/src/xhci/port.rs b/drivers/usb/xhcid/src/xhci/port.rs index 0654ccc3..5edbd9cb 100644 --- a/drivers/usb/xhcid/src/xhci/port.rs +++ b/drivers/usb/xhcid/src/xhci/port.rs @@ -46,6 +46,37 @@ bitflags! { } } +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum PortLinkState { + U0 = 0, + U1 = 1, + U2 = 2, + U3 = 3, +} + +impl PortLinkState { + const PORT_STATE_MASK: u32 = 0b1111 << 5; + + pub fn from_port_state(state: u8) -> Option { + match state { + 0 => Some(Self::U0), + 1 => Some(Self::U1), + 2 => Some(Self::U2), + 3 => Some(Self::U3), + _ => None, + } + } + + pub fn as_str(self) -> &'static str { + match self { + Self::U0 => "U0", + Self::U1 => "U1", + Self::U2 => "U2", + Self::U3 => "U3", + } + } +} + #[repr(C, packed)] pub struct Port { // This has write one to clear fields, do not expose it, handle writes carefully! @@ -75,6 +106,14 @@ impl Port { .write((self.flags_preserved() | PortFlags::PR).bits()); } + pub fn set_link_state(&mut self, state: PortLinkState) { + let mut value = self.flags_preserved().bits(); + value &= !PortLinkState::PORT_STATE_MASK; + value |= (state as u32) << 5; + value |= PortFlags::LWS.bits(); + self.portsc.write(value); + } + pub fn state(&self) -> u8 { ((self.read() & (0b1111 << 5)) >> 5) as u8 } diff --git a/drivers/usb/xhcid/src/xhci/scheme.rs b/drivers/usb/xhcid/src/xhci/scheme.rs index ca27b3fe..2c27b906 100644 --- a/drivers/usb/xhcid/src/xhci/scheme.rs +++ b/drivers/usb/xhcid/src/xhci/scheme.rs @@ -20,7 +20,9 @@ use std::convert::TryFrom; use std::io::prelude::*; use std::ops::Deref; use std::sync::atomic; -use std::{cmp, fmt, io, mem, str}; +use std::collections::BTreeMap; +use std::time::Duration; +use std::{cmp, fmt, io, mem, ptr, str}; use common::dma::Dma; use futures::executor::block_on; @@ -557,6 +559,31 @@ impl AnyDescriptor { } impl Xhci { + fn snapshot_input_context(input_context: &Dma>) -> Box<[u8]> { + let mut snapshot = vec![0u8; mem::size_of::>()].into_boxed_slice(); + unsafe { + ptr::copy_nonoverlapping( + (&**input_context) as *const super::InputContext as *const u8, + snapshot.as_mut_ptr(), + snapshot.len(), + ); + } + snapshot + } + + fn restore_input_context( + input_context: &mut Dma>, + snapshot: &[u8], + ) { + unsafe { + ptr::copy_nonoverlapping( + snapshot.as_ptr(), + (&mut **input_context) as *mut super::InputContext as *mut u8, + snapshot.len(), + ); + } + } + async fn new_if_desc( &self, port_id: PortId, @@ -629,6 +656,37 @@ impl Xhci { (event_trb, command_trb) } + pub fn execute_command_with_timeout( + &self, + timeout: Duration, + f: F, + ) -> Option<(Trb, Trb)> { + if self.interrupt_is_pending(0) { + debug!("The EHB bit is already set!"); + } + + let pending_wait = { + let Ok(mut command_ring) = self.cmd.lock() else { + return None; + }; + let (cmd_index, cycle) = (command_ring.next_index(), command_ring.cycle); + + { + let command_trb = &mut command_ring.trbs[cmd_index]; + f(command_trb, cycle); + } + + let command_trb = &command_ring.trbs[cmd_index]; + self.queue_command_completion_wait( + &*command_ring, + command_trb, + EventDoorbell::new(self, 0, 0), + ) + }; + + let trbs = pending_wait.wait_timeout(timeout)?; + Some((trbs.event_trb, trbs.src_trb?)) + } pub async fn execute_control_transfer( &self, port_num: PortId, @@ -640,6 +698,9 @@ impl Xhci { where D: FnMut(&mut Trb, bool) -> ControlFlow, { + self.wake_port_to_u0(port_num, self.slot(port_num)?); + let _transfer_guard = self.begin_transfer_guard(port_num)?; + let future = { let mut port_state = self.port_state_mut(port_num)?; let slot = port_state.slot; @@ -710,6 +771,9 @@ impl Xhci { where D: FnMut(&mut Trb, bool) -> ControlFlow, { + self.wake_port_to_u0(port_num, self.slot(port_num)?); + let _transfer_guard = self.begin_transfer_guard(port_num)?; + let endp_idx = endp_num.checked_sub(1).ok_or(Error::new(EIO))?; let mut port_state = self.port_state_mut(port_num)?; @@ -863,6 +927,34 @@ impl Xhci { handle_event_trb("RESET_ENDPOINT", &event_trb, &command_trb) } + pub async fn stop_endpoint_with_timeout( + &self, + port_num: PortId, + slot: u8, + endp_num: u8, + timeout: Duration, + ) -> Result { + let endp_num_xhc = if endp_num == 0 { + 1 + } else { + let endp_idx = endp_num.checked_sub(1).ok_or(Error::new(EIO))?; + let port_state = self.port_states.get(&port_num).ok_or(Error::new(EBADFD))?; + let endp_desc = port_state + .get_endp_desc(endp_idx) + .ok_or(Error::new(EBADFD))?; + Self::endp_num_to_dci(endp_num, endp_desc) + }; + + let Some((event_trb, command_trb)) = self.execute_command_with_timeout(timeout, |trb, cycle| { + trb.stop_endpoint(slot, endp_num_xhc, false, cycle); + }) else { + return Ok(false); + }; + + handle_event_trb("STOP_ENDPOINT", &event_trb, &command_trb)?; + Ok(true) + } + fn endp_ctx_interval(speed_id: &ProtocolSpeed, endp_desc: &EndpDesc) -> u8 { /// Logarithmic (base 2) 125 µs periods per millisecond. const MILLISEC_PERIODS: u8 = 3; @@ -956,9 +1048,7 @@ impl Xhci { req: &ConfigureEndpointsReq, ) -> Result<()> { let (endp_desc_count, new_context_entries, configuration_value) = { - let mut port_state = self.port_states.get_mut(&port).ok_or(Error::new(EBADFD))?; - - port_state.cfg_idx = Some(req.config_desc); + let port_state = self.port_states.get(&port).ok_or(Error::new(EBADFD))?; let config_desc = port_state .dev_desc @@ -1003,210 +1093,259 @@ impl Xhci { Error::new(EIO) })?; - { + let (slot, previous_cfg_idx, input_snapshot) = { let port_state = self.port_states.get(&port).ok_or(Error::new(EBADFD))?; - let mut input_context = port_state.input_context.lock().unwrap(); - - // Configure the slot context as well, which holds the last index of the endp descs. - input_context.add_context.write(1); - input_context.drop_context.write(0); - - const CONTEXT_ENTRIES_MASK: u32 = 0xF800_0000; - const CONTEXT_ENTRIES_SHIFT: u8 = 27; + let input_snapshot = { + let input_context = port_state + .input_context + .lock() + .map_err(|_| Error::new(EIO))?; + Self::snapshot_input_context(&*input_context) + }; - const HUB_PORTS_MASK: u32 = 0xFF00_0000; - const HUB_PORTS_SHIFT: u8 = 24; + (port_state.slot, port_state.cfg_idx, input_snapshot) + }; - let mut current_slot_a = input_context.device.slot.a.read(); - let mut current_slot_b = input_context.device.slot.b.read(); + let mut staged_endpoint_states = BTreeMap::new(); + let stage_result = (|| -> Result<()> { + let mut port_state = self.port_states.get_mut(&port).ok_or(Error::new(EBADFD))?; - // Set context entries - current_slot_a &= !CONTEXT_ENTRIES_MASK; - current_slot_a |= - (u32::from(new_context_entries) << CONTEXT_ENTRIES_SHIFT) & CONTEXT_ENTRIES_MASK; + { + let mut input_context = port_state + .input_context + .lock() + .map_err(|_| Error::new(EIO))?; - // Set hub data - current_slot_a &= !(1 << 26); - current_slot_b &= !HUB_PORTS_MASK; - if let Some(hub_ports) = req.hub_ports { - current_slot_a |= 1 << 26; - current_slot_b |= (u32::from(hub_ports) << HUB_PORTS_SHIFT) & HUB_PORTS_MASK; - } + // Configure the slot context as well, which holds the last index of the endp descs. + input_context.add_context.write(1); + input_context.drop_context.write(0); - input_context.device.slot.a.write(current_slot_a); - input_context.device.slot.b.write(current_slot_b); + const CONTEXT_ENTRIES_MASK: u32 = 0xF800_0000; + const CONTEXT_ENTRIES_SHIFT: u8 = 27; - let control = if self.op.lock().unwrap().cie() { - (u32::from(req.alternate_setting.unwrap_or(0)) << 16) - | (u32::from(req.interface_desc.unwrap_or(0)) << 8) - | u32::from(configuration_value) - } else { - 0 - }; - input_context.control.write(control); - } + const HUB_PORTS_MASK: u32 = 0xFF00_0000; + const HUB_PORTS_SHIFT: u8 = 24; - for endp_idx in 0..endp_desc_count as u8 { - let endp_num = endp_idx + 1; + let mut current_slot_a = input_context.device.slot.a.read(); + let mut current_slot_b = input_context.device.slot.b.read(); - let mut port_state = self.port_states.get_mut(&port).ok_or(Error::new(EBADFD))?; - let dev_desc = port_state.dev_desc.as_ref().unwrap(); - let endp_desc = port_state.get_endp_desc(endp_idx).ok_or_else(|| { - warn!("failed to find endpoint {}", endp_idx); - Error::new(EIO) - })?; + current_slot_a &= !CONTEXT_ENTRIES_MASK; + current_slot_a |= (u32::from(new_context_entries) << CONTEXT_ENTRIES_SHIFT) + & CONTEXT_ENTRIES_MASK; - let endp_num_xhc = Self::endp_num_to_dci(endp_num, endp_desc); + current_slot_a &= !(1 << 26); + current_slot_b &= !HUB_PORTS_MASK; + if let Some(hub_ports) = req.hub_ports { + current_slot_a |= 1 << 26; + current_slot_b |= (u32::from(hub_ports) << HUB_PORTS_SHIFT) & HUB_PORTS_MASK; + } - let usb_log_max_streams = endp_desc.log_max_streams(); + input_context.device.slot.a.write(current_slot_a); + input_context.device.slot.b.write(current_slot_b); - // TODO: Secondary streams. - let primary_streams = if let Some(log_max_streams) = usb_log_max_streams { - // TODO: Can streams-capable be configured to not use streams? - if log_max_psa_size != 0 { - cmp::min(u8::from(log_max_streams), log_max_psa_size + 1) - 1 + let control = if self.op.lock().map_err(|_| Error::new(EIO))?.cie() { + (u32::from(req.alternate_setting.unwrap_or(0)) << 16) + | (u32::from(req.interface_desc.unwrap_or(0)) << 8) + | u32::from(configuration_value) } else { 0 - } - } else { - 0 - }; - let linear_stream_array = if primary_streams != 0 { true } else { false }; + }; + input_context.control.write(control); + } - // TODO: Interval related fields - // TODO: Max ESIT payload size. + for endp_idx in 0..endp_desc_count as u8 { + let endp_num = endp_idx + 1; - let mult = endp_desc.isoch_mult(lec); + let dev_desc = port_state.dev_desc.as_ref().ok_or(Error::new(EBADFD))?; + let endp_desc = port_state.get_endp_desc(endp_idx).ok_or_else(|| { + warn!("failed to find endpoint {}", endp_idx); + Error::new(EIO) + })?; - let max_packet_size = Self::endp_ctx_max_packet_size(endp_desc); - let max_burst_size = Self::endp_ctx_max_burst(speed_id, dev_desc, endp_desc); + let endp_num_xhc = Self::endp_num_to_dci(endp_num, endp_desc); - let max_esit_payload = Self::endp_ctx_max_esit_payload( - speed_id, - dev_desc, - endp_desc, - max_packet_size, - max_burst_size, - ); - let max_esit_payload_lo = max_esit_payload as u16; - let max_esit_payload_hi = ((max_esit_payload & 0x00FF_0000) >> 16) as u8; - - let interval = Self::endp_ctx_interval(speed_id, endp_desc); - - let max_error_count = 3; - let ep_ty = endp_desc.xhci_ep_type()?; - let host_initiate_disable = false; - - // TODO: Maybe this value is out of scope for xhcid, because the actual usb device - // driver probably knows better. The spec says that the initial value should be 8 bytes - // for control, 1KiB for interrupt and 3KiB for bulk and isoch. - let avg_trb_len: u16 = match endp_desc.ty() { - EndpointTy::Ctrl => { - warn!("trying to use control endpoint"); - return Err(Error::new(EIO)); // only endpoint zero is of type control, and is configured separately with the address device command. - } - EndpointTy::Bulk | EndpointTy::Isoch => 3072, // 3 KiB - EndpointTy::Interrupt => 1024, // 1 KiB - }; + let usb_log_max_streams = endp_desc.log_max_streams(); + let primary_streams = if let Some(log_max_streams) = usb_log_max_streams { + if log_max_psa_size != 0 { + cmp::min(u8::from(log_max_streams), log_max_psa_size + 1) - 1 + } else { + 0 + } + } else { + 0 + }; + let linear_stream_array = primary_streams != 0; + + let mult = endp_desc.isoch_mult(lec); + let max_packet_size = Self::endp_ctx_max_packet_size(endp_desc); + let max_burst_size = Self::endp_ctx_max_burst(speed_id, dev_desc, endp_desc); + let max_esit_payload = Self::endp_ctx_max_esit_payload( + speed_id, + dev_desc, + endp_desc, + max_packet_size, + max_burst_size, + ); + let max_esit_payload_lo = max_esit_payload as u16; + let max_esit_payload_hi = ((max_esit_payload & 0x00FF_0000) >> 16) as u8; + let interval = Self::endp_ctx_interval(speed_id, endp_desc); + + let max_error_count = 3; + let ep_ty = endp_desc.xhci_ep_type()?; + let host_initiate_disable = false; + let avg_trb_len: u16 = match endp_desc.ty() { + EndpointTy::Ctrl => { + warn!("trying to use control endpoint"); + return Err(Error::new(EIO)); + } + EndpointTy::Bulk | EndpointTy::Isoch => 3072, + EndpointTy::Interrupt => 1024, + }; - assert_eq!(ep_ty & 0x7, ep_ty); - assert_eq!(mult & 0x3, mult); - assert_eq!(max_error_count & 0x3, max_error_count); - assert_ne!(ep_ty, 0); // 0 means invalid. + assert_eq!(ep_ty & 0x7, ep_ty); + assert_eq!(mult & 0x3, mult); + assert_eq!(max_error_count & 0x3, max_error_count); + assert_ne!(ep_ty, 0); + + let (ring_ptr, staged_state) = if usb_log_max_streams.is_some() { + let mut array = + StreamContextArray::new::(self.cap.ac64(), 1 << (primary_streams + 1))?; + array.add_ring::(self.cap.ac64(), 1, true)?; + let array_ptr = array.register(); + + assert_eq!( + array_ptr & 0xFFFF_FFFF_FFFF_FF81, + array_ptr, + "stream ctx ptr not aligned to 16 bytes" + ); - let ring_ptr = if usb_log_max_streams.is_some() { - let mut array = - StreamContextArray::new::(self.cap.ac64(), 1 << (primary_streams + 1))?; + ( + array_ptr, + EndpointState { + transfer: super::RingOrStreams::Streams(array), + driver_if_state: EndpIfState::Init, + }, + ) + } else { + let ring = Ring::new::(self.cap.ac64(), 16, true)?; + let ring_ptr = ring.register(); - // TODO: Use as many stream rings as needed. - array.add_ring::(self.cap.ac64(), 1, true)?; - let array_ptr = array.register(); + assert_eq!( + ring_ptr & 0xFFFF_FFFF_FFFF_FF81, + ring_ptr, + "ring pointer not aligned to 16 bytes" + ); - assert_eq!( - array_ptr & 0xFFFF_FFFF_FFFF_FF81, - array_ptr, - "stream ctx ptr not aligned to 16 bytes" + ( + ring_ptr, + EndpointState { + transfer: super::RingOrStreams::Ring(ring), + driver_if_state: EndpIfState::Init, + }, + ) + }; + assert_eq!(primary_streams & 0x1F, primary_streams); + + staged_endpoint_states.insert(endp_num, staged_state); + + let mut input_context = port_state + .input_context + .lock() + .map_err(|_| Error::new(EIO))?; + input_context.add_context.writef(1 << endp_num_xhc, true); + + let endp_i = endp_num_xhc as usize - 1; + input_context.device.endpoints[endp_i].a.write( + u32::from(mult) << 8 + | u32::from(primary_streams) << 10 + | u32::from(linear_stream_array) << 15 + | u32::from(interval) << 16 + | u32::from(max_esit_payload_hi) << 24, ); - port_state.endpoint_states.insert( - endp_num, - EndpointState { - transfer: super::RingOrStreams::Streams(array), - driver_if_state: EndpIfState::Init, - }, + input_context.device.endpoints[endp_i].b.write( + max_error_count << 1 + | u32::from(ep_ty) << 3 + | u32::from(host_initiate_disable) << 7 + | u32::from(max_burst_size) << 8 + | u32::from(max_packet_size) << 16, ); - array_ptr - } else { - let ring = Ring::new::(self.cap.ac64(), 16, true)?; - let ring_ptr = ring.register(); - - assert_eq!( - ring_ptr & 0xFFFF_FFFF_FFFF_FF81, - ring_ptr, - "ring pointer not aligned to 16 bytes" - ); - port_state.endpoint_states.insert( - endp_num, - EndpointState { - transfer: super::RingOrStreams::Ring(ring), - driver_if_state: EndpIfState::Init, - }, - ); - ring_ptr - }; - assert_eq!(primary_streams & 0x1F, primary_streams); - - let mut input_context = port_state.input_context.lock().unwrap(); - input_context.add_context.writef(1 << endp_num_xhc, true); - - let endp_i = endp_num_xhc as usize - 1; - input_context.device.endpoints[endp_i].a.write( - u32::from(mult) << 8 - | u32::from(primary_streams) << 10 - | u32::from(linear_stream_array) << 15 - | u32::from(interval) << 16 - | u32::from(max_esit_payload_hi) << 24, - ); - input_context.device.endpoints[endp_i].b.write( - max_error_count << 1 - | u32::from(ep_ty) << 3 - | u32::from(host_initiate_disable) << 7 - | u32::from(max_burst_size) << 8 - | u32::from(max_packet_size) << 16, - ); + input_context.device.endpoints[endp_i].trl.write(ring_ptr as u32); + input_context.device.endpoints[endp_i].trh.write((ring_ptr >> 32) as u32); + input_context.device.endpoints[endp_i] + .c + .write(u32::from(avg_trb_len) | (u32::from(max_esit_payload_lo) << 16)); - input_context.device.endpoints[endp_i] - .trl - .write(ring_ptr as u32); - input_context.device.endpoints[endp_i] - .trh - .write((ring_ptr >> 32) as u32); + log::debug!("initialized endpoint {}", endp_num); + } - input_context.device.endpoints[endp_i] - .c - .write(u32::from(avg_trb_len) | (u32::from(max_esit_payload_lo) << 16)); + Ok(()) + })(); - log::debug!("initialized endpoint {}", endp_num); + if let Err(err) = stage_result { + warn!("xhcid: configure slot {} failed, rolling back", slot); + let mut port_state = self.port_states.get_mut(&port).ok_or(Error::new(EBADFD))?; + let mut input_context = port_state + .input_context + .lock() + .map_err(|_| Error::new(EIO))?; + Self::restore_input_context(&mut *input_context, &input_snapshot); + return Err(err); } - { - let port_state = self.port_states.get(&port).ok_or(Error::new(EBADFD))?; - let slot = port_state.slot; - let input_context_physical = port_state.input_context.lock().unwrap().physical(); + let input_context_physical = self + .port_states + .get(&port) + .ok_or(Error::new(EBADFD))? + .input_context + .lock() + .map_err(|_| Error::new(EIO))? + .physical(); + + let (event_trb, command_trb) = self + .execute_command(|trb, cycle| trb.configure_endpoint(slot, input_context_physical, cycle)) + .await; + + if let Err(err) = handle_event_trb("CONFIGURE_ENDPOINT", &event_trb, &command_trb) { + warn!("xhcid: configure slot {} failed, rolling back", slot); + let mut port_state = self.port_states.get_mut(&port).ok_or(Error::new(EBADFD))?; + let mut input_context = port_state + .input_context + .lock() + .map_err(|_| Error::new(EIO))?; + Self::restore_input_context(&mut *input_context, &input_snapshot); + return Err(err); + } - let (event_trb, command_trb) = self - .execute_command(|trb, cycle| { - trb.configure_endpoint(slot, input_context_physical, cycle) - }) - .await; + if let Err(err) = self.set_configuration(port, configuration_value).await { + warn!("xhcid: configure slot {} failed, rolling back", slot); + { + let mut port_state = self.port_states.get_mut(&port).ok_or(Error::new(EBADFD))?; + let mut input_context = port_state + .input_context + .lock() + .map_err(|_| Error::new(EIO))?; + Self::restore_input_context(&mut *input_context, &input_snapshot); + } - //self.event_handler_finished(); + if let Err(restore_err) = self.set_configuration(port, previous_cfg_idx.unwrap_or(0)).await { + warn!( + "xhcid: failed to restore configuration {} for slot {}: {}", + previous_cfg_idx.unwrap_or(0), + slot, + restore_err + ); + } - handle_event_trb("CONFIGURE_ENDPOINT", &event_trb, &command_trb)?; + return Err(err); } - // Tell the device about this configuration. - self.set_configuration(port, configuration_value).await?; + let mut port_state = self.port_states.get_mut(&port).ok_or(Error::new(EBADFD))?; + port_state.cfg_idx = Some(req.config_desc); + port_state.endpoint_states.retain(|endp_num, _| *endp_num == 0); + for (endp_num, state) in staged_endpoint_states { + port_state.endpoint_states.insert(endp_num, state); + } Ok(()) }