# P2-xhcid-remaining.patch # Extract xhcid remaining hardening: MSI-X/MSI/legacy IRQ fallback, test hooks, # port lifecycle management, staged port states, suspend/resume, endpoint # configuration rollback, and power management. # # Files: drivers/usb/xhcid/src/main.rs, drivers/usb/xhcid/src/xhci/device_enumerator.rs, # drivers/usb/xhcid/src/xhci/mod.rs, drivers/usb/xhcid/src/xhci/scheme.rs diff --git a/drivers/usb/xhcid/src/main.rs b/drivers/usb/xhcid/src/main.rs index d345a52f..562c580a 100644 --- a/drivers/usb/xhcid/src/main.rs +++ b/drivers/usb/xhcid/src/main.rs @@ -33,7 +33,7 @@ use std::sync::Arc; use pcid_interface::irq_helpers::read_bsp_apic_id; #[cfg(target_arch = "x86_64")] use pcid_interface::irq_helpers::{ - allocate_first_msi_interrupt_on_bsp, allocate_single_interrupt_vector_for_msi, + try_allocate_first_msi_interrupt_on_bsp, try_allocate_single_interrupt_vector_for_msi, }; use pcid_interface::{PciFeature, PciFeatureInfo, PciFunctionHandle}; @@ -61,11 +61,24 @@ fn get_int_method(pcid_handle: &mut PciFunctionHandle) -> (Option, Interru let has_msix = all_pci_features.iter().any(|feature| feature.is_msix()); if has_msix { - let msix_info = match pcid_handle.feature_info(PciFeature::MsiX) { - PciFeatureInfo::Msi(_) => panic!(), - PciFeatureInfo::MsiX(s) => s, + let msix_info = match pcid_handle.try_feature_info(PciFeature::MsiX) { + Ok(PciFeatureInfo::MsiX(s)) => s, + Ok(PciFeatureInfo::Msi(_)) => { + log::error!("xhcid: invalid MSI-X feature response payload"); + return (None, InterruptMethod::Polling); + } + Err(err) => { + log::error!("xhcid: failed to fetch MSI-X feature info: {err}"); + return (None, InterruptMethod::Polling); + } + }; + let mut info = match unsafe { msix_info.try_map_and_mask_all(pcid_handle) } { + Ok(info) => info, + Err(err) => { + log::error!("xhcid: failed to map MSI-X registers: {err}"); + return (None, InterruptMethod::Polling); + } }; - let mut info = unsafe { msix_info.map_and_mask_all(pcid_handle) }; // Allocate one msi vector. @@ -75,27 +88,53 @@ fn get_int_method(pcid_handle: &mut PciFunctionHandle) -> (Option, Interru let table_entry_pointer = info.table_entry_pointer(k); - let destination_id = read_bsp_apic_id().expect("xhcid: failed to read BSP apic id"); + let destination_id = match read_bsp_apic_id() { + Ok(id) => id, + Err(err) => { + log::error!("xhcid: failed to read BSP APIC ID: {err}"); + return (None, InterruptMethod::Polling); + } + }; let (msg_addr_and_data, interrupt_handle) = - allocate_single_interrupt_vector_for_msi(destination_id); + match try_allocate_single_interrupt_vector_for_msi(destination_id) { + Ok(result) => result, + Err(err) => { + log::error!("xhcid: failed to allocate MSI-X vector: {err}"); + return (None, InterruptMethod::Polling); + } + }; table_entry_pointer.write_addr_and_data(msg_addr_and_data); table_entry_pointer.unmask(); (Some(interrupt_handle), InterruptMethod::Msi) }; - pcid_handle.enable_feature(PciFeature::MsiX); + if let Err(err) = pcid_handle.try_enable_feature(PciFeature::MsiX) { + log::error!("xhcid: failed to enable MSI-X: {err}"); + return (None, InterruptMethod::Polling); + } log::debug!("Enabled MSI-X"); method } else if has_msi { - let interrupt_handle = allocate_first_msi_interrupt_on_bsp(pcid_handle); - (Some(interrupt_handle), InterruptMethod::Msi) + match try_allocate_first_msi_interrupt_on_bsp(pcid_handle) { + Ok(interrupt_handle) => (Some(interrupt_handle), InterruptMethod::Msi), + Err(err) => { + log::error!("xhcid: failed to allocate MSI interrupt: {err}"); + (None, InterruptMethod::Polling) + } + } } else if let Some(irq) = pci_config.func.legacy_interrupt_line { log::debug!("Legacy IRQ {}", irq); // legacy INTx# interrupt pins. - (Some(irq.irq_handle("xhcid")), InterruptMethod::Intx) + match irq.try_irq_handle("xhcid") { + Ok(file) => (Some(file), InterruptMethod::Intx), + Err(err) => { + log::error!("xhcid: failed to open legacy IRQ handle: {err}"); + (None, InterruptMethod::Polling) + } + } } else { // no interrupts at all (None, InterruptMethod::Polling) @@ -109,7 +148,13 @@ fn get_int_method(pcid_handle: &mut PciFunctionHandle) -> (Option, Interru if let Some(irq) = pci_config.func.legacy_interrupt_line { // legacy INTx# interrupt pins. - (Some(irq.irq_handle("xhcid")), InterruptMethod::Intx) + match irq.try_irq_handle("xhcid") { + Ok(file) => (Some(file), InterruptMethod::Intx), + Err(err) => { + log::error!("xhcid: failed to open legacy IRQ handle: {err}"); + (None, InterruptMethod::Polling) + } + } } else { // no interrupts at all (None, InterruptMethod::Polling) @@ -136,23 +181,48 @@ fn daemon_with_context_size( log::debug!("XHCI PCI CONFIG: {:?}", pci_config); - let address = unsafe { pcid_handle.map_bar(0) }.ptr.as_ptr() as usize; + let address = match unsafe { pcid_handle.try_map_bar(0) } { + Ok(bar) => bar.ptr.as_ptr() as usize, + Err(err) => { + log::error!("xhcid: failed to map BAR0: {err}"); + std::process::exit(1); + } + }; + + let (irq_file, interrupt_method) = get_int_method(&mut pcid_handle); - let (irq_file, interrupt_method) = (None, InterruptMethod::Polling); //get_int_method(&mut pcid_handle); - //TODO: Fix interrupts. + match interrupt_method { + InterruptMethod::Msi => log::info!("xhcid: using MSI/MSI-X interrupt delivery"), + InterruptMethod::Intx => log::info!("xhcid: using legacy INTx interrupt delivery"), + InterruptMethod::Polling => log::warn!("xhcid: falling back to polling mode"), + } log::info!("XHCI {}", pci_config.func.display()); let scheme_name = format!("usb.{}", name); - let socket = Socket::create().expect("xhcid: failed to create usb scheme"); + let socket = match Socket::create() { + Ok(socket) => socket, + Err(err) => { + log::error!("xhcid: failed to create usb scheme: {err}"); + std::process::exit(1); + } + }; let handler = Blocking::new(&socket, 16); let hci = Arc::new( - Xhci::::new(scheme_name.clone(), address, interrupt_method, pcid_handle) - .expect("xhcid: failed to allocate device"), + match Xhci::::new(scheme_name.clone(), address, interrupt_method, pcid_handle) { + Ok(hci) => hci, + Err(err) => { + log::error!("xhcid: failed to allocate device: {err}"); + std::process::exit(1); + } + }, ); register_sync_scheme(&socket, &scheme_name, &mut &*hci) - .expect("xhcid: failed to regsiter scheme to namespace"); + .unwrap_or_else(|err| { + log::error!("xhcid: failed to register scheme to namespace: {err}"); + std::process::exit(1); + }); daemon.ready(); @@ -163,7 +233,10 @@ fn daemon_with_context_size( handler .process_requests_blocking(&*hci) - .expect("xhcid: failed to process requests"); + .unwrap_or_else(|err| { + log::error!("xhcid: failed to process requests: {err}"); + std::process::exit(1); + }); } fn main() { @@ -171,7 +244,13 @@ fn main() { } fn daemon(daemon: daemon::Daemon, mut pcid_handle: PciFunctionHandle) -> ! { - let address = unsafe { pcid_handle.map_bar(0) }.ptr.as_ptr() as usize; + let address = match unsafe { pcid_handle.try_map_bar(0) } { + Ok(bar) => bar.ptr.as_ptr() as usize, + Err(err) => { + log::error!("xhcid: failed to map BAR0: {err}"); + std::process::exit(1); + } + }; let cap = unsafe { &mut *(address as *mut xhci::CapabilityRegs) }; if cap.csz() { daemon_with_context_size::<{ xhci::CONTEXT_64 }>(daemon, pcid_handle) diff --git a/drivers/usb/xhcid/src/xhci/device_enumerator.rs b/drivers/usb/xhcid/src/xhci/device_enumerator.rs index 74b9f732..493e79df 100644 --- a/drivers/usb/xhcid/src/xhci/device_enumerator.rs +++ b/drivers/usb/xhcid/src/xhci/device_enumerator.rs @@ -4,9 +4,11 @@ use common::io::Io; use crossbeam_channel; use log::{debug, info, warn}; use std::sync::Arc; -use std::time::Duration; +use std::time::{Duration, Instant}; use syscall::EAGAIN; +const DEFAULT_PORT_RESET_SETTLE_MS: u64 = 16; + pub struct DeviceEnumerationRequest { pub port_id: PortId, } @@ -28,7 +30,11 @@ impl DeviceEnumerator { let request = match self.request_queue.recv() { Ok(req) => req, Err(err) => { - panic!("Failed to received an enumeration request! error: {}", err) + warn!( + "device enumerator stopping after request queue closed: {}", + err + ); + break; } }; @@ -38,7 +44,11 @@ impl DeviceEnumerator { debug!("Device Enumerator request for port {}", port_id); let (len, flags) = { - let ports = self.hci.ports.lock().unwrap(); + let ports = self + .hci + .ports + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); let len = ports.len(); @@ -62,43 +72,52 @@ impl DeviceEnumerator { //A USB3 port won't generate a Connect Status Change until it's already enabled, so this check //will always be skipped for USB3 ports if !flags.contains(PortFlags::PED) { - let disabled_state = flags.contains(PortFlags::PP) - && flags.contains(PortFlags::CCS) - && !flags.contains(PortFlags::PED) - && !flags.contains(PortFlags::PR); + let disabled_state = Self::port_is_disabled(&flags); if !disabled_state { - panic!( - "Port {} isn't in the disabled state! Current flags: {:?}", + warn!( + "Port {} never reached the disabled state before reset-driven enumeration; current flags: {:?}", port_id, flags ); + continue; } else { debug!("Port {} has entered the disabled state.", port_id); } //THIS LOCKS THE PORTS. DO NOT LOCK PORTS BEFORE THIS POINT debug!("Received a device connect on port {}, but it's not enabled. Resetting the port.", port_id); - let _ = self.hci.reset_port(port_id); + if let Err(err) = self.hci.reset_port(port_id) { + warn!( + "failed to reset port {} before enumeration; skipping attach: {}", + port_id, err + ); + continue; + } - let mut ports = self.hci.ports.lock().unwrap(); + let mut ports = self + .hci + .ports + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); let port = &mut ports[port_array_index]; port.clear_prc(); - std::thread::sleep(Duration::from_millis(16)); //Some controllers need some extra time to make the transition. + drop(ports); - let flags = port.flags(); + let flags = self.wait_for_port_enabled_state( + port_array_index, + Duration::from_millis(DEFAULT_PORT_RESET_SETTLE_MS), + ); - let enabled_state = flags.contains(PortFlags::PP) - && flags.contains(PortFlags::CCS) - && flags.contains(PortFlags::PED) - && !flags.contains(PortFlags::PR); + let enabled_state = Self::port_is_enabled(&flags); if !enabled_state { warn!( - "Port {} isn't in the enabled state! Current flags: {:?}", + "Port {} isn't in the enabled state after bounded reset settle; current flags: {:?}", port_id, flags ); + continue; } else { debug!( "Port {} is in the enabled state. Proceeding with enumeration", @@ -131,13 +150,60 @@ impl DeviceEnumerator { Ok(was_connected) => { if was_connected { info!("Device on port {} was detached", port_id); + } else { + debug!( + "Ignoring duplicate or out-of-order detach event for unattached port {}", + port_id + ); } } Err(err) => { - warn!("processing of device attach request failed! Error: {}", err); + warn!("processing of device detach request failed! Error: {}", err); } } } } } + + fn port_is_disabled(flags: &PortFlags) -> bool { + flags.contains(PortFlags::PP) + && flags.contains(PortFlags::CCS) + && !flags.contains(PortFlags::PED) + && !flags.contains(PortFlags::PR) + } + + fn port_is_enabled(flags: &PortFlags) -> bool { + flags.contains(PortFlags::PP) + && flags.contains(PortFlags::CCS) + && flags.contains(PortFlags::PED) + && !flags.contains(PortFlags::PR) + } + + fn wait_for_port_enabled_state( + &self, + port_array_index: usize, + settle_timeout: Duration, + ) -> PortFlags { + let start = Instant::now(); + + loop { + let flags = { + let ports = self + .hci + .ports + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + ports[port_array_index].flags() + }; + + if Self::port_is_enabled(&flags) + || !flags.contains(PortFlags::PR) + || start.elapsed() >= settle_timeout + { + return flags; + } + + std::thread::sleep(Duration::from_millis(1)); + } + } } diff --git a/drivers/usb/xhcid/src/xhci/mod.rs b/drivers/usb/xhcid/src/xhci/mod.rs index f2143676..0d2ec432 100644 --- a/drivers/usb/xhcid/src/xhci/mod.rs +++ b/drivers/usb/xhcid/src/xhci/mod.rs @@ -11,12 +11,13 @@ //! documents are specified in the crate-level documentation. use std::collections::BTreeMap; use std::convert::TryFrom; -use std::fs::File; +use std::fs::{self, File}; use std::sync::atomic::AtomicUsize; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Condvar, Mutex}; +use std::time::Duration; use std::{mem, process, slice, thread}; -use syscall::error::{Error, Result, EBADF, EBADMSG, EIO, ENOENT}; +use syscall::error::{Error, Result, EBADF, EBADMSG, EBUSY, EIO, ENOENT}; use syscall::{EAGAIN, PAGE_SIZE}; use chashmap::CHashMap; @@ -77,7 +78,52 @@ pub enum InterruptMethod { Msi, } +const XHCID_TEST_HOOK_PATH: &str = "/tmp/xhcid-test-hook"; +const XHCID_TEST_HOOK_MAX_DELAY_MS: u64 = 5_000; + impl Xhci { + fn read_test_hook_command_from_path(path: &str) -> Option { + let contents = fs::read_to_string(path).ok()?; + contents + .lines() + .map(|line| line.trim()) + .find(|line| !line.is_empty() && !line.starts_with('#')) + .map(|line| line.to_owned()) + } + + fn clear_test_hook_command_path(path: &str) { + if let Err(err) = fs::remove_file(path) { + if err.kind() != std::io::ErrorKind::NotFound { + warn!("failed to remove xhcid test hook file {}: {}", path, err); + } + } + } + + fn consume_test_hook_from_path(path: &str, expected: &str) -> bool { + match Self::read_test_hook_command_from_path(path) { + Some(command) if command == expected => { + Self::clear_test_hook_command_path(path); + true + } + _ => false, + } + } + + fn consume_test_hook_delay_ms_from_path(path: &str, prefix: &str) -> Option { + let command = Self::read_test_hook_command_from_path(path)?; + let delay_ms = command.strip_prefix(prefix)?.parse::().ok()?; + Self::clear_test_hook_command_path(path); + Some(delay_ms.min(XHCID_TEST_HOOK_MAX_DELAY_MS)) + } + + pub(crate) fn consume_test_hook(&self, expected: &str) -> bool { + Self::consume_test_hook_from_path(XHCID_TEST_HOOK_PATH, expected) + } + + pub(crate) fn consume_test_hook_delay_ms(&self, prefix: &str) -> Option { + Self::consume_test_hook_delay_ms_from_path(XHCID_TEST_HOOK_PATH, prefix) + } + /// Gets descriptors, before the port state is initiated. async fn get_desc_raw( &self, @@ -104,7 +150,17 @@ impl Xhci { ); let future = { - let mut port_state = self.port_states.get_mut(&port).ok_or(Error::new(ENOENT))?; + let mut published_port_state = self.port_states.get_mut(&port); + let mut staged_port_state = if published_port_state.is_none() { + self.staged_port_states.get_mut(&port) + } else { + None + }; + + let port_state = published_port_state + .as_deref_mut() + .or_else(|| staged_port_state.as_deref_mut()) + .ok_or(Error::new(ENOENT))?; let ring = port_state .endpoint_states .get_mut(&0) @@ -283,6 +339,7 @@ pub struct Xhci { handles: CHashMap, next_handle: AtomicUsize, port_states: CHashMap>, + staged_port_states: CHashMap>, drivers: CHashMap>, scheme_name: String, @@ -308,9 +365,97 @@ struct PortState { slot: u8, protocol_speed: &'static ProtocolSpeed, cfg_idx: Option, + active_ifaces: BTreeMap, // iface number → active alternate setting input_context: Mutex>>, dev_desc: Option, endpoint_states: BTreeMap, + lifecycle: Arc, + pm_state: PortPmState, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) enum PortLifecycleState { + Attaching, + Attached, + Detaching, +} + +struct PortLifecycleInner { + state: PortLifecycleState, + active_operations: usize, +} + +pub(crate) struct PortLifecycle { + inner: Mutex, + idle: Condvar, +} + +impl PortLifecycle { + pub(crate) fn new_attaching() -> Self { + Self { + inner: Mutex::new(PortLifecycleInner { + state: PortLifecycleState::Attaching, + active_operations: 1, + }), + idle: Condvar::new(), + } + } + + fn lock_inner(&self) -> std::sync::MutexGuard<'_, PortLifecycleInner> { + self.inner.lock().unwrap_or_else(|err| err.into_inner()) + } + + pub(crate) fn finish_attach_success(&self) -> PortLifecycleState { + let mut inner = self.lock_inner(); + + if inner.state == PortLifecycleState::Attaching { + inner.state = PortLifecycleState::Attached; + } + + if inner.active_operations != 0 { + inner.active_operations -= 1; + } + if inner.active_operations == 0 { + self.idle.notify_all(); + } + + inner.state + } + + pub(crate) fn finish_attach_failure(&self) { + let mut inner = self.lock_inner(); + inner.state = PortLifecycleState::Detaching; + + if inner.active_operations != 0 { + inner.active_operations -= 1; + } + if inner.active_operations == 0 { + self.idle.notify_all(); + } + } + + pub(crate) fn begin_detaching(&self) { + let mut inner = self.lock_inner(); + inner.state = PortLifecycleState::Detaching; + + while inner.active_operations != 0 { + inner = self.idle.wait(inner).unwrap_or_else(|err| err.into_inner()); + } + } +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) enum PortPmState { + Active, + Suspended, +} +impl PortPmState { + pub fn as_str(&self) -> &'static str { + match self { + Self::Active => "active", + Self::Suspended => "suspended", + } + } } impl PortState { @@ -463,6 +608,7 @@ impl Xhci { handles: CHashMap::new(), next_handle: AtomicUsize::new(0), port_states: CHashMap::new(), + staged_port_states: CHashMap::new(), drivers: CHashMap::new(), scheme_name, @@ -793,11 +939,14 @@ impl Xhci { } pub async fn attach_device(&self, port_id: PortId) -> syscall::Result<()> { - if self.port_states.contains_key(&port_id) { + if self.port_states.contains_key(&port_id) || self.staged_port_states.contains_key(&port_id) + { debug!("Already contains port {}", port_id); return Err(syscall::Error::new(EAGAIN)); } + info!("xhcid: begin attach for port {}", port_id); + let (data, state, speed, flags) = { let port = &self.ports.lock().unwrap()[port_id.root_hub_port_index()]; (port.read(), port.state(), port.speed(), port.flags()) @@ -808,74 +957,102 @@ impl Xhci { port_id, data, state, speed, flags ); - if flags.contains(port::PortFlags::CCS) { - let slot_ty = match self.supported_protocol(port_id) { - Some(protocol) => protocol.proto_slot_ty(), - None => { - warn!("Failed to find supported protocol information for port"); - 0 - } - }; - - debug!("Slot type: {}", slot_ty); - debug!("Enabling slot."); - let slot = match self.enable_port_slot(slot_ty).await { - Ok(ok) => ok, - Err(err) => { - error!("Failed to enable slot for port {}: {}", port_id, err); - return Err(err); - } - }; + if !flags.contains(port::PortFlags::CCS) { + warn!("Attempted to attach a device that didnt have CCS=1"); + return Ok(()); + } - debug!("Enabled port {}, which the xHC mapped to {}", port_id, slot); + let slot_ty = match self.supported_protocol(port_id) { + Some(protocol) => protocol.proto_slot_ty(), + None => { + warn!("Failed to find supported protocol information for port"); + 0 + } + }; - //TODO: get correct speed for child devices - let protocol_speed = self - .lookup_psiv(port_id, speed) - .expect("Failed to retrieve speed ID"); + debug!("Slot type: {}", slot_ty); + debug!("Enabling slot."); + let slot = match self.enable_port_slot(slot_ty).await { + Ok(ok) => ok, + Err(err) => { + error!("Failed to enable slot for port {}: {}", port_id, err); + return Err(err); + } + }; - let mut input = unsafe { self.alloc_dma_zeroed::>()? }; + debug!("Enabled port {}, which the xHC mapped to {}", port_id, slot); - debug!("Attempting to address the device"); - let mut ring = match self - .address_device(&mut input, port_id, slot_ty, slot, protocol_speed, speed) - .await - { - Ok(device_ring) => device_ring, - Err(err) => { - error!("Failed to address device for port {}: `{}`", port_id, err); - return Err(err); + let protocol_speed = match self.lookup_psiv(port_id, speed) { + Some(protocol_speed) => protocol_speed, + None => { + let err = Error::new(EIO); + error!("Failed to retrieve speed ID for port {}", port_id); + if let Err(disable_err) = self.disable_port_slot(slot).await { + warn!( + "Failed to disable slot {} after speed lookup failure on port {}: {}", + slot, port_id, disable_err + ); } - }; + return Err(err); + } + }; - debug!("Addressed device"); + let mut input = unsafe { self.alloc_dma_zeroed::>()? }; - // TODO: Should the descriptors be cached in PortState, or refetched? + debug!("Attempting to address the device"); + let ring = match self + .address_device(&mut input, port_id, slot_ty, slot, protocol_speed, speed) + .await + { + Ok(device_ring) => device_ring, + Err(err) => { + error!("Failed to address device for port {}: `{}`", port_id, err); + if let Err(disable_err) = self.disable_port_slot(slot).await { + warn!( + "Failed to disable slot {} after address failure on port {}: {}", + slot, port_id, disable_err + ); + } + return Err(err); + } + }; - let mut port_state = PortState { - slot, - protocol_speed, - input_context: Mutex::new(input), - dev_desc: None, - cfg_idx: None, - endpoint_states: std::iter::once(( - 0, - EndpointState { - transfer: RingOrStreams::Ring(ring), - driver_if_state: EndpIfState::Init, - }, - )) - .collect::>(), - }; - self.port_states.insert(port_id, port_state); - debug!("Got port states!"); + debug!("Addressed device"); - // Ensure correct packet size is used + let lifecycle = Arc::new(PortLifecycle::new_attaching()); + let port_state = PortState { + slot, + protocol_speed, + input_context: Mutex::new(input), + dev_desc: None, + cfg_idx: None, + active_ifaces: BTreeMap::new(), + endpoint_states: std::iter::once(( + 0, + EndpointState { + transfer: RingOrStreams::Ring(ring), + driver_if_state: EndpIfState::Init, + }, + )) + .collect::>(), + lifecycle: Arc::clone(&lifecycle), + pm_state: PortPmState::Active, + }; + self.staged_port_states.insert(port_id, port_state); + debug!("Got staged port state!"); + + let attach_result = async { let dev_desc_8_byte = self.fetch_dev_desc_8_byte(port_id, slot).await?; { - let mut port_state = self.port_states.get_mut(&port_id).unwrap(); + let mut port_state = self + .staged_port_states + .get_mut(&port_id) + .ok_or(Error::new(ENOENT))?; - let mut input = port_state.input_context.lock().unwrap(); + let mut input = port_state + .input_context + .lock() + .unwrap_or_else(|err| err.into_inner()); self.update_max_packet_size(&mut *input, slot, dev_desc_8_byte) .await?; @@ -885,97 +1062,175 @@ impl Xhci { let dev_desc = self.get_desc(port_id, slot).await?; debug!("Got the full device descriptor!"); - self.port_states.get_mut(&port_id).unwrap().dev_desc = Some(dev_desc); + self.staged_port_states + .get_mut(&port_id) + .ok_or(Error::new(ENOENT))? + .dev_desc = Some(dev_desc); debug!("Got the port states again!"); { - let mut port_state = self.port_states.get_mut(&port_id).unwrap(); - - let mut input = port_state.input_context.lock().unwrap(); + let mut port_state = self + .staged_port_states + .get_mut(&port_id) + .ok_or(Error::new(ENOENT))?; + + let mut input = port_state + .input_context + .lock() + .unwrap_or_else(|err| err.into_inner()); debug!("Got the input context!"); - let dev_desc = port_state.dev_desc.as_ref().unwrap(); + let dev_desc = port_state.dev_desc.as_ref().ok_or(Error::new(EIO))?; self.update_default_control_pipe(&mut *input, slot, dev_desc) .await?; } debug!("Updated the default control pipe"); + Ok(()) + } + .await; - match self.spawn_drivers(port_id) { - Ok(()) => (), - Err(err) => { - error!("Failed to spawn driver for port {}: `{}`", port_id, err) + match attach_result { + Ok(()) => { + if let Some(delay_ms) = + self.consume_test_hook_delay_ms("delay_before_attach_commit_ms=") + { + info!( + "xhcid: test hook delaying attach commit for port {} by {} ms", + port_id, delay_ms + ); + thread::sleep(Duration::from_millis(delay_ms)); } + + if lifecycle.finish_attach_success() != PortLifecycleState::Attached { + warn!( + "attach for port {} completed after detach already started; skipping publication", + port_id + ); + return Err(Error::new(EBUSY)); + } + + let staged_port_state = self + .staged_port_states + .remove(&port_id) + .ok_or(Error::new(ENOENT))?; + self.port_states.insert(port_id, staged_port_state); + + match self.spawn_drivers(port_id) { + Ok(()) => (), + Err(err) => { + error!("Failed to spawn driver for port {}: `{}`", port_id, err) + } + } + + info!("xhcid: finished attach for port {}", port_id); + Ok(()) + } + Err(err) => { + lifecycle.finish_attach_failure(); + if let Err(detach_err) = self.detach_device(port_id).await { + warn!( + "failed to clean up attach failure on port {}: {}", + port_id, detach_err + ); + } + Err(err) } - } else { - warn!("Attempted to attach a device that didnt have CCS=1"); } - - Ok(()) } pub async fn detach_device(&self, port_id: PortId) -> Result { - if let Some(children) = self.drivers.remove(&port_id) { - for mut child in children { - info!("killing driver process {} for port {}", child.id(), port_id); - match child.kill() { - Ok(()) => { - info!("killed driver process {} for port {}", child.id(), port_id); - match child.try_wait() { - Ok(status_opt) => match status_opt { - Some(status) => { - debug!( - "driver process {} for port {} exited with status {}", - child.id(), - port_id, - status - ); - } - None => { - //TODO: kill harder + let published_state = self.port_states.get(&port_id); + let staged_state = if published_state.is_none() { + self.staged_port_states.get(&port_id) + } else { + None + }; + + let (slot, lifecycle, was_published) = match published_state + .as_deref() + .or_else(|| staged_state.as_deref()) + { + Some(state) => (state.slot, Arc::clone(&state.lifecycle), published_state.is_some()), + None => { + debug!( + "Attempted to detach from port {}, which wasn't previously attached.", + port_id + ); + return Ok(false); + } + }; + drop(published_state); + drop(staged_state); + + lifecycle.begin_detaching(); + + if was_published { + if let Some(children) = self.drivers.remove(&port_id) { + for mut child in children { + info!("killing driver process {} for port {}", child.id(), port_id); + match child.kill() { + Ok(()) => { + info!("killed driver process {} for port {}", child.id(), port_id); + match child.try_wait() { + Ok(status_opt) => match status_opt { + Some(status) => { + debug!( + "driver process {} for port {} exited with status {}", + child.id(), + port_id, + status + ); + } + None => { + warn!( + "driver process {} for port {} still running", + child.id(), + port_id + ); + } + }, + Err(err) => { warn!( - "driver process {} for port {} still running", + "failed to wait for the driver process {} for port {}: {}", child.id(), - port_id + port_id, + err ); } - }, - Err(err) => { - warn!( - "failed to wait for the driver process {} for port {}: {}", - child.id(), - port_id, - err - ); } } - } - Err(err) => { - warn!( - "failed to kill the driver process {} for port {}: {}", - child.id(), - port_id, - err - ); + Err(err) => { + warn!( + "failed to kill the driver process {} for port {}: {}", + child.id(), + port_id, + err + ); + } } } } } - if let Some(state) = self.port_states.remove(&port_id) { - debug!("disabling port slot {} for port {}", state.slot, port_id); - let result = self.disable_port_slot(state.slot).await.and(Ok(true)); - debug!( - "disabled port slot {} for port {} with result: {:?}", - state.slot, port_id, result - ); - result - } else { - debug!( - "Attempted to detach from port {}, which wasn't previously attached.", - port_id - ); - Ok(false) + debug!("disabling port slot {} for port {}", slot, port_id); + match self.disable_port_slot(slot).await { + Ok(()) => { + if was_published { + let _ = self.port_states.remove(&port_id); + } else { + let _ = self.staged_port_states.remove(&port_id); + } + debug!("disabled port slot {} for port {}", slot, port_id); + Ok(true) + } + Err(err) => { + warn!( + "failed to disable port slot {} for port {}: {}", + slot, port_id, err + ); + Err(err) + } } } @@ -1246,14 +1501,12 @@ impl Xhci { let drivers_usercfg: &DriversConfig = &DRIVERS_CONFIG; for ifdesc in config_desc.interface_descs.iter() { - //TODO: support alternate settings - // This is difficult because the device driver must know which alternate - // to use, but if alternates can have different classes, then a different - // device driver may be required for each alternate. For now, we will use - // only the default alternate setting (0) + // Only auto-spawn drivers for the default alternate setting (0). + // Non-default alternates are selected later by the device driver + // via SET_INTERFACE + configure_endpoints with specific alternate_setting. if ifdesc.alternate_setting != 0 { - warn!( - "ignoring port {} iface {} alternate {} class {}.{} proto {}", + debug!( + "skipping port {} iface {} alternate {} class {}.{} proto {} (non-default alternate)", port, ifdesc.number, ifdesc.alternate_setting, @@ -1458,6 +1711,53 @@ pub fn start_device_enumerator(hci: &Arc>) { })); } +#[cfg(test)] +mod tests { + use std::fs; + use std::path::Path; + use std::time::{SystemTime, UNIX_EPOCH}; + + use super::{Xhci, XHCID_TEST_HOOK_MAX_DELAY_MS}; + + fn unique_test_hook_path() -> String { + let unique = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_nanos(); + format!("/tmp/xhcid-test-hook-{}", unique) + } + + #[test] + fn consume_test_hook_only_clears_matching_command() { + let path = unique_test_hook_path(); + fs::write(&path, "fail_after_set_configuration\n").unwrap(); + + assert!(!Xhci::<16>::consume_test_hook_from_path( + &path, + "fail_after_configure_endpoint" + )); + assert!(Path::new(&path).exists()); + + assert!(Xhci::<16>::consume_test_hook_from_path( + &path, + "fail_after_set_configuration" + )); + assert!(!Path::new(&path).exists()); + } + + #[test] + fn consume_test_hook_delay_clamps_and_clears() { + let path = unique_test_hook_path(); + fs::write(&path, "delay_before_attach_commit_ms=999999\n").unwrap(); + + assert_eq!( + Xhci::<16>::consume_test_hook_delay_ms_from_path(&path, "delay_before_attach_commit_ms="), + Some(XHCID_TEST_HOOK_MAX_DELAY_MS) + ); + assert!(!Path::new(&path).exists()); + } +} + #[derive(Deserialize)] struct DriverConfig { name: String, diff --git a/drivers/usb/xhcid/src/xhci/scheme.rs b/drivers/usb/xhcid/src/xhci/scheme.rs index ca27b3fe..29437294 100644 --- a/drivers/usb/xhcid/src/xhci/scheme.rs +++ b/drivers/usb/xhcid/src/xhci/scheme.rs @@ -20,6 +20,7 @@ use std::convert::TryFrom; use std::io::prelude::*; use std::ops::Deref; use std::sync::atomic; +use std::collections::BTreeMap; use std::{cmp, fmt, io, mem, str}; use common::dma::Dma; @@ -33,9 +34,9 @@ use common::io::Io; use redox_scheme::{CallerCtx, OpenResult}; use syscall::schemev2::NewFdFlags; use syscall::{ - Error, Result, Stat, EACCES, EBADF, EBADFD, EBADMSG, EINVAL, EIO, EISDIR, ENOENT, ENOSYS, - ENOTDIR, EOPNOTSUPP, EPROTO, ESPIPE, MODE_CHR, MODE_DIR, MODE_FILE, O_DIRECTORY, O_RDWR, - O_STAT, O_WRONLY, SEEK_CUR, SEEK_END, SEEK_SET, + Error, Result, Stat, EACCES, EBADF, EBADFD, EBADMSG, EBUSY, EINVAL, EIO, EISDIR, ENOENT, + ENOSYS, ENOTDIR, EOPNOTSUPP, EPROTO, ESPIPE, MODE_CHR, MODE_DIR, MODE_FILE, O_DIRECTORY, + O_RDWR, O_STAT, O_WRONLY, SEEK_CUR, SEEK_END, SEEK_SET, }; use super::{port, usb}; @@ -61,10 +62,16 @@ lazy_static! { .expect("Failed to create the regex for the port/attach scheme."); static ref REGEX_PORT_DETACH: Regex = Regex::new(r"^port([\d\.]+)/detach$") .expect("Failed to create the regex for the port/detach scheme."); + static ref REGEX_PORT_SUSPEND: Regex = Regex::new(r"^port([\d\.]+)/suspend$") + .expect("Failed to create the regex for the port/suspend scheme."); + static ref REGEX_PORT_RESUME: Regex = Regex::new(r"^port([\d\.]+)/resume$") + .expect("Failed to create the regex for the port/resume scheme."); static ref REGEX_PORT_DESCRIPTORS: Regex = Regex::new(r"^port([\d\.]+)/descriptors$") .expect("Failed to create the regex for the port/descriptors"); static ref REGEX_PORT_STATE: Regex = Regex::new(r"^port([\d\.]+)/state$") .expect("Failed to create the regex for the port/state scheme"); + static ref REGEX_PORT_PM_STATE: Regex = Regex::new(r"^port([\d\.]+)/pm_state$") + .expect("Failed to create the regex for the port/pm_state scheme"); static ref REGEX_PORT_REQUEST: Regex = Regex::new(r"^port([\d\.]+)/request$") .expect("Failed to create the regex for the port/request scheme"); static ref REGEX_PORT_ENDPOINTS: Regex = Regex::new(r"^port([\d\.]+)/endpoints$") @@ -138,12 +145,15 @@ pub enum Handle { Port(PortId, Vec), // port, contents PortDesc(PortId, Vec), // port, contents PortState(PortId), // port + PortPmState(PortId), // port PortReq(PortId, PortReqState), // port, state Endpoints(PortId, Vec), // port, contents Endpoint(PortId, u8, EndpointHandleTy), // port, endpoint, state ConfigureEndpoints(PortId), // port AttachDevice(PortId), // port DetachDevice(PortId), // port + SuspendDevice(PortId), // port + ResumeDevice(PortId), // port SchemeRoot, } @@ -173,6 +183,8 @@ enum SchemeParameters { PortDesc(PortId), // port number /// /port/state PortState(PortId), // port number + /// /port/pm_state + PortPmState(PortId), // port number /// /port/request PortReq(PortId), // port number /// /port/endpoints @@ -188,6 +200,10 @@ enum SchemeParameters { AttachDevice(PortId), // port number /// /port/detach DetachDevice(PortId), // port number + /// /port/suspend + SuspendDevice(PortId), // port number + /// /port/resume + ResumeDevice(PortId), // port number } impl Handle { @@ -210,6 +226,9 @@ impl Handle { Handle::PortState(port_num) => { format!("port{}/state", port_num) } + Handle::PortPmState(port_num) => { + format!("port{}/pm_state", port_num) + } Handle::PortReq(port_num, _) => { format!("port{}/request", port_num) } @@ -236,6 +255,12 @@ impl Handle { Handle::DetachDevice(port_num) => { format!("port{}/detach", port_num) } + Handle::SuspendDevice(port_num) => { + format!("port{}/suspend", port_num) + } + Handle::ResumeDevice(port_num) => { + format!("port{}/resume", port_num) + } Handle::SchemeRoot => String::from(""), } } @@ -259,10 +284,13 @@ impl Handle { &Handle::PortReq(_, PortReqState::Tmp) => unreachable!(), &Handle::PortReq(_, PortReqState::TmpSetup(_)) => unreachable!(), &Handle::PortState(_) => HandleType::Character, + &Handle::PortPmState(_) => HandleType::Character, &Handle::PortReq(_, _) => HandleType::Character, &Handle::ConfigureEndpoints(_) => HandleType::Character, &Handle::AttachDevice(_) => HandleType::Character, &Handle::DetachDevice(_) => HandleType::Character, + &Handle::SuspendDevice(_) => HandleType::Character, + &Handle::ResumeDevice(_) => HandleType::Character, &Handle::Endpoint(_, _, ref st) => match st { EndpointHandleTy::Data => HandleType::Character, EndpointHandleTy::Ctl => HandleType::Character, @@ -290,10 +318,13 @@ impl Handle { &Handle::PortReq(_, PortReqState::Tmp) => None, &Handle::PortReq(_, PortReqState::TmpSetup(_)) => None, &Handle::PortState(_) => None, + &Handle::PortPmState(_) => None, &Handle::PortReq(_, _) => None, &Handle::ConfigureEndpoints(_) => None, &Handle::AttachDevice(_) => None, &Handle::DetachDevice(_) => None, + &Handle::SuspendDevice(_) => None, + &Handle::ResumeDevice(_) => None, &Handle::Endpoint(_, _, ref st) => match st { EndpointHandleTy::Data => None, EndpointHandleTy::Ctl => None, @@ -384,6 +415,14 @@ impl SchemeParameters { let port_num = get_port_id_from_regex(®EX_PORT_DETACH, scheme, 0)?; Ok(Self::DetachDevice(port_num)) + } else if REGEX_PORT_SUSPEND.is_match(scheme) { + let port_num = get_port_id_from_regex(®EX_PORT_SUSPEND, scheme, 0)?; + + Ok(Self::SuspendDevice(port_num)) + } else if REGEX_PORT_RESUME.is_match(scheme) { + let port_num = get_port_id_from_regex(®EX_PORT_RESUME, scheme, 0)?; + + Ok(Self::ResumeDevice(port_num)) } else if REGEX_PORT_DESCRIPTORS.is_match(scheme) { let port_num = get_port_id_from_regex(®EX_PORT_DESCRIPTORS, scheme, 0)?; @@ -392,6 +431,10 @@ impl SchemeParameters { let port_num = get_port_id_from_regex(®EX_PORT_STATE, scheme, 0)?; Ok(Self::PortState(port_num)) + } else if REGEX_PORT_PM_STATE.is_match(scheme) { + let port_num = get_port_id_from_regex(®EX_PORT_PM_STATE, scheme, 0)?; + + Ok(Self::PortPmState(port_num)) } else if REGEX_PORT_REQUEST.is_match(scheme) { let port_num = get_port_id_from_regex(®EX_PORT_REQUEST, scheme, 0)?; @@ -524,6 +567,39 @@ pub enum AnyDescriptor { SuperSpeedPlusCompanion(usb::SuperSpeedPlusIsochCmpDescriptor), } +#[derive(Clone, Copy)] +struct ConfigureContextSnapshot { + add_context: u32, + drop_context: u32, + control: u32, + slot_a: u32, + slot_b: u32, +} + +#[derive(Clone, Copy)] +struct EndpointContextSnapshot { + a: u32, + b: u32, + trl: u32, + trh: u32, + c: u32, +} + +impl EndpointContextSnapshot { + fn capture_values(a: u32, b: u32, trl: u32, trh: u32, c: u32) -> Self { + Self { a, b, trl, trh, c } + } +} + +struct EndpointProgram { + endp_num_xhc: u8, + a: u32, + b: u32, + trl: u32, + trh: u32, + c: u32, +} + impl AnyDescriptor { fn parse(bytes: &[u8]) -> Option<(Self, usize)> { if bytes.len() < 2 { @@ -640,6 +716,8 @@ impl Xhci { where D: FnMut(&mut Trb, bool) -> ControlFlow, { + self.ensure_port_active(port_num)?; + let future = { let mut port_state = self.port_state_mut(port_num)?; let slot = port_state.slot; @@ -710,6 +788,8 @@ impl Xhci { where D: FnMut(&mut Trb, bool) -> ControlFlow, { + self.ensure_port_active(port_num)?; + let endp_idx = endp_num.checked_sub(1).ok_or(Error::new(EIO))?; let mut port_state = self.port_state_mut(port_num)?; @@ -835,7 +915,10 @@ impl Xhci { port, usb::Setup::set_interface(interface_num, alternate_setting), ) - .await + .await?; + let mut port_state = self.port_states.get_mut(&port).ok_or(Error::new(EBADFD))?; + port_state.active_ifaces.insert(interface_num, alternate_setting); + Ok(()) } async fn reset_endpoint(&self, port_num: PortId, endp_num: u8, tsp: bool) -> Result<()> { @@ -950,35 +1033,114 @@ impl Xhci { self.port_states.get_mut(&port).ok_or(Error::new(EBADF)) } + fn restore_configure_input_context( + &self, + port: PortId, + snapshot: ConfigureContextSnapshot, + endpoint_snapshots: &[(usize, EndpointContextSnapshot)], + ) -> Result { + let port_state = self.port_states.get(&port).ok_or(Error::new(EBADFD))?; + let mut input_context = port_state.input_context.lock().unwrap(); + + input_context.add_context.write(snapshot.add_context); + input_context.drop_context.write(snapshot.drop_context); + input_context.control.write(snapshot.control); + input_context.device.slot.a.write(snapshot.slot_a); + input_context.device.slot.b.write(snapshot.slot_b); + + for (endp_i, endp_snapshot) in endpoint_snapshots { + input_context.device.endpoints[*endp_i].a.write(endp_snapshot.a); + input_context.device.endpoints[*endp_i].b.write(endp_snapshot.b); + input_context.device.endpoints[*endp_i].trl.write(endp_snapshot.trl); + input_context.device.endpoints[*endp_i].trh.write(endp_snapshot.trh); + input_context.device.endpoints[*endp_i].c.write(endp_snapshot.c); + } + + Ok(input_context.physical()) + } + + async fn rollback_configure_attempt( + &self, + port: PortId, + slot: u8, + configure_snapshot: ConfigureContextSnapshot, + endpoint_snapshots: &[(usize, EndpointContextSnapshot)], + stage: &str, + ) { + let rollback_input_context_physical = match self.restore_configure_input_context( + port, + configure_snapshot, + endpoint_snapshots, + ) { + Ok(physical) => physical, + Err(restore_err) => { + warn!( + "failed to restore configure input context after {}: {:?}", + stage, restore_err + ); + return; + } + }; + + let (rollback_event_trb, rollback_command_trb) = self + .execute_command(|trb, cycle| { + trb.configure_endpoint(slot, rollback_input_context_physical, cycle) + }) + .await; + + if let Err(rollback_err) = handle_event_trb( + "CONFIGURE_ENDPOINT_ROLLBACK", + &rollback_event_trb, + &rollback_command_trb, + ) { + warn!( + "failed to roll back CONFIGURE_ENDPOINT after {}: {:?}", + stage, rollback_err + ); + } + } + async fn configure_endpoints_once( &self, port: PortId, req: &ConfigureEndpointsReq, ) -> Result<()> { - let (endp_desc_count, new_context_entries, configuration_value) = { - let mut port_state = self.port_states.get_mut(&port).ok_or(Error::new(EBADFD))?; - - port_state.cfg_idx = Some(req.config_desc); + let (dev_desc, endpoint_descs, new_context_entries, configuration_value, speed_id) = { + let port_state = self.port_states.get(&port).ok_or(Error::new(EBADFD))?; + let dev_desc = port_state.dev_desc.as_ref().ok_or(Error::new(EBADFD))?.clone(); + let speed_id = port_state.protocol_speed; - let config_desc = port_state - .dev_desc - .as_ref() - .unwrap() + let config_desc = dev_desc .config_descs .iter() .find(|desc| desc.configuration_value == req.config_desc) .ok_or(Error::new(EBADFD))?; - //TODO: USE ENDPOINTS FROM ALL INTERFACES - let mut endp_desc_count = 0; - let mut new_context_entries = 1; - for if_desc in config_desc.interface_descs.iter() { - for endpoint in if_desc.endpoints.iter() { - endp_desc_count += 1; - let entry = Self::endp_num_to_dci(endp_desc_count, endpoint); - if entry > new_context_entries { - new_context_entries = entry; - } + let configuration_value = config_desc.configuration_value; + + let endpoint_descs = if let Some(iface_num) = req.interface_desc { + let alt = req.alternate_setting.unwrap_or(0); + config_desc + .interface_descs + .iter() + .filter(|if_desc| if_desc.number == iface_num && if_desc.alternate_setting == alt) + .flat_map(|if_desc| if_desc.endpoints.iter().copied()) + .collect::>() + } else { + config_desc + .interface_descs + .iter() + .filter(|if_desc| if_desc.alternate_setting == 0) + .flat_map(|if_desc| if_desc.endpoints.iter().copied()) + .collect::>() + }; + + let endp_desc_count = endpoint_descs.len(); + let mut new_context_entries = 1u8; + for (endp_idx, endpoint) in endpoint_descs.iter().enumerate() { + let entry = Self::endp_num_to_dci(endp_idx as u8 + 1, endpoint); + if entry > new_context_entries { + new_context_entries = entry; } } new_context_entries += 1; @@ -989,74 +1151,22 @@ impl Xhci { } ( - endp_desc_count, + dev_desc, + endpoint_descs, new_context_entries, - config_desc.configuration_value, + configuration_value, + speed_id, ) }; let lec = self.cap.lec(); let log_max_psa_size = self.cap.max_psa_size(); - let port_speed_id = self.ports.lock().unwrap()[port.root_hub_port_index()].speed(); - let speed_id: &ProtocolSpeed = self.lookup_psiv(port, port_speed_id).ok_or_else(|| { - warn!("no speed_id"); - Error::new(EIO) - })?; - - { - let port_state = self.port_states.get(&port).ok_or(Error::new(EBADFD))?; - let mut input_context = port_state.input_context.lock().unwrap(); - - // Configure the slot context as well, which holds the last index of the endp descs. - input_context.add_context.write(1); - input_context.drop_context.write(0); - - const CONTEXT_ENTRIES_MASK: u32 = 0xF800_0000; - const CONTEXT_ENTRIES_SHIFT: u8 = 27; - - const HUB_PORTS_MASK: u32 = 0xFF00_0000; - const HUB_PORTS_SHIFT: u8 = 24; - - let mut current_slot_a = input_context.device.slot.a.read(); - let mut current_slot_b = input_context.device.slot.b.read(); - - // Set context entries - current_slot_a &= !CONTEXT_ENTRIES_MASK; - current_slot_a |= - (u32::from(new_context_entries) << CONTEXT_ENTRIES_SHIFT) & CONTEXT_ENTRIES_MASK; - - // Set hub data - current_slot_a &= !(1 << 26); - current_slot_b &= !HUB_PORTS_MASK; - if let Some(hub_ports) = req.hub_ports { - current_slot_a |= 1 << 26; - current_slot_b |= (u32::from(hub_ports) << HUB_PORTS_SHIFT) & HUB_PORTS_MASK; - } - - input_context.device.slot.a.write(current_slot_a); - input_context.device.slot.b.write(current_slot_b); - - let control = if self.op.lock().unwrap().cie() { - (u32::from(req.alternate_setting.unwrap_or(0)) << 16) - | (u32::from(req.interface_desc.unwrap_or(0)) << 8) - | u32::from(configuration_value) - } else { - 0 - }; - input_context.control.write(control); - } + let mut staged_endpoint_states = BTreeMap::new(); + let mut endpoint_programs = Vec::new(); - for endp_idx in 0..endp_desc_count as u8 { - let endp_num = endp_idx + 1; - - let mut port_state = self.port_states.get_mut(&port).ok_or(Error::new(EBADFD))?; - let dev_desc = port_state.dev_desc.as_ref().unwrap(); - let endp_desc = port_state.get_endp_desc(endp_idx).ok_or_else(|| { - warn!("failed to find endpoint {}", endp_idx); - Error::new(EIO) - })?; - - let endp_num_xhc = Self::endp_num_to_dci(endp_num, endp_desc); + for (endp_idx, endp_desc) in endpoint_descs.iter().copied().enumerate() { + let endp_num = endp_idx as u8 + 1; + let endp_num_xhc = Self::endp_num_to_dci(endp_num, &endp_desc); let usb_log_max_streams = endp_desc.log_max_streams(); @@ -1078,20 +1188,20 @@ impl Xhci { let mult = endp_desc.isoch_mult(lec); - let max_packet_size = Self::endp_ctx_max_packet_size(endp_desc); - let max_burst_size = Self::endp_ctx_max_burst(speed_id, dev_desc, endp_desc); + let max_packet_size = Self::endp_ctx_max_packet_size(&endp_desc); + let max_burst_size = Self::endp_ctx_max_burst(speed_id, &dev_desc, &endp_desc); let max_esit_payload = Self::endp_ctx_max_esit_payload( speed_id, - dev_desc, - endp_desc, + &dev_desc, + &endp_desc, max_packet_size, max_burst_size, ); let max_esit_payload_lo = max_esit_payload as u16; let max_esit_payload_hi = ((max_esit_payload & 0x00FF_0000) >> 16) as u8; - let interval = Self::endp_ctx_interval(speed_id, endp_desc); + let interval = Self::endp_ctx_interval(speed_id, &endp_desc); let max_error_count = 3; let ep_ty = endp_desc.xhci_ep_type()?; @@ -1114,7 +1224,7 @@ impl Xhci { assert_eq!(max_error_count & 0x3, max_error_count); assert_ne!(ep_ty, 0); // 0 means invalid. - let ring_ptr = if usb_log_max_streams.is_some() { + let (endpoint_state, ring_ptr) = if usb_log_max_streams.is_some() { let mut array = StreamContextArray::new::(self.cap.ac64(), 1 << (primary_streams + 1))?; @@ -1127,15 +1237,13 @@ impl Xhci { array_ptr, "stream ctx ptr not aligned to 16 bytes" ); - port_state.endpoint_states.insert( - endp_num, + ( EndpointState { transfer: super::RingOrStreams::Streams(array), driver_if_state: EndpIfState::Init, }, - ); - - array_ptr + array_ptr, + ) } else { let ring = Ring::new::(self.cap.ac64(), 16, true)?; let ring_ptr = ring.register(); @@ -1145,68 +1253,205 @@ impl Xhci { ring_ptr, "ring pointer not aligned to 16 bytes" ); - port_state.endpoint_states.insert( - endp_num, + ( EndpointState { transfer: super::RingOrStreams::Ring(ring), driver_if_state: EndpIfState::Init, }, - ); - ring_ptr + ring_ptr, + ) }; assert_eq!(primary_streams & 0x1F, primary_streams); - let mut input_context = port_state.input_context.lock().unwrap(); - input_context.add_context.writef(1 << endp_num_xhc, true); - - let endp_i = endp_num_xhc as usize - 1; - input_context.device.endpoints[endp_i].a.write( - u32::from(mult) << 8 + staged_endpoint_states.insert(endp_num, endpoint_state); + endpoint_programs.push(EndpointProgram { + endp_num_xhc, + a: u32::from(mult) << 8 | u32::from(primary_streams) << 10 | u32::from(linear_stream_array) << 15 | u32::from(interval) << 16 | u32::from(max_esit_payload_hi) << 24, - ); - input_context.device.endpoints[endp_i].b.write( - max_error_count << 1 + b: max_error_count << 1 | u32::from(ep_ty) << 3 | u32::from(host_initiate_disable) << 7 | u32::from(max_burst_size) << 8 | u32::from(max_packet_size) << 16, - ); - - input_context.device.endpoints[endp_i] - .trl - .write(ring_ptr as u32); - input_context.device.endpoints[endp_i] - .trh - .write((ring_ptr >> 32) as u32); - - input_context.device.endpoints[endp_i] - .c - .write(u32::from(avg_trb_len) | (u32::from(max_esit_payload_lo) << 16)); + trl: ring_ptr as u32, + trh: (ring_ptr >> 32) as u32, + c: u32::from(avg_trb_len) | (u32::from(max_esit_payload_lo) << 16), + }); - log::debug!("initialized endpoint {}", endp_num); + log::debug!("staged endpoint {}", endp_num); } - { + let (configure_snapshot, endpoint_snapshots, input_context_physical) = { let port_state = self.port_states.get(&port).ok_or(Error::new(EBADFD))?; - let slot = port_state.slot; - let input_context_physical = port_state.input_context.lock().unwrap().physical(); + let mut input_context = port_state.input_context.lock().unwrap(); + + let configure_snapshot = ConfigureContextSnapshot { + add_context: input_context.add_context.read(), + drop_context: input_context.drop_context.read(), + control: input_context.control.read(), + slot_a: input_context.device.slot.a.read(), + slot_b: input_context.device.slot.b.read(), + }; - let (event_trb, command_trb) = self - .execute_command(|trb, cycle| { - trb.configure_endpoint(slot, input_context_physical, cycle) + let endpoint_snapshots = endpoint_programs + .iter() + .map(|program| { + let endp_i = program.endp_num_xhc as usize - 1; + ( + endp_i, + EndpointContextSnapshot::capture_values( + input_context.device.endpoints[endp_i].a.read(), + input_context.device.endpoints[endp_i].b.read(), + input_context.device.endpoints[endp_i].trl.read(), + input_context.device.endpoints[endp_i].trh.read(), + input_context.device.endpoints[endp_i].c.read(), + ), + ) }) - .await; + .collect::>(); + + // Configure the slot context as well, which holds the last index of the endp descs. + input_context.add_context.write(1); + input_context.drop_context.write(0); - //self.event_handler_finished(); + const CONTEXT_ENTRIES_MASK: u32 = 0xF800_0000; + const CONTEXT_ENTRIES_SHIFT: u8 = 27; + + const HUB_PORTS_MASK: u32 = 0xFF00_0000; + const HUB_PORTS_SHIFT: u8 = 24; + + let mut current_slot_a = input_context.device.slot.a.read(); + let mut current_slot_b = input_context.device.slot.b.read(); + + current_slot_a &= !CONTEXT_ENTRIES_MASK; + current_slot_a |= + (u32::from(new_context_entries) << CONTEXT_ENTRIES_SHIFT) & CONTEXT_ENTRIES_MASK; + + current_slot_a &= !(1 << 26); + current_slot_b &= !HUB_PORTS_MASK; + if let Some(hub_ports) = req.hub_ports { + current_slot_a |= 1 << 26; + current_slot_b |= (u32::from(hub_ports) << HUB_PORTS_SHIFT) & HUB_PORTS_MASK; + } + + input_context.device.slot.a.write(current_slot_a); + input_context.device.slot.b.write(current_slot_b); + + let control = if self.op.lock().unwrap().cie() { + (u32::from(req.alternate_setting.unwrap_or(0)) << 16) + | (u32::from(req.interface_desc.unwrap_or(0)) << 8) + | u32::from(configuration_value) + } else { + 0 + }; + input_context.control.write(control); + + for program in &endpoint_programs { + let endp_i = program.endp_num_xhc as usize - 1; + input_context.add_context.writef(1 << program.endp_num_xhc, true); + input_context.device.endpoints[endp_i].a.write(program.a); + input_context.device.endpoints[endp_i].b.write(program.b); + input_context.device.endpoints[endp_i].trl.write(program.trl); + input_context.device.endpoints[endp_i].trh.write(program.trh); + input_context.device.endpoints[endp_i].c.write(program.c); + } + + (configure_snapshot, endpoint_snapshots, input_context.physical()) + }; + + let slot = self.port_states.get(&port).ok_or(Error::new(EBADFD))?.slot; + + let (event_trb, command_trb) = self + .execute_command(|trb, cycle| trb.configure_endpoint(slot, input_context_physical, cycle)) + .await; + + if let Err(err) = handle_event_trb("CONFIGURE_ENDPOINT", &event_trb, &command_trb) { + self.rollback_configure_attempt( + port, + slot, + configure_snapshot, + &endpoint_snapshots, + "CONFIGURE_ENDPOINT failure", + ) + .await; + return Err(err); + } + + if self.consume_test_hook("fail_after_configure_endpoint") { + info!( + "xhcid: test hook injecting failure after CONFIGURE_ENDPOINT for port {}", + port + ); + self.rollback_configure_attempt( + port, + slot, + configure_snapshot, + &endpoint_snapshots, + "test hook fail_after_configure_endpoint", + ) + .await; + return Err(Error::new(EIO)); + } + + if let Err(err) = self.set_configuration(port, configuration_value).await { + self.rollback_configure_attempt( + port, + slot, + configure_snapshot, + &endpoint_snapshots, + "set_configuration failure", + ) + .await; + return Err(err); + } - handle_event_trb("CONFIGURE_ENDPOINT", &event_trb, &command_trb)?; + if self.consume_test_hook("fail_after_set_configuration") { + info!( + "xhcid: test hook injecting failure after SET_CONFIGURATION for port {}", + port + ); + self.rollback_configure_attempt( + port, + slot, + configure_snapshot, + &endpoint_snapshots, + "test hook fail_after_set_configuration", + ) + .await; + return Err(Error::new(EIO)); } - // Tell the device about this configuration. - self.set_configuration(port, configuration_value).await?; + { + let mut port_state = self.port_states.get_mut(&port).ok_or(Error::new(EBADFD))?; + port_state.cfg_idx = Some(configuration_value); + port_state.endpoint_states.retain(|endp_num, _| *endp_num == 0); + for (endp_num, endpoint_state) in staged_endpoint_states { + port_state.endpoint_states.insert(endp_num, endpoint_state); + } + if let Some(iface_num) = req.interface_desc { + let alt = req.alternate_setting.unwrap_or(0); + port_state.active_ifaces.insert(iface_num, alt); + } else if port_state.active_ifaces.is_empty() { + let default_iface_entries: Vec<(u8, u8)> = port_state + .dev_desc + .as_ref() + .and_then(|dd| dd.config_descs.iter().find(|cd| cd.configuration_value == configuration_value)) + .map(|cd| { + cd.interface_descs + .iter() + .filter(|if_desc| if_desc.alternate_setting == 0) + .map(|if_desc| (if_desc.number, 0u8)) + .collect() + }) + .unwrap_or_default(); + for (iface_num, alt) in default_iface_entries { + port_state.active_ifaces.insert(iface_num, alt); + } + } + } Ok(()) } @@ -1857,7 +2102,7 @@ impl Xhci { if (flags & O_DIRECTORY != 0) || (flags & O_STAT != 0) { let mut contents = Vec::new(); - write!(contents, "descriptors\nendpoints\n").unwrap(); + write!(contents, "descriptors\nendpoints\npm_state\nsuspend\nresume\n").unwrap(); if self.slot_state( self.port_states @@ -1894,6 +2139,14 @@ impl Xhci { Ok(Handle::PortState(port_num)) } + fn open_handle_port_pm_state(&self, port_num: PortId, flags: usize) -> Result { + if flags & O_DIRECTORY != 0 && flags & O_STAT == 0 { + return Err(Error::new(ENOTDIR)); + } + + Ok(Handle::PortPmState(port_num)) + } + /// implements open() for /port/endpoints /// /// # Arguments @@ -2088,6 +2341,30 @@ impl Xhci { Ok(Handle::DetachDevice(port_num)) } + fn open_handle_suspend_device(&self, port_num: PortId, flags: usize) -> Result { + if flags & O_DIRECTORY != 0 && flags & O_STAT == 0 { + return Err(Error::new(ENOTDIR)); + } + + if flags & O_RDWR != O_WRONLY && flags & O_STAT == 0 { + return Err(Error::new(EACCES)); + } + + Ok(Handle::SuspendDevice(port_num)) + } + + fn open_handle_resume_device(&self, port_num: PortId, flags: usize) -> Result { + if flags & O_DIRECTORY != 0 && flags & O_STAT == 0 { + return Err(Error::new(ENOTDIR)); + } + + if flags & O_RDWR != O_WRONLY && flags & O_STAT == 0 { + return Err(Error::new(EACCES)); + } + + Ok(Handle::ResumeDevice(port_num)) + } + /// implements open() for /port/request /// /// # Arguments @@ -2156,6 +2433,9 @@ impl SchemeSync for &Xhci { SchemeParameters::PortState(port_number) => { self.open_handle_port_state(port_number, flags)? } + SchemeParameters::PortPmState(port_number) => { + self.open_handle_port_pm_state(port_number, flags)? + } SchemeParameters::PortReq(port_number) => { self.open_handle_port_request(port_number, flags)? } @@ -2174,6 +2454,12 @@ impl SchemeSync for &Xhci { SchemeParameters::DetachDevice(port_number) => { self.open_handle_detach_device(port_number, flags)? } + SchemeParameters::SuspendDevice(port_number) => { + self.open_handle_suspend_device(port_number, flags)? + } + SchemeParameters::ResumeDevice(port_number) => { + self.open_handle_resume_device(port_number, flags)? + } }; let fd = self.next_handle.fetch_add(1, atomic::Ordering::Relaxed); @@ -2204,7 +2490,11 @@ impl SchemeSync for &Xhci { //If we have a handle to the configure scheme, we need to mark it as write only. match &*guard { - Handle::ConfigureEndpoints(_) | Handle::AttachDevice(_) | Handle::DetachDevice(_) => { + Handle::ConfigureEndpoints(_) + | Handle::AttachDevice(_) + | Handle::DetachDevice(_) + | Handle::SuspendDevice(_) + | Handle::ResumeDevice(_) => { stat.st_mode = stat.st_mode | 0o200; } _ => {} @@ -2254,6 +2544,8 @@ impl SchemeSync for &Xhci { Handle::ConfigureEndpoints(_) => Err(Error::new(EBADF)), Handle::AttachDevice(_) => Err(Error::new(EBADF)), Handle::DetachDevice(_) => Err(Error::new(EBADF)), + Handle::SuspendDevice(_) => Err(Error::new(EBADF)), + Handle::ResumeDevice(_) => Err(Error::new(EBADF)), Handle::SchemeRoot => Err(Error::new(EBADF)), &mut Handle::Endpoint(port_num, endp_num, ref mut st) => match st { @@ -2285,6 +2577,10 @@ impl SchemeSync for &Xhci { Ok(Xhci::::write_dyn_string(string, buf, offset)) } + &mut Handle::PortPmState(port_num) => { + let ps = self.port_states.get(&port_num).ok_or(Error::new(EBADF))?; + Ok(Xhci::::write_dyn_string(ps.pm_state.as_str().as_bytes(), buf, offset)) + } &mut Handle::PortReq(port_num, ref mut st) => { let state = std::mem::replace(st, PortReqState::Tmp); drop(guard); // release the lock @@ -2324,6 +2620,14 @@ impl SchemeSync for &Xhci { block_on(self.detach_device(port_num))?; Ok(buf.len()) } + &mut Handle::SuspendDevice(port_num) => { + block_on(self.suspend_device(port_num))?; + Ok(buf.len()) + } + &mut Handle::ResumeDevice(port_num) => { + block_on(self.resume_device(port_num))?; + Ok(buf.len()) + } &mut Handle::Endpoint(port_num, endp_num, ref ep_file_ty) => match ep_file_ty { EndpointHandleTy::Ctl => block_on(self.on_write_endp_ctl(port_num, endp_num, buf)), EndpointHandleTy::Data => { @@ -2348,6 +2652,54 @@ impl SchemeSync for &Xhci { } impl Xhci { + fn ensure_port_active(&self, port_num: PortId) -> Result<()> { + let port_state = self.port_states.get(&port_num).ok_or(Error::new(EBADFD))?; + + match port_state.pm_state { + super::PortPmState::Active => Ok(()), + super::PortPmState::Suspended => { + info!( + "xhcid: port {} rejected routable operation while suspended", + port_num + ); + Err(Error::new(EBUSY)) + } + } + } + + pub async fn suspend_device(&self, port_num: PortId) -> Result<()> { + let mut port_state = self.port_states.get_mut(&port_num).ok_or(Error::new(EBADFD))?; + + if port_state.pm_state != super::PortPmState::Active { + return Err(Error::new(EBUSY)); + } + + port_state.pm_state = super::PortPmState::Suspended; + info!("xhcid: suspended port {}", port_num); + Ok(()) + } + + pub async fn resume_device(&self, port_num: PortId) -> Result<()> { + let mut port_state = self.port_states.get_mut(&port_num).ok_or(Error::new(EBADFD))?; + + if port_state.pm_state == super::PortPmState::Active { + return Ok(()); + } + + let slot_state = self.slot_state(port_state.slot as usize); + if slot_state != SlotState::Addressed as u8 && slot_state != SlotState::Configured as u8 { + warn!( + "refusing to resume port {} while slot {} is in controller state {}", + port_num, port_state.slot, slot_state + ); + return Err(Error::new(EIO)); + } + + port_state.pm_state = super::PortPmState::Active; + info!("xhcid: resumed port {}", port_num); + Ok(()) + } + pub fn get_endp_status(&self, port_num: PortId, endp_num: u8) -> Result { let port_state = self.port_states.get(&port_num).ok_or(Error::new(EBADFD))?; @@ -2398,6 +2750,8 @@ impl Xhci { endp_num: u8, clear_feature: bool, ) -> Result<()> { + self.ensure_port_active(port_num)?; + if self.get_endp_status(port_num, endp_num)? != EndpointStatus::Halted { return Err(Error::new(EPROTO)); }