diff --git a/drivers/acpid/src/acpi.rs b/drivers/acpid/src/acpi.rs index 94a1eb17..a735a4e7 100644 --- a/drivers/acpid/src/acpi.rs +++ b/drivers/acpid/src/acpi.rs @@ -25,6 +25,8 @@ use amlserde::{AmlSerde, AmlSerdeValue}; #[cfg(target_arch = "x86_64")] pub mod dmar; +#[cfg(target_arch = "x86_64")] +use self::dmar::Dmar; use crate::aml_physmem::{AmlPageCache, AmlPhysMemHandler}; /// The raw SDT header struct, as defined by the ACPI specification. @@ -379,6 +381,12 @@ pub struct AcpiContext { tables: Vec, dsdt: Option, fadt: Option, + pm1a_cnt_blk: u64, + pm1b_cnt_blk: u64, + slp_typa_s5: u8, + slp_typb_s5: u8, + reset_reg: Option, + reset_value: u8, aml_symbols: RwLock, @@ -424,6 +432,63 @@ impl AcpiContext { .flatten() } + pub fn evaluate_acpi_method( + &mut self, + path: &str, + method: &str, + args: &[u64], + ) -> Result, AmlEvalError> { + let full_path = format!("{path}.{method}"); + let aml_name = + AmlName::from_str(&full_path).map_err(|_| AmlEvalError::DeserializationError)?; + let args = args + .iter() + .copied() + .map(AmlSerdeValue::Integer) + .collect::>(); + + match self.aml_eval(aml_name, args)? { + AmlSerdeValue::Integer(value) => Ok(vec![value]), + AmlSerdeValue::Package { contents } => contents + .into_iter() + .map(|value| match value { + AmlSerdeValue::Integer(value) => Ok(value), + _ => Err(AmlEvalError::DeserializationError), + }) + .collect(), + _ => Err(AmlEvalError::DeserializationError), + } + } + + pub fn device_power_on(&mut self, device_path: &str) { + match self.evaluate_acpi_method(device_path, "_PS0", &[]) { + Ok(values) => { + log::debug!("{}._PS0 => {:?}", device_path, values); + } + Err(error) => { + log::warn!("Failed to power on {} with _PS0: {:?}", device_path, error); + } + } + } + + pub fn device_power_off(&mut self, device_path: &str) { + match self.evaluate_acpi_method(device_path, "_PS3", &[]) { + Ok(values) => { + log::debug!("{}._PS3 => {:?}", device_path, values); + } + Err(error) => { + log::warn!("Failed to power off {} with _PS3: {:?}", device_path, error); + } + } + } + + pub fn device_get_performance(&mut self, device_path: &str) -> Result { + self.evaluate_acpi_method(device_path, "_PPC", &[])? + .into_iter() + .next() + .ok_or(AmlEvalError::DeserializationError) + } + pub fn init( rxsdt_physaddrs: impl Iterator, ec: Vec<(RegionSpace, Box)>, @@ -444,6 +509,12 @@ impl AcpiContext { tables, dsdt: None, fadt: None, + pm1a_cnt_blk: 0, + pm1b_cnt_blk: 0, + slp_typa_s5: 0, + slp_typb_s5: 0, + reset_reg: None, + reset_value: 0, // Temporary values aml_symbols: RwLock::new(AmlSymbols::new(ec)), @@ -458,7 +529,10 @@ impl AcpiContext { } Fadt::init(&mut this); - //TODO (hangs on real hardware): Dmar::init(&this); + // DMAR (Intel VT-d) init — previously disabled due to iterator bug (type_bytes copied + // instead of len_bytes in DmarRawIter). Safe to call now: on AMD systems, no DMAR table + // exists and this returns early with a warning. + Dmar::init(&this); this } @@ -562,92 +636,83 @@ impl AcpiContext { aml_symbols.symbol_cache = FxHashMap::default(); } - /// Set Power State - /// See https://uefi.org/sites/default/files/resources/ACPI_6_1.pdf - /// - search for PM1a - /// See https://forum.osdev.org/viewtopic.php?t=16990 for practical details - pub fn set_global_s_state(&self, state: u8) { - if state != 5 { - return; - } - let fadt = match self.fadt() { - Some(fadt) => fadt, - None => { - log::error!("Cannot set global S-state due to missing FADT."); - return; - } - }; - - let port = fadt.pm1a_control_block as u16; - let mut val = 1 << 13; - - let aml_symbols = self.aml_symbols.read(); + pub fn acpi_shutdown(&self) { + let pm1a_value = (u16::from(self.slp_typa_s5) << 10) | 0x2000; + let pm1b_value = (u16::from(self.slp_typb_s5) << 10) | 0x2000; - let s5_aml_name = match acpi::aml::namespace::AmlName::from_str("\\_S5") { - Ok(aml_name) => aml_name, - Err(error) => { - log::error!("Could not build AmlName for \\_S5, {:?}", error); + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + let Ok(pm1a_port) = u16::try_from(self.pm1a_cnt_blk) else { + log::error!("PM1a_CNT_BLK address is invalid: {:#X}", self.pm1a_cnt_blk); return; - } - }; + }; - let s5 = match &aml_symbols.aml_context { - Some(aml_context) => match aml_context.namespace.lock().get(s5_aml_name) { - Ok(s5) => s5, - Err(error) => { - log::error!("Cannot set S-state, missing \\_S5, {:?}", error); - return; + log::warn!( + "Shutdown with ACPI PM1a_CNT outw(0x{:X}, 0x{:X})", + pm1a_port, + pm1a_value + ); + Pio::::new(pm1a_port).write(pm1a_value); + + if self.pm1b_cnt_blk != 0 { + match u16::try_from(self.pm1b_cnt_blk) { + Ok(pm1b_port) => { + log::warn!( + "Shutdown with ACPI PM1b_CNT outw(0x{:X}, 0x{:X})", + pm1b_port, + pm1b_value + ); + Pio::::new(pm1b_port).write(pm1b_value); + } + Err(_) => { + log::error!("PM1b_CNT_BLK address is invalid: {:#X}", self.pm1b_cnt_blk); + } } - }, - None => { - log::error!("Cannot set S-state, AML context not initialized"); - return; } - }; + } - let package = match s5.deref() { - acpi::aml::object::Object::Package(package) => package, - _ => { - log::error!("Cannot set S-state, \\_S5 is not a package"); - return; - } - }; + #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))] + { + log::error!( + "Cannot shutdown with ACPI PM1_CNT writes on this architecture (PM1a={:#X}, PM1b={:#X})", + self.pm1a_cnt_blk, + self.pm1b_cnt_blk + ); + } + } - let slp_typa = match package[0].deref() { - acpi::aml::object::Object::Integer(i) => i.to_owned(), - _ => { - log::error!("typa is not an Integer"); - return; + pub fn acpi_reboot(&self) { + match self.reset_reg { + Some(reset_reg) => { + log::warn!( + "Reboot with ACPI reset register {:?} value {:#X}", + reset_reg, + self.reset_value + ); + reset_reg.write_u8(self.reset_value); } - }; - let slp_typb = match package[1].deref() { - acpi::aml::object::Object::Integer(i) => i.to_owned(), - _ => { - log::error!("typb is not an Integer"); - return; + None => { + log::error!("Cannot reboot with ACPI: no reset register present in FADT"); } - }; - - log::trace!("Shutdown SLP_TYPa {:X}, SLP_TYPb {:X}", slp_typa, slp_typb); - val |= slp_typa as u16; - - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - { - log::warn!("Shutdown with ACPI outw(0x{:X}, 0x{:X})", port, val); - Pio::::new(port).write(val); } + } - // TODO: Handle SLP_TYPb + /// Set Power State + /// See https://uefi.org/sites/default/files/resources/ACPI_6_1.pdf + /// - search for PM1a + /// See https://forum.osdev.org/viewtopic.php?t=16990 for practical details + pub fn set_global_s_state(&self, state: u8) { + if state != 5 { + return; + } - #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))] - { - log::error!( - "Cannot shutdown with ACPI outw(0x{:X}, 0x{:X}) on this architecture", - port, - val - ); + if self.fadt().is_none() { + log::error!("Cannot set global S-state due to missing FADT."); + return; } + self.acpi_shutdown(); + loop { core::hint::spin_loop(); } @@ -707,7 +772,7 @@ unsafe impl plain::Plain for FadtStruct {} #[repr(C, packed)] #[derive(Clone, Copy, Debug, Default)] -pub struct GenericAddressStructure { +pub struct GenericAddress { address_space: u8, bit_width: u8, bit_offset: u8, @@ -715,11 +780,77 @@ pub struct GenericAddressStructure { address: u64, } +impl GenericAddress { + pub fn is_empty(&self) -> bool { + self.address == 0 + } + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + pub fn write_u8(&self, value: u8) { + match self.address_space { + 0 => { + let raw_address = self.address; + let Ok(address) = usize::try_from(raw_address) else { + log::error!( + "Reset register physical address is invalid: {:#X}", + raw_address + ); + return; + }; + let page = address / PAGE_SIZE * PAGE_SIZE; + let offset = address % PAGE_SIZE; + let virt = unsafe { + common::physmap( + page, + PAGE_SIZE, + common::Prot::RW, + common::MemoryType::default(), + ) + }; + + match virt { + Ok(virt) => unsafe { + (virt as *mut u8).add(offset).write_volatile(value); + let _ = libredox::call::munmap(virt, PAGE_SIZE); + }, + Err(error) => { + log::error!("Failed to map ACPI reset register: {}", error); + } + } + } + 1 => match u16::try_from(self.address) { + Ok(port) => { + Pio::::new(port).write(value); + } + Err(_) => { + let raw_address = self.address; + log::error!("Reset register I/O port is invalid: {:#X}", raw_address); + } + }, + address_space => { + log::warn!( + "Unsupported ACPI reset register address space {} for {:?}", + address_space, + self + ); + } + } + } + + #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))] + pub fn write_u8(&self, _value: u8) { + log::error!( + "Cannot access ACPI reset register {:?} on this architecture", + self + ); + } +} + #[repr(C, packed)] #[derive(Clone, Copy, Debug)] pub struct FadtAcpi2Struct { // 12 byte structure; see below for details - pub reset_reg: GenericAddressStructure, + pub reset_reg: GenericAddress, pub reset_value: u8, reserved3: [u8; 3], @@ -728,14 +859,14 @@ pub struct FadtAcpi2Struct { pub x_firmware_control: u64, pub x_dsdt: u64, - pub x_pm1a_event_block: GenericAddressStructure, - pub x_pm1b_event_block: GenericAddressStructure, - pub x_pm1a_control_block: GenericAddressStructure, - pub x_pm1b_control_block: GenericAddressStructure, - pub x_pm2_control_block: GenericAddressStructure, - pub x_pm_timer_block: GenericAddressStructure, - pub x_gpe0_block: GenericAddressStructure, - pub x_gpe1_block: GenericAddressStructure, + pub x_pm1a_event_block: GenericAddress, + pub x_pm1b_event_block: GenericAddress, + pub x_pm1a_control_block: GenericAddress, + pub x_pm1b_control_block: GenericAddress, + pub x_pm2_control_block: GenericAddress, + pub x_pm_timer_block: GenericAddress, + pub x_gpe0_block: GenericAddress, + pub x_gpe1_block: GenericAddress, } unsafe impl plain::Plain for FadtAcpi2Struct {} @@ -793,9 +924,27 @@ impl Fadt { None => usize::try_from(fadt.dsdt).expect("expected any given u32 to fit within usize"), }; - log::debug!("FACP at {:X}", { dsdt_ptr }); + let pm1a_evt_blk = u64::from(fadt.pm1a_event_block); + let pm1b_evt_blk = u64::from(fadt.pm1b_event_block); + let pm1a_cnt_blk = u64::from(fadt.pm1a_control_block); + let pm1b_cnt_blk = u64::from(fadt.pm1b_control_block); + let (reset_reg, reset_value) = match fadt.acpi_2_struct() { + Some(fadt2) if !fadt2.reset_reg.is_empty() => { + (Some(fadt2.reset_reg), fadt2.reset_value) + } + _ => (None, 0), + }; - let dsdt_sdt = match Sdt::load_from_physical(fadt.dsdt as usize) { + log::debug!("FACP at {:X}", { dsdt_ptr }); + log::debug!( + "FADT power blocks: PM1a_EVT={:#X}, PM1b_EVT={:#X}, PM1a_CNT={:#X}, PM1b_CNT={:#X}", + pm1a_evt_blk, + pm1b_evt_blk, + pm1a_cnt_blk, + pm1b_cnt_blk + ); + + let dsdt_sdt = match Sdt::load_from_physical(dsdt_ptr) { Ok(dsdt) => dsdt, Err(error) => { log::error!("Failed to load DSDT: {}", error); @@ -803,8 +952,48 @@ impl Fadt { } }; + let (slp_typa_s5, slp_typb_s5) = match AmlName::from_str("\\_S5") { + Ok(s5_name) => match context.aml_eval(s5_name, Vec::new()) { + Ok(AmlSerdeValue::Package { contents }) => match (contents.get(0), contents.get(1)) + { + ( + Some(AmlSerdeValue::Integer(slp_typa)), + Some(AmlSerdeValue::Integer(slp_typb)), + ) => match (u8::try_from(*slp_typa), u8::try_from(*slp_typb)) { + (Ok(slp_typa_s5), Ok(slp_typb_s5)) => (slp_typa_s5, slp_typb_s5), + _ => { + log::warn!("\\_S5 values do not fit in u8: {:?}", contents); + (0, 0) + } + }, + _ => { + log::warn!("\\_S5 package did not contain two integers: {:?}", contents); + (0, 0) + } + }, + Ok(value) => { + log::warn!("\\_S5 returned unexpected AML value: {:?}", value); + (0, 0) + } + Err(error) => { + log::warn!("Failed to evaluate \\_S5: {:?}", error); + (0, 0) + } + }, + Err(error) => { + log::warn!("Could not build AmlName for \\_S5: {:?}", error); + (0, 0) + } + }; + context.fadt = Some(fadt.clone()); context.dsdt = Some(Dsdt(dsdt_sdt.clone())); + context.pm1a_cnt_blk = pm1a_cnt_blk; + context.pm1b_cnt_blk = pm1b_cnt_blk; + context.slp_typa_s5 = slp_typa_s5; + context.slp_typb_s5 = slp_typb_s5; + context.reset_reg = reset_reg; + context.reset_value = reset_value; context.tables.push(dsdt_sdt); } diff --git a/drivers/acpid/src/acpi/dmar/mod.rs b/drivers/acpid/src/acpi/dmar/mod.rs index c42b379a..f4dff276 100644 --- a/drivers/acpid/src/acpi/dmar/mod.rs +++ b/drivers/acpid/src/acpi/dmar/mod.rs @@ -474,10 +474,13 @@ impl<'sdt> Iterator for DmarRawIter<'sdt> { let len_bytes = <[u8; 2]>::try_from(type_bytes) .expect("expected a 2-byte slice to be convertible to [u8; 2]"); - let ty = u16::from_ne_bytes(type_bytes); - let len = u16::from_ne_bytes(len_bytes); + let len = u16::from_ne_bytes(len_bytes) as usize; - let len = usize::try_from(len).expect("expected u16 to fit within usize"); + if len < 4 { + return None; + } + + let ty = u16::from_ne_bytes(type_bytes); if len > remainder.len() { log::warn!("DMAR remapping structure length was smaller than the remaining length of the table."); diff --git a/drivers/input/usbhidd/src/main.rs b/drivers/input/usbhidd/src/main.rs index 15c5b778..68f8689c 100644 --- a/drivers/input/usbhidd/src/main.rs +++ b/drivers/input/usbhidd/src/main.rs @@ -247,7 +247,13 @@ fn main() -> Result<()> { reqs::set_idle(&handle, 1, 0, interface_num as u16).context("Failed to set idle")?; let report_desc_len = hid_desc.desc_len; - assert_eq!(hid_desc.desc_ty, REPORT_DESC_TY); + if hid_desc.desc_ty != REPORT_DESC_TY { + anyhow::bail!( + "unexpected HID descriptor type {:X}, expected {:X}", + hid_desc.desc_ty, + REPORT_DESC_TY + ); + } let mut report_desc_bytes = vec![0u8; report_desc_len as usize]; handle @@ -261,8 +267,8 @@ fn main() -> Result<()> { ) .context("Failed to retrieve report descriptor")?; - let mut handler = - ReportHandler::new(&report_desc_bytes).expect("failed to parse report descriptor"); + let mut handler = ReportHandler::new(&report_desc_bytes) + .map_err(|e| anyhow::anyhow!("failed to parse report descriptor: {}", e))?; let report_len = match endp_desc_opt { Some((_endp_num, endp_desc)) => endp_desc.max_packet_size as usize, @@ -318,10 +324,14 @@ fn main() -> Result<()> { let mut mouse_dy = 0i32; let mut scroll_y = 0i32; let mut buttons = last_buttons; - for event in handler - .handle(&report_buffer) - .expect("failed to parse report") - { + let events = match handler.handle(&report_buffer) { + Ok(events) => events, + Err(err) => { + log::warn!("failed to parse report: {}", err); + continue; + } + }; + for event in events { log::debug!("{}", event); if event.usage_page == UsagePage::GenericDesktop as u16 { if event.usage == GenericDesktopUsage::X as u16 { diff --git a/drivers/pcid/src/scheme.rs b/drivers/pcid/src/scheme.rs index c2caf804..95acdb57 100644 --- a/drivers/pcid/src/scheme.rs +++ b/drivers/pcid/src/scheme.rs @@ -21,6 +21,7 @@ enum Handle { TopLevel { entries: Vec }, Access, Device, + Config { addr: PciAddress }, Channel { addr: PciAddress, st: ChannelState }, SchemeRoot, } @@ -30,14 +31,20 @@ struct HandleWrapper { } impl Handle { fn is_file(&self) -> bool { - matches!(self, Self::Access | Self::Channel { .. }) + matches!( + self, + Self::Access | Self::Config { .. } | Self::Channel { .. } + ) } fn is_dir(&self) -> bool { !self.is_file() } // TODO: capability rather than root fn requires_root(&self) -> bool { - matches!(self, Self::Access | Self::Channel { .. }) + matches!( + self, + Self::Access | Self::Config { .. } | Self::Channel { .. } + ) } fn is_scheme_root(&self) -> bool { matches!(self, Self::SchemeRoot) @@ -132,6 +139,7 @@ impl SchemeSync for PciScheme { let (len, mode) = match handle.inner { Handle::TopLevel { ref entries } => (entries.len(), MODE_DIR | 0o755), Handle::Device => (DEVICE_CONTENTS.len(), MODE_DIR | 0o755), + Handle::Config { .. } => (256, MODE_CHR | 0o600), Handle::Access | Handle::Channel { .. } => (0, MODE_CHR | 0o600), Handle::SchemeRoot => return Err(Error::new(EBADF)), }; @@ -156,6 +164,18 @@ impl SchemeSync for PciScheme { match handle.inner { Handle::TopLevel { .. } => Err(Error::new(EISDIR)), Handle::Device => Err(Error::new(EISDIR)), + Handle::Config { addr } => { + let offset = _offset as u16; + let dword_offset = offset & !0x3; + let byte_offset = (offset & 0x3) as usize; + let bytes_to_read = buf.len().min(4 - byte_offset); + + let dword = unsafe { self.pcie.read(addr, dword_offset) }; + let bytes = dword.to_le_bytes(); + buf[..bytes_to_read] + .copy_from_slice(&bytes[byte_offset..byte_offset + bytes_to_read]); + Ok(bytes_to_read) + } Handle::Channel { addr: _, ref mut st, @@ -193,7 +213,9 @@ impl SchemeSync for PciScheme { return Ok(buf); } Handle::Device => DEVICE_CONTENTS, - Handle::Access | Handle::Channel { .. } => return Err(Error::new(ENOTDIR)), + Handle::Access | Handle::Config { .. } | Handle::Channel { .. } => { + return Err(Error::new(ENOTDIR)); + } Handle::SchemeRoot => return Err(Error::new(EBADF)), }; @@ -223,6 +245,20 @@ impl SchemeSync for PciScheme { } match handle.inner { + Handle::Config { addr } => { + let offset = _offset as u16; + let dword_offset = offset & !0x3; + let byte_offset = (offset & 0x3) as usize; + let bytes_to_write = buf.len().min(4 - byte_offset); + + let mut dword = unsafe { self.pcie.read(addr, dword_offset) }; + let mut bytes = dword.to_le_bytes(); + bytes[byte_offset..byte_offset + bytes_to_write] + .copy_from_slice(&buf[..bytes_to_write]); + dword = u32::from_le_bytes(bytes); + unsafe { self.pcie.write(addr, dword_offset, dword) }; + Ok(buf.len()) + } Handle::Channel { addr, ref mut st } => { Self::write_channel(&self.pcie, &mut self.tree, addr, st, buf) } @@ -318,6 +354,10 @@ impl PciScheme { func.enabled = false; } } + Some(HandleWrapper { + inner: Handle::Config { .. }, + .. + }) => {} _ => {} } } @@ -343,6 +383,7 @@ impl PciScheme { let path = &after[1..]; match path { + "config" => Handle::Config { addr }, "channel" => { if func.enabled { return Err(Error::new(ENOLCK)); diff --git a/drivers/storage/usbscsid/src/main.rs b/drivers/storage/usbscsid/src/main.rs index 5382d118..803b30fa 100644 --- a/drivers/storage/usbscsid/src/main.rs +++ b/drivers/storage/usbscsid/src/main.rs @@ -17,37 +17,55 @@ fn main() { fn daemon(daemon: daemon::Daemon) -> ! { let mut args = env::args().skip(1); - const USAGE: &'static str = "usbscsid "; + const USAGE: &str = "usbscsid "; - let scheme = args.next().expect(USAGE); - let port = args + let scheme = args.next().unwrap_or_else(|| { + eprintln!("usbscsid: {USAGE}"); + std::process::exit(1); + }); + let port: PortId = args .next() - .expect(USAGE) - .parse::() - .expect("Expected port ID"); - let protocol = args + .unwrap_or_else(|| { + eprintln!("usbscsid: {USAGE}"); + std::process::exit(1); + }) + .parse() + .unwrap_or_else(|e| { + eprintln!("usbscsid: invalid port ID: {e}"); + std::process::exit(1); + }); + let protocol_num: u8 = args .next() - .expect(USAGE) - .parse::() - .expect("protocol has to be a number 0-255"); + .unwrap_or_else(|| { + eprintln!("usbscsid: {USAGE}"); + std::process::exit(1); + }) + .parse() + .unwrap_or_else(|e| { + eprintln!("usbscsid: protocol must be a number 0-255: {e}"); + std::process::exit(1); + }); println!( "USB SCSI driver spawned with scheme `{}`, port {}, protocol {}", - scheme, port, protocol + scheme, port, protocol_num ); let disk_scheme_name = format!("disk.usb-{scheme}+{port}-scsi"); - // TODO: Use eventfds. - let handle = - XhciClientHandle::new(scheme.to_owned(), port).expect("Failed to open XhciClientHandle"); + let handle = XhciClientHandle::new(scheme.to_owned(), port) + .unwrap_or_else(|e| { + eprintln!("usbscsid: failed to open XhciClientHandle: {e}"); + std::process::exit(1); + }); let desc = handle .get_standard_descs() - .expect("Failed to get standard descriptors"); + .unwrap_or_else(|e| { + eprintln!("usbscsid: failed to get standard descriptors: {e}"); + std::process::exit(1); + }); - // TODO: Perhaps the drivers should just be given the config, interface, and alternate setting - // from xhcid. let (conf_desc, configuration_value, (if_desc, interface_num, alternate_setting)) = desc .config_descs .iter() @@ -65,7 +83,10 @@ fn daemon(daemon: daemon::Daemon) -> ! { interface_desc, )) }) - .expect("Failed to find suitable configuration"); + .unwrap_or_else(|| { + eprintln!("usbscsid: failed to find suitable SCSI BOT configuration"); + std::process::exit(1); + }); handle .configure_endpoints(&ConfigureEndpointsReq { @@ -74,20 +95,32 @@ fn daemon(daemon: daemon::Daemon) -> ! { alternate_setting: Some(alternate_setting), hub_ports: None, }) - .expect("Failed to configure endpoints"); - - let mut protocol = protocol::setup(&handle, protocol, &desc, &conf_desc, &if_desc) - .expect("Failed to setup protocol"); - - // TODO: Let all of the USB drivers fork or be managed externally, and xhcid won't have to keep - // track of all the drivers. - let mut scsi = Scsi::new(&mut *protocol).expect("usbscsid: failed to setup SCSI"); + .unwrap_or_else(|e| { + eprintln!("usbscsid: failed to configure endpoints: {e}"); + std::process::exit(1); + }); + + let mut protocol = protocol::setup(&handle, protocol_num, &desc, &conf_desc, &if_desc) + .unwrap_or_else(|| { + eprintln!("usbscsid: failed to setup protocol (protocol 0x{:02x})", protocol_num); + std::process::exit(1); + }); + + let mut scsi = Scsi::new(&mut *protocol).unwrap_or_else(|e| { + eprintln!("usbscsid: failed to setup SCSI: {e}"); + std::process::exit(1); + }); println!("SCSI initialized"); let mut buffer = [0u8; 512]; - scsi.read(&mut *protocol, 0, &mut buffer).unwrap(); - println!("DISK CONTENT: {}", base64::encode(&buffer[..])); + match scsi.read(&mut *protocol, 0, &mut buffer) { + Ok(_) => println!("DISK CONTENT: {}", base64::encode(&buffer[..])), + Err(e) => eprintln!("usbscsid: initial sector read failed: {e}"), + } - let event_queue = event::EventQueue::new().unwrap(); + let event_queue = event::EventQueue::new().unwrap_or_else(|e| { + eprintln!("usbscsid: failed to create event queue: {e}"); + std::process::exit(1); + }); event::user_data! { enum Event { @@ -119,13 +152,25 @@ fn daemon(daemon: daemon::Daemon) -> ! { Event::Scheme, event::EventFlags::READ, ) - .unwrap(); + .unwrap_or_else(|e| { + eprintln!("usbscsid: failed to subscribe to scheme events: {e}"); + std::process::exit(1); + }); for event in event_queue { - match event.unwrap().user_data { - Event::Scheme => driver_block::FuturesExecutor - .block_on(scheme.tick()) - .unwrap(), + match event { + Ok(ev) => match ev.user_data { + Event::Scheme => { + if let Err(e) = driver_block::FuturesExecutor.block_on(scheme.tick()) { + eprintln!("usbscsid: scheme tick error: {e}"); + break; + } + } + }, + Err(e) => { + eprintln!("usbscsid: event queue error: {e}"); + break; + } } } diff --git a/drivers/storage/usbscsid/src/protocol/bot.rs b/drivers/storage/usbscsid/src/protocol/bot.rs index b751d51a..87885653 100644 --- a/drivers/storage/usbscsid/src/protocol/bot.rs +++ b/drivers/storage/usbscsid/src/protocol/bot.rs @@ -88,6 +88,8 @@ pub struct BulkOnlyTransport<'a> { bulk_out: XhciEndpHandle, bulk_in_num: u8, bulk_out_num: u8, + bulk_in_addr: u8, + bulk_out_addr: u8, max_lun: u8, current_tag: u32, interface_num: u8, @@ -98,23 +100,28 @@ pub const FEATURE_ENDPOINT_HALT: u16 = 0; impl<'a> BulkOnlyTransport<'a> { pub fn init( handle: &'a XhciClientHandle, - config_desc: &ConfDesc, + _config_desc: &ConfDesc, if_desc: &IfDesc, ) -> Result { let endpoints = &if_desc.endpoints; - let bulk_in_num = (endpoints + let (bulk_in_idx, bulk_in_desc) = endpoints .iter() - .position(|endpoint| endpoint.direction() == EndpDirection::In) - .unwrap() - + 1) as u8; - let bulk_out_num = (endpoints + .enumerate() + .find(|(_, endpoint)| endpoint.direction() == EndpDirection::In) + .ok_or(ProtocolError::ProtocolError("no bulk IN endpoint found"))?; + let (bulk_out_idx, bulk_out_desc) = endpoints .iter() - .position(|endpoint| endpoint.direction() == EndpDirection::Out) - .unwrap() - + 1) as u8; + .enumerate() + .find(|(_, endpoint)| endpoint.direction() == EndpDirection::Out) + .ok_or(ProtocolError::ProtocolError("no bulk OUT endpoint found"))?; - let max_lun = get_max_lun(handle, 0)?; + let bulk_in_num = (bulk_in_idx + 1) as u8; + let bulk_out_num = (bulk_out_idx + 1) as u8; + let bulk_in_addr = bulk_in_desc.address; + let bulk_out_addr = bulk_out_desc.address; + + let max_lun = get_max_lun(handle, if_desc.number.into())?; println!("BOT_MAX_LUN {}", max_lun); Ok(Self { @@ -122,6 +129,8 @@ impl<'a> BulkOnlyTransport<'a> { bulk_out: handle.open_endpoint(bulk_out_num)?, bulk_in_num, bulk_out_num, + bulk_in_addr, + bulk_out_addr, handle, max_lun, current_tag: 0, @@ -133,7 +142,7 @@ impl<'a> BulkOnlyTransport<'a> { self.bulk_in.reset(false)?; self.handle.clear_feature( PortReqRecipient::Endpoint, - u16::from(self.bulk_in_num), + u16::from(self.bulk_in_addr), FEATURE_ENDPOINT_HALT, )?; } @@ -144,7 +153,7 @@ impl<'a> BulkOnlyTransport<'a> { self.bulk_out.reset(false)?; self.handle.clear_feature( PortReqRecipient::Endpoint, - u16::from(self.bulk_out_num), + u16::from(self.bulk_out_addr), FEATURE_ENDPOINT_HALT, )?; } @@ -162,38 +171,59 @@ impl<'a> BulkOnlyTransport<'a> { } Ok(()) } - fn read_csw_raw( - &mut self, - csw_buffer: &mut [u8; 13], - already: bool, - ) -> Result<(), ProtocolError> { - match self.bulk_in.transfer_read(&mut csw_buffer[..])? { - PortTransferStatus { - kind: PortTransferStatusKind::Stalled, - .. - } => { - if already { + fn read_csw(&mut self, csw_buffer: &mut [u8; 13]) -> Result<(), ProtocolError> { + let mut attempts = 0u8; + loop { + let status = self.bulk_in.transfer_read(&mut csw_buffer[..])?; + match status { + PortTransferStatus { + kind: PortTransferStatusKind::Stalled, + .. + } => { + attempts += 1; + if attempts >= 2 { + self.reset_recovery()?; + return Err(ProtocolError::ProtocolError( + "bulk IN stalled repeatedly when reading CSW", + )); + } + eprintln!("usbscsid: bulk IN stalled when reading CSW, clearing stall"); + self.clear_stall_in()?; + continue; + } + PortTransferStatus { + kind: PortTransferStatusKind::ShortPacket, + bytes_transferred, + } if bytes_transferred != 13 => { + eprintln!( + "usbscsid: short packet when reading CSW ({} != 13)", + bytes_transferred + ); self.reset_recovery()?; + return Err(ProtocolError::ProtocolError( + "short packet when reading CSW", + )); + } + PortTransferStatus { + kind: PortTransferStatusKind::Success, + .. + } + | PortTransferStatus { + kind: PortTransferStatusKind::ShortPacket, + bytes_transferred: 13, + } => return Ok(()), + _ => { + eprintln!( + "usbscsid: unexpected transfer status when reading CSW: {:?}", + status + ); + self.reset_recovery()?; + return Err(ProtocolError::ProtocolError( + "unexpected transfer status when reading CSW", + )); } - println!("bulk in endpoint stalled when reading CSW"); - self.clear_stall_in()?; - self.read_csw_raw(csw_buffer, true)?; - } - PortTransferStatus { - kind: PortTransferStatusKind::ShortPacket, - bytes_transferred, - } if bytes_transferred != 13 => { - panic!( - "received a short packet when reading CSW ({} != 13)", - bytes_transferred - ) } - _ => (), } - Ok(()) - } - fn read_csw(&mut self, csw_buffer: &mut [u8; 13]) -> Result<(), ProtocolError> { - self.read_csw_raw(csw_buffer, false) } } @@ -207,7 +237,8 @@ impl<'a> Protocol for BulkOnlyTransport<'a> { let tag = self.current_tag; let mut cbw_bytes = [0u8; 31]; - let cbw = plain::from_mut_bytes::(&mut cbw_bytes).unwrap(); + let cbw = plain::from_mut_bytes::(&mut cbw_bytes) + .map_err(|_| ProtocolError::ProtocolError("CBW buffer size mismatch"))?; *cbw = CommandBlockWrapper::new(tag, data.len() as u32, data.direction().into(), 0, cb)?; let cbw = *cbw; @@ -216,22 +247,48 @@ impl<'a> Protocol for BulkOnlyTransport<'a> { kind: PortTransferStatusKind::Stalled, .. } => { - // TODO: Error handling - panic!("bulk out endpoint stalled when sending CBW {:?}", cbw); - //self.clear_stall_out()?; - //dbg!(self.bulk_in.status()?, self.bulk_out.status()?); + eprintln!( + "usbscsid: bulk OUT endpoint stalled when sending CBW {:?}", + cbw + ); + self.clear_stall_out()?; + return Err(ProtocolError::ProtocolError( + "bulk OUT endpoint stalled when sending CBW", + )); } PortTransferStatus { bytes_transferred, .. } if bytes_transferred != 31 => { - panic!( - "received short packet when sending CBW ({} != 31)", + eprintln!( + "usbscsid: short packet when sending CBW ({} != 31)", bytes_transferred ); + self.reset_recovery()?; + return Err(ProtocolError::ProtocolError( + "short packet when sending CBW", + )); + } + PortTransferStatus { + kind: PortTransferStatusKind::Success, + .. + } => (), + PortTransferStatus { + kind: PortTransferStatusKind::ShortPacket, + bytes_transferred: 31, + } => (), + status => { + eprintln!( + "usbscsid: unexpected transfer status {:?} when sending CBW", + status + ); + self.reset_recovery()?; + return Err(ProtocolError::ProtocolError( + "unexpected transfer status when sending CBW", + )); } - _ => (), } + let data_len = data.len() as u32; let early_residue: Option = match data { DeviceReqData::In(buffer) => match self.bulk_in.transfer_read(buffer)? { PortTransferStatus { @@ -240,15 +297,19 @@ impl<'a> Protocol for BulkOnlyTransport<'a> { } => match kind { PortTransferStatusKind::Success => None, PortTransferStatusKind::ShortPacket => { - println!( - "received short packet (len {}) when transferring data", - bytes_transferred + let residue = data_len.saturating_sub(bytes_transferred); + eprintln!( + "usbscsid: short packet ({} of {} bytes) during data read", + bytes_transferred, data_len ); - NonZeroU32::new(bytes_transferred) + NonZeroU32::new(residue) } PortTransferStatusKind::Stalled => { - panic!("bulk in endpoint stalled when reading data"); - //self.clear_stall_in()?; + eprintln!("usbscsid: bulk IN endpoint stalled when reading data"); + self.clear_stall_in()?; + return Err(ProtocolError::ProtocolError( + "bulk IN endpoint stalled during data read", + )); } PortTransferStatusKind::Unknown => { return Err(ProtocolError::XhciError( @@ -266,15 +327,19 @@ impl<'a> Protocol for BulkOnlyTransport<'a> { } => match kind { PortTransferStatusKind::Success => None, PortTransferStatusKind::ShortPacket => { - println!( - "received short packet (len {}) when transferring data", - bytes_transferred + let residue = data_len.saturating_sub(bytes_transferred); + eprintln!( + "usbscsid: short packet ({} of {} bytes) during data write", + bytes_transferred, data_len ); - NonZeroU32::new(bytes_transferred) + NonZeroU32::new(residue) } PortTransferStatusKind::Stalled => { - panic!("bulk out endpoint stalled when reading data"); - //self.clear_stall_out()?; + eprintln!("usbscsid: bulk OUT endpoint stalled when writing data"); + self.clear_stall_out()?; + return Err(ProtocolError::ProtocolError( + "bulk OUT endpoint stalled during data write", + )); } PortTransferStatusKind::Unknown => { return Err(ProtocolError::XhciError( @@ -290,7 +355,8 @@ impl<'a> Protocol for BulkOnlyTransport<'a> { let mut csw_buffer = [0u8; 13]; self.read_csw(&mut csw_buffer)?; - let csw = plain::from_bytes::(&csw_buffer).unwrap(); + let csw = plain::from_bytes::(&csw_buffer) + .map_err(|_| ProtocolError::ProtocolError("CSW buffer size mismatch"))?; let residue = early_residue.or(NonZeroU32::new(csw.data_residue)); diff --git a/drivers/storage/usbscsid/src/protocol/mod.rs b/drivers/storage/usbscsid/src/protocol/mod.rs index a580765f..952268c7 100644 --- a/drivers/storage/usbscsid/src/protocol/mod.rs +++ b/drivers/storage/usbscsid/src/protocol/mod.rs @@ -59,22 +59,18 @@ pub trait Protocol { /// Bulk-only transport pub mod bot; -mod uas { - // TODO -} - use bot::BulkOnlyTransport; pub fn setup<'a>( handle: &'a XhciClientHandle, protocol: u8, - dev_desc: &DevDesc, + _dev_desc: &DevDesc, conf_desc: &ConfDesc, if_desc: &IfDesc, ) -> Option> { match protocol { 0x50 => Some(Box::new( - BulkOnlyTransport::init(handle, conf_desc, if_desc).unwrap(), + BulkOnlyTransport::init(handle, conf_desc, if_desc).ok()?, )), _ => None, } diff --git a/drivers/storage/usbscsid/src/scsi/mod.rs b/drivers/storage/usbscsid/src/scsi/mod.rs index 790abea6..fbba4d00 100644 --- a/drivers/storage/usbscsid/src/scsi/mod.rs +++ b/drivers/storage/usbscsid/src/scsi/mod.rs @@ -146,8 +146,15 @@ impl Scsi { &self.command_buffer[..10], DeviceReqData::In(&mut self.data_buffer[..initial_alloc_len as usize]), )? { - self.get_ff_sense(protocol, 252)?; - panic!("{:?}", self.res_ff_sense_data()); + if let Ok(()) = self.get_ff_sense(protocol, 252) { + eprintln!( + "usbscsid: MODE SENSE(10) failed: {:?}", + self.res_ff_sense_data() + ); + } + return Err(ScsiError::ProtocolError(ProtocolError::ProtocolError( + "MODE SENSE(10) command failed", + ))); } let optimal_alloc_len = self.res_mode_param_header10().mode_data_len() + 2; // the length of the mode data field itself @@ -161,7 +168,7 @@ impl Scsi { )?; Ok(( self.res_mode_param_header10(), - self.res_blkdesc_mode10(), + self.res_blkdesc_mode10()?, self.res_mode_pages10(), )) } @@ -199,44 +206,50 @@ impl Scsi { pub fn res_mode_param_header10(&self) -> &cmds::ModeParamHeader10 { plain::from_bytes(&self.data_buffer).unwrap() } - pub fn res_blkdesc_mode6(&self) -> &[cmds::ShortLbaModeParamBlkDesc] { + pub fn res_blkdesc_mode6(&self) -> Result<&[cmds::ShortLbaModeParamBlkDesc]> { let header = self.res_mode_param_header6(); let descs_start = mem::size_of::(); - plain::slice_from_bytes( - &self.data_buffer[descs_start..descs_start + usize::from(header.block_desc_len)], + let desc_len = usize::from(header.block_desc_len); + if descs_start + desc_len > self.data_buffer.len() { + return Err(ScsiError::Overflow( + "block descriptor length exceeds data buffer", + )); + } + Ok( + plain::slice_from_bytes(&self.data_buffer[descs_start..descs_start + desc_len]) + .map_err(|_| ScsiError::Overflow("block descriptor alignment mismatch"))?, ) - .unwrap() } - pub fn res_blkdesc_mode10(&self) -> BlkDescSlice<'_> { + pub fn res_blkdesc_mode10(&self) -> Result> { let header = self.res_mode_param_header10(); let descs_start = mem::size_of::(); + let desc_range = descs_start..descs_start + usize::from(header.block_desc_len()); + if desc_range.end > self.data_buffer.len() { + return Err(ScsiError::Overflow( + "block descriptor length exceeds data buffer", + )); + } if header.longlba() { - BlkDescSlice::Long( - plain::slice_from_bytes( - &self.data_buffer - [descs_start..descs_start + usize::from(header.block_desc_len())], - ) - .unwrap(), - ) + Ok(BlkDescSlice::Long( + plain::slice_from_bytes(&self.data_buffer[desc_range]).map_err(|_| { + ScsiError::Overflow("long LBA block descriptor alignment mismatch") + })?, + )) } else if self.res_standard_inquiry_data().periph_dev_ty() != cmds::PeriphDeviceType::DirectAccess as u8 && self.res_standard_inquiry_data().version() == cmds::InquiryVersion::Spc3 as u8 { - BlkDescSlice::General( - plain::slice_from_bytes( - &self.data_buffer - [descs_start..descs_start + usize::from(header.block_desc_len())], - ) - .unwrap(), - ) + Ok(BlkDescSlice::General( + plain::slice_from_bytes(&self.data_buffer[desc_range]).map_err(|_| { + ScsiError::Overflow("general block descriptor alignment mismatch") + })?, + )) } else { - BlkDescSlice::Short( - plain::slice_from_bytes( - &self.data_buffer - [descs_start..descs_start + usize::from(header.block_desc_len())], - ) - .unwrap(), - ) + Ok(BlkDescSlice::Short( + plain::slice_from_bytes(&self.data_buffer[desc_range]).map_err(|_| { + ScsiError::Overflow("short LBA block descriptor alignment mismatch") + })?, + )) } } diff --git a/drivers/usb/usbhubd/src/main.rs b/drivers/usb/usbhubd/src/main.rs index 2c8b9876..eab690dd 100644 --- a/drivers/usb/usbhubd/src/main.rs +++ b/drivers/usb/usbhubd/src/main.rs @@ -2,26 +2,41 @@ use std::{env, thread, time}; use xhcid_interface::{ plain, usb, ConfigureEndpointsReq, DevDesc, DeviceReqData, PortId, PortReqRecipient, PortReqTy, - XhciClientHandle, + UsbSpeed, XhciClientHandle, }; fn main() { common::init(); let mut args = env::args().skip(1); - const USAGE: &'static str = "usbhubd "; + const USAGE: &str = "usbhubd "; - let scheme = args.next().expect(USAGE); + let scheme = args.next().unwrap_or_else(|| { + eprintln!("usbhubd: {USAGE}"); + std::process::exit(1); + }); let port_id = args .next() - .expect(USAGE) + .unwrap_or_else(|| { + eprintln!("usbhubd: {USAGE}"); + std::process::exit(1); + }) .parse::() - .expect("Expected port ID"); + .unwrap_or_else(|e| { + eprintln!("usbhubd: invalid port ID: {e}"); + std::process::exit(1); + }); let interface_num = args .next() - .expect(USAGE) + .unwrap_or_else(|| { + eprintln!("usbhubd: {USAGE}"); + std::process::exit(1); + }) .parse::() - .expect("Expected integer as input of interface"); + .unwrap_or_else(|e| { + eprintln!("usbhubd: interface number must be 0-255: {e}"); + std::process::exit(1); + }); log::info!( "USB HUB driver spawned with scheme `{}`, port {}, interface {}", @@ -39,11 +54,14 @@ fn main() { common::file_level(), ); - let handle = - XhciClientHandle::new(scheme.clone(), port_id).expect("Failed to open XhciClientHandle"); - let desc: DevDesc = handle - .get_standard_descs() - .expect("Failed to get standard descriptors"); + let handle = XhciClientHandle::new(scheme.clone(), port_id).unwrap_or_else(|e| { + eprintln!("usbhubd: failed to open XhciClientHandle: {e}"); + std::process::exit(1); + }); + let desc: DevDesc = handle.get_standard_descs().unwrap_or_else(|e| { + eprintln!("usbhubd: failed to get standard descriptors: {e}"); + std::process::exit(1); + }); let (conf_desc, if_desc) = desc .config_descs @@ -58,11 +76,13 @@ fn main() { })?; Some((conf_desc.clone(), if_desc)) }) - .expect("Failed to find suitable configuration"); + .unwrap_or_else(|| { + eprintln!("usbhubd: failed to find configuration with interface {interface_num}"); + std::process::exit(1); + }); // Read hub descriptor let (ports, usb_3) = if desc.major_version() >= 3 { - // USB 3.0 hubs let mut hub_desc = usb::HubDescriptorV3::default(); handle .device_request( @@ -73,10 +93,12 @@ fn main() { 0, DeviceReqData::In(unsafe { plain::as_mut_bytes(&mut hub_desc) }), ) - .expect("Failed to read hub descriptor"); + .unwrap_or_else(|e| { + eprintln!("usbhubd: failed to read USB 3 hub descriptor: {e}"); + std::process::exit(1); + }); (hub_desc.ports, true) } else { - // USB 2.0 and earlier hubs let mut hub_desc = usb::HubDescriptorV2::default(); handle .device_request( @@ -87,7 +109,10 @@ fn main() { 0, DeviceReqData::In(unsafe { plain::as_mut_bytes(&mut hub_desc) }), ) - .expect("Failed to read hub descriptor"); + .unwrap_or_else(|e| { + eprintln!("usbhubd: failed to read USB 2 hub descriptor: {e}"); + std::process::exit(1); + }); (hub_desc.ports, false) }; @@ -95,25 +120,55 @@ fn main() { handle .configure_endpoints(&ConfigureEndpointsReq { config_desc: conf_desc.configuration_value, - interface_desc: None, //TODO: stalls on USB 3 hub: Some(interface_num), - alternate_setting: None, //TODO: stalls on USB 3 hub: Some(if_desc.alternate_setting), + interface_desc: Some(interface_num), + alternate_setting: Some(if_desc.alternate_setting), hub_ports: Some(ports), }) - .expect("Failed to configure endpoints after reading hub descriptor"); + .unwrap_or_else(|e| { + eprintln!("usbhubd: failed to configure endpoints: {e}"); + std::process::exit(1); + }); if usb_3 { handle .device_request( PortReqTy::Class, PortReqRecipient::Device, - 0x0c, // SET_HUB_DEPTH + 0x0c, port_id.hub_depth().into(), 0, DeviceReqData::NoData, ) - .expect("Failed to set hub depth"); + .unwrap_or_else(|e| { + eprintln!("usbhubd: failed to set hub depth: {e}"); + std::process::exit(1); + }); } + let status_change_len = (usize::from(ports) + 8) / 8; + let mut status_change_buf = vec![0u8; status_change_len]; + + let mut interrupt_ep = if_desc + .endpoints + .iter() + .find(|ep| { + ep.ty() == xhcid_interface::EndpointTy::Interrupt + && ep.direction() == xhcid_interface::EndpDirection::In + }) + .and_then(|ep| { + let ep_num = ep.address & 0x0F; + match handle.open_endpoint(ep_num) { + Ok(h) => { + log::info!("hub interrupt endpoint {} opened", ep_num); + Some(h) + } + Err(err) => { + log::warn!("failed to open hub interrupt endpoint {}: {}", ep_num, err); + None + } + } + }); + // Initialize states struct PortState { port_id: PortId, @@ -129,111 +184,297 @@ fn main() { } if attached { - self.handle.attach().expect("Failed to attach"); + let speed = match &self.port_sts { + usb::HubPortStatus::V2(v2) => UsbSpeed::from_v2_port_status(*v2), + usb::HubPortStatus::V3(v3) => UsbSpeed::from_v3_port_status(*v3), + }; + let res = match speed { + Some(s) => self.handle.attach_with_speed(s), + None => self.handle.attach(), + }; + if let Err(err) = res { + log::error!("failed to attach port {}: {}", self.port_id, err); + return; + } } else { - self.handle.detach().expect("Failed to detach"); + if let Err(err) = self.handle.detach() { + log::error!("failed to detach port {}: {}", self.port_id, err); + return; + } } self.attached = attached; } } - let mut states = Vec::new(); - for port in 1..=ports { - let child_port_id = port_id.child(port).expect("Cannot get child port ID"); - states.push(PortState { - port_id: child_port_id, - port_sts: if usb_3 { - usb::HubPortStatus::V3(usb::HubPortStatusV3::default()) - } else { - usb::HubPortStatus::V2(usb::HubPortStatusV2::default()) - }, - handle: XhciClientHandle::new(scheme.clone(), child_port_id) - .expect("Failed to open XhciClientHandle"), - attached: false, - }); + let mut states: Vec> = (1..=ports) + .map(|port| { + let child_port_id = match port_id.child(port) { + Ok(id) => id, + Err(e) => { + log::error!("port {}: cannot compute child port ID: {}", port, e); + return None; + } + }; + let child_handle = match XhciClientHandle::new(scheme.clone(), child_port_id) { + Ok(h) => h, + Err(e) => { + log::error!("port {}: failed to open XhciClientHandle: {}", port, e); + return None; + } + }; + Some(PortState { + port_id: child_port_id, + port_sts: if usb_3 { + usb::HubPortStatus::V3(usb::HubPortStatusV3::default()) + } else { + usb::HubPortStatus::V2(usb::HubPortStatusV2::default()) + }, + handle: child_handle, + attached: false, + }) + }) + .collect(); + + macro_rules! hub_req { + ($handle:expr, $port:expr, $msg:expr, $($arg:expr),*) => { + if let Err(err) = $handle.device_request($($arg),*) { + log::error!("port {}: {} failed: {}", $port, $msg, err); + continue; + } + }; } - //TODO: use change flags? - loop { - for port in 1..=ports { - let port_idx: usize = port.checked_sub(1).unwrap().into(); - let state = states.get_mut(port_idx).unwrap(); + let clear_port_changes = |handle: &XhciClientHandle, port: u8, is_usb3: bool| { + let mut features: Vec = vec![ + usb::HubPortFeature::CPortConnection, + usb::HubPortFeature::CPortReset, + usb::HubPortFeature::CPortOverCurrent, + usb::HubPortFeature::CPortEnable, + ]; + if is_usb3 { + features.push(usb::HubPortFeature::CPortLinkState); + features.push(usb::HubPortFeature::CPortConfigError); + } + for feature in &features { + if let Err(err) = handle.device_request( + PortReqTy::Class, + PortReqRecipient::Other, + usb::SetupReq::ClearFeature as u8, + *feature as u16, + port as u16, + DeviceReqData::NoData, + ) { + log::debug!("port {}: clear feature {:?} failed: {}", port, feature, err); + } + } + }; - let port_sts = if usb_3 { - let mut port_sts = usb::HubPortStatusV3::default(); - handle - .device_request( + let check_all_ports = + |states: &mut Vec>, handle: &XhciClientHandle, usb_3: bool, ports: u8| { + for port in 1..=ports { + let port_idx: usize = (port - 1) as usize; + let state = match states.get_mut(port_idx) { + Some(Some(s)) => s, + _ => continue, + }; + + let port_sts = if usb_3 { + let mut port_sts = usb::HubPortStatusV3::default(); + hub_req!( + handle, + port, + "get status", PortReqTy::Class, PortReqRecipient::Other, usb::SetupReq::GetStatus as u8, 0, port as u16, - DeviceReqData::In(unsafe { plain::as_mut_bytes(&mut port_sts) }), - ) - .expect("Failed to retrieve port status"); - usb::HubPortStatus::V3(port_sts) - } else { - let mut port_sts = usb::HubPortStatusV2::default(); - handle - .device_request( + DeviceReqData::In(unsafe { plain::as_mut_bytes(&mut port_sts) }) + ); + usb::HubPortStatus::V3(port_sts) + } else { + let mut port_sts = usb::HubPortStatusV2::default(); + hub_req!( + handle, + port, + "get status", PortReqTy::Class, PortReqRecipient::Other, usb::SetupReq::GetStatus as u8, 0, port as u16, - DeviceReqData::In(unsafe { plain::as_mut_bytes(&mut port_sts) }), - ) - .expect("Failed to retrieve port status"); + DeviceReqData::In(unsafe { plain::as_mut_bytes(&mut port_sts) }) + ); + usb::HubPortStatus::V2(port_sts) + }; + if state.port_sts != port_sts { + state.port_sts = port_sts; + log::info!("port {} status {:X?}", port, port_sts); + } + clear_port_changes(handle, port, usb_3); + + if !port_sts.is_powered() { + log::info!("power on port {port}"); + hub_req!( + handle, + port, + "set port power", + PortReqTy::Class, + PortReqRecipient::Other, + usb::SetupReq::SetFeature as u8, + usb::HubPortFeature::PortPower as u16, + port as u16, + DeviceReqData::NoData + ); + state.ensure_attached(false); + continue; + } + + if !port_sts.is_connected() { + state.ensure_attached(false); + continue; + } + + if port_sts.is_resetting() { + state.ensure_attached(false); + continue; + } + + if !port_sts.is_enabled() { + log::info!("reset port {port}"); + hub_req!( + handle, + port, + "set port reset", + PortReqTy::Class, + PortReqRecipient::Other, + usb::SetupReq::SetFeature as u8, + usb::HubPortFeature::PortReset as u16, + port as u16, + DeviceReqData::NoData + ); + state.ensure_attached(false); + continue; + } + + state.ensure_attached(true); + } + }; + + check_all_ports(&mut states, &handle, usb_3, ports); + + loop { + let ports_to_check: Vec = if let Some(ref mut ep) = interrupt_ep { + match ep.transfer_read(&mut status_change_buf) { + Ok(_) => { + let mut changed = Vec::new(); + for port in 1..=ports { + let status_change_bit = usize::from(port); + let byte_idx = status_change_bit / 8; + let bit_idx = status_change_bit % 8; + if byte_idx < status_change_buf.len() + && (status_change_buf[byte_idx] & (1 << bit_idx)) != 0 + { + changed.push(port); + } + } + if changed.is_empty() { + (1..=ports).collect() + } else { + changed + } + } + Err(err) => { + log::warn!("hub interrupt read failed: {}, falling back to poll", err); + (1..=ports).collect() + } + } + } else { + (1..=ports).collect() + }; + + for port in ports_to_check { + let port_idx: usize = (port - 1) as usize; + let state = match states.get_mut(port_idx) { + Some(Some(s)) => s, + _ => continue, + }; + + let port_sts = if usb_3 { + let mut port_sts = usb::HubPortStatusV3::default(); + hub_req!( + handle, + port, + "get status", + PortReqTy::Class, + PortReqRecipient::Other, + usb::SetupReq::GetStatus as u8, + 0, + port as u16, + DeviceReqData::In(unsafe { plain::as_mut_bytes(&mut port_sts) }) + ); + usb::HubPortStatus::V3(port_sts) + } else { + let mut port_sts = usb::HubPortStatusV2::default(); + hub_req!( + handle, + port, + "get status", + PortReqTy::Class, + PortReqRecipient::Other, + usb::SetupReq::GetStatus as u8, + 0, + port as u16, + DeviceReqData::In(unsafe { plain::as_mut_bytes(&mut port_sts) }) + ); usb::HubPortStatus::V2(port_sts) }; if state.port_sts != port_sts { state.port_sts = port_sts; log::info!("port {} status {:X?}", port, port_sts); } + clear_port_changes(&handle, port, usb_3); - // Ensure port is powered on if !port_sts.is_powered() { log::info!("power on port {port}"); - handle - .device_request( - PortReqTy::Class, - PortReqRecipient::Other, - usb::SetupReq::SetFeature as u8, - usb::HubPortFeature::PortPower as u16, - port as u16, - DeviceReqData::NoData, - ) - .expect("Failed to set port power"); + hub_req!( + handle, + port, + "set port power", + PortReqTy::Class, + PortReqRecipient::Other, + usb::SetupReq::SetFeature as u8, + usb::HubPortFeature::PortPower as u16, + port as u16, + DeviceReqData::NoData + ); state.ensure_attached(false); continue; } - // Ignore disconnected port if !port_sts.is_connected() { state.ensure_attached(false); continue; } - // Ignore port in reset if port_sts.is_resetting() { state.ensure_attached(false); continue; } - // Ensure port is enabled if !port_sts.is_enabled() { log::info!("reset port {port}"); - handle - .device_request( - PortReqTy::Class, - PortReqRecipient::Other, - usb::SetupReq::SetFeature as u8, - usb::HubPortFeature::PortReset as u16, - port as u16, - DeviceReqData::NoData, - ) - .expect("Failed to set port enable"); + hub_req!( + handle, + port, + "set port reset", + PortReqTy::Class, + PortReqRecipient::Other, + usb::SetupReq::SetFeature as u8, + usb::HubPortFeature::PortReset as u16, + port as u16, + DeviceReqData::NoData + ); state.ensure_attached(false); continue; } @@ -241,9 +482,8 @@ fn main() { state.ensure_attached(true); } - //TODO: use interrupts or poll faster? - thread::sleep(time::Duration::new(1, 0)); + if interrupt_ep.is_none() { + thread::sleep(time::Duration::from_millis(100)); + } } - - //TODO: read interrupt port for changes } diff --git a/drivers/usb/xhcid/src/driver_interface.rs b/drivers/usb/xhcid/src/driver_interface.rs index 727f8d7e..bd5b7735 100644 --- a/drivers/usb/xhcid/src/driver_interface.rs +++ b/drivers/usb/xhcid/src/driver_interface.rs @@ -16,6 +16,63 @@ use thiserror::Error; pub use crate::usb::{EndpointTy, ENDP_ATTR_TY_MASK}; +#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)] +#[repr(u8)] +pub enum UsbSpeed { + Low = 0, + Full = 1, + High = 2, + Super = 3, + SuperPlus = 4, +} + +impl TryFrom for UsbSpeed { + type Error = (); + + fn try_from(value: u8) -> Result { + match value { + 0 => Ok(Self::Low), + 1 => Ok(Self::Full), + 2 => Ok(Self::High), + 3 => Ok(Self::Super), + 4 => Ok(Self::SuperPlus), + _ => Err(()), + } + } +} + +impl UsbSpeed { + pub fn from_v2_port_status(status: crate::usb::HubPortStatusV2) -> Option { + if status.contains(crate::usb::HubPortStatusV2::HIGH_SPEED) { + Some(Self::High) + } else if status.contains(crate::usb::HubPortStatusV2::LOW_SPEED) { + Some(Self::Low) + } else if status.contains(crate::usb::HubPortStatusV2::CONNECTION) { + Some(Self::Full) + } else { + None + } + } + + /// Map from USB 3 hub port status speed bits (wPortStatus bits 12:10) to a speed category. + /// + /// Per USB 3.2 spec Table 10-12: 0=undefined, 1=Full, 2=High, 3=SuperSpeed. + /// SuperSpeedPlus requires reading Extended Port Status via a separate device request + /// and is not decoded here; values 4-7 are treated as SuperSpeedPlus as a conservative + /// upper bound. + pub fn from_v3_port_status(status: crate::usb::HubPortStatusV3) -> Option { + let speed_bits = (status.bits() >> 10) & 0x7; + match speed_bits { + 0 => None, + 1 => Some(Self::Full), + 2 => Some(Self::High), + 3 => Some(Self::Super), + 4..=7 => Some(Self::SuperPlus), + _ => None, + } + } +} + #[derive(Clone, Debug, Default, Serialize, Deserialize)] pub struct ConfigureEndpointsReq { /// Index into the configuration descriptors of the device descriptor. @@ -40,6 +97,8 @@ pub struct DevDesc { pub product_str: Option, pub serial_str: Option, pub config_descs: SmallVec<[ConfDesc; 1]>, + pub supports_superspeed: bool, + pub supports_superspeedplus: bool, } impl DevDesc { @@ -555,6 +614,11 @@ impl XhciClientHandle { let _bytes_written = file.write(&[])?; Ok(()) } + pub fn attach_with_speed(&self, speed: UsbSpeed) -> result::Result<(), XhciClientHandleError> { + let file = self.fd.openat("attach", libredox::flag::O_WRONLY, 0)?; + file.write(&[speed as u8])?; + Ok(()) + } pub fn detach(&self) -> result::Result<(), XhciClientHandleError> { let file = self.fd.openat("detach", libredox::flag::O_WRONLY, 0)?; let _bytes_written = file.write(&[])?; @@ -832,7 +896,7 @@ impl XhciEndpHandle { TransferStream { bytes_to_transfer: total_len, bytes_transferred: 0, - bytes_per_transfer: 32768, // TODO + bytes_per_transfer: 32768, endp_handle: self, } } diff --git a/drivers/usb/xhcid/src/main.rs b/drivers/usb/xhcid/src/main.rs index 25b2fdd6..97354ffe 100644 --- a/drivers/usb/xhcid/src/main.rs +++ b/drivers/usb/xhcid/src/main.rs @@ -140,8 +140,7 @@ fn daemon_with_context_size( let address = unsafe { pcid_handle.map_bar(0) }.ptr.as_ptr() as usize; - let (irq_file, interrupt_method) = (None, InterruptMethod::Polling); //get_int_method(&mut pcid_handle); - //TODO: Fix interrupts. + let (irq_file, interrupt_method) = get_int_method(&mut pcid_handle); log::info!("XHCI {}", pci_config.func.display()); diff --git a/drivers/usb/xhcid/src/usb/hub.rs b/drivers/usb/xhcid/src/usb/hub.rs index 9dab55e8..69168bfc 100644 --- a/drivers/usb/xhcid/src/usb/hub.rs +++ b/drivers/usb/xhcid/src/usb/hub.rs @@ -88,8 +88,12 @@ pub enum HubPortFeature { PortLinkState = 5, PortPower = 8, CPortConnection = 16, + CPortEnable = 17, + CPortSuspend = 18, CPortOverCurrent = 19, CPortReset = 20, + CPortLinkState = 25, + CPortConfigError = 26, } bitflags::bitflags! { diff --git a/drivers/usb/xhcid/src/xhci/device_enumerator.rs b/drivers/usb/xhcid/src/xhci/device_enumerator.rs index 74b9f732..32d7f640 100644 --- a/drivers/usb/xhcid/src/xhci/device_enumerator.rs +++ b/drivers/usb/xhcid/src/xhci/device_enumerator.rs @@ -28,7 +28,8 @@ impl DeviceEnumerator { let request = match self.request_queue.recv() { Ok(req) => req, Err(err) => { - panic!("Failed to received an enumeration request! error: {}", err) + log::error!("channel closed, device enumerator exiting: {}", err); + return; } }; @@ -38,7 +39,7 @@ impl DeviceEnumerator { debug!("Device Enumerator request for port {}", port_id); let (len, flags) = { - let ports = self.hci.ports.lock().unwrap(); + let ports = self.hci.ports.lock().unwrap_or_else(|e| e.into_inner()); let len = ports.len(); @@ -68,10 +69,11 @@ impl DeviceEnumerator { && !flags.contains(PortFlags::PR); if !disabled_state { - panic!( - "Port {} isn't in the disabled state! Current flags: {:?}", + warn!( + "Port {} isn't in the disabled state! Current flags: {:?}. Continuing.", port_id, flags ); + continue; } else { debug!("Port {} has entered the disabled state.", port_id); } @@ -80,7 +82,7 @@ impl DeviceEnumerator { debug!("Received a device connect on port {}, but it's not enabled. Resetting the port.", port_id); let _ = self.hci.reset_port(port_id); - let mut ports = self.hci.ports.lock().unwrap(); + let mut ports = self.hci.ports.lock().unwrap_or_else(|e| e.into_inner()); let port = &mut ports[port_array_index]; port.clear_prc(); diff --git a/drivers/usb/xhcid/src/xhci/event.rs b/drivers/usb/xhcid/src/xhci/event.rs index 83af1209..4121b0ae 100644 --- a/drivers/usb/xhcid/src/xhci/event.rs +++ b/drivers/usb/xhcid/src/xhci/event.rs @@ -4,6 +4,7 @@ use syscall::error::Result; use common::dma::Dma; use super::ring::Ring; +use super::runtime::RuntimeRegs; use super::trb::Trb; use super::Xhci; @@ -24,9 +25,13 @@ pub struct EventRing { impl EventRing { pub fn new(ac64: bool) -> Result { + Self::new_with_size::(ac64, 256) + } + + pub fn new_with_size(ac64: bool, size: usize) -> Result { let mut ring = EventRing { ste: unsafe { Xhci::::alloc_dma_zeroed_unsized_raw(ac64, 1)? }, - ring: Ring::new::(ac64, 256, false)?, + ring: Ring::new::(ac64, size, false)?, }; ring.ste[0] @@ -43,9 +48,14 @@ impl EventRing { pub fn next(&mut self) -> &mut Trb { self.ring.next().0 } - pub fn erdp(&self) -> u64 { + pub fn dequeue_ptr(&self) -> u64 { self.ring.register() & 0xFFFF_FFFF_FFFF_FFF0 } + pub fn erdp(&self, runtime_regs: &RuntimeRegs) -> u64 { + ((u64::from(runtime_regs.ints[0].erdp_high.read()) << 32) + | u64::from(runtime_regs.ints[0].erdp_low.read())) + & 0xFFFF_FFFF_FFFF_FFF0 + } pub fn erstba(&self) -> u64 { self.ste.physical() as u64 } diff --git a/drivers/usb/xhcid/src/xhci/irq_reactor.rs b/drivers/usb/xhcid/src/xhci/irq_reactor.rs index ac492d5b..ed193477 100644 --- a/drivers/usb/xhcid/src/xhci/irq_reactor.rs +++ b/drivers/usb/xhcid/src/xhci/irq_reactor.rs @@ -9,6 +9,7 @@ use std::os::unix::io::AsRawFd; use crossbeam_channel::{Receiver, Sender}; use log::{debug, error, info, trace, warn}; +use syscall::error::{Error, Result, EIO}; use super::doorbell::Doorbell; use super::event::EventRing; @@ -32,7 +33,7 @@ pub struct State { impl State { fn finish(self, message: Option) { - *self.message.lock().unwrap() = message; + *self.message.lock().unwrap_or_else(|e| e.into_inner()) = message; trace!("Waking up future with waker: {:?}", self.waker); self.waker.wake(); } @@ -129,7 +130,7 @@ impl IrqReactor { hci_clone .primary_event_ring .lock() - .unwrap() + .unwrap_or_else(|e| e.into_inner()) .ring .next_index() }; @@ -137,7 +138,10 @@ impl IrqReactor { 'trb_loop: loop { self.pause(); - let mut event_ring = hci_clone.primary_event_ring.lock().unwrap(); + let mut event_ring = hci_clone + .primary_event_ring + .lock() + .unwrap_or_else(|e| e.into_inner()); let event_trb = &mut event_ring.ring.trbs[event_trb_index]; @@ -182,7 +186,7 @@ impl IrqReactor { } fn mask_interrupts(&mut self) { - let mut run = self.hci.run.lock().unwrap(); + let mut run = self.hci.run.lock().unwrap_or_else(|e| e.into_inner()); debug!("Masking interrupts!"); @@ -194,7 +198,7 @@ impl IrqReactor { } fn unmask_interrupts(&mut self) { - let mut run = self.hci.run.lock().unwrap(); + let mut run = self.hci.run.lock().unwrap_or_else(|e| e.into_inner()); debug!("unmasking interrupts!"); if run.ints[0].iman.readf(1 << 1) { @@ -208,35 +212,72 @@ impl IrqReactor { debug!("Running IRQ reactor with IRQ file and event queue"); let hci_clone = Arc::clone(&self.hci); - let event_queue = - RawEventQueue::new().expect("xhcid irq_reactor: failed to create IRQ event queue"); - let irq_fd = self.irq_file.as_ref().unwrap().as_raw_fd(); - event_queue - .subscribe(irq_fd as usize, 0, event::EventFlags::READ) - .unwrap(); + let event_queue = match RawEventQueue::new() { + Ok(event_queue) => event_queue, + Err(err) => { + error!( + "xhcid irq_reactor: failed to create IRQ event queue: {}", + err + ); + return self.run_polling(); + } + }; + let irq_fd = match self.irq_file.as_ref() { + Some(irq_file) => irq_file.as_raw_fd(), + None => { + error!("xhcid irq_reactor: missing IRQ file, falling back to polling mode"); + return self.run_polling(); + } + }; + if let Err(err) = event_queue.subscribe(irq_fd as usize, 0, event::EventFlags::READ) { + error!( + "xhcid irq_reactor: failed to subscribe IRQ fd {}: {}", + irq_fd, err + ); + return self.run_polling(); + } trace!("IRQ Reactor has created its event queue."); let mut event_trb_index = { hci_clone .primary_event_ring .lock() - .unwrap() + .unwrap_or_else(|e| e.into_inner()) .ring .next_index() }; trace!("IRQ reactor has grabbed the next index in the event ring."); 'trb_loop: loop { - let _event = event_queue.next_event().unwrap(); + let _event = match event_queue.next_event() { + Ok(event) => event, + Err(err) => { + error!("xhcid irq_reactor: failed to read next IRQ event: {}", err); + continue 'trb_loop; + } + }; trace!("IRQ event queue notified"); let mut buffer = [0u8; 8]; - let _ = self - .irq_file - .as_mut() - .unwrap() - .read(&mut buffer) - .expect("Failed to read from irq scheme"); + { + let irq_file = match self.irq_file.as_mut() { + Some(irq_file) => irq_file, + None => { + error!( + "xhcid irq_reactor: IRQ file disappeared, falling back to polling mode" + ); + return self.run_polling(); + } + }; + + let _ = match irq_file.read(&mut buffer) { + Ok(n) => n, + Err(err) => { + log::error!("failed to read from irq scheme: {}", err); + continue 'trb_loop; + } + }; + } if !self.hci.received_irq() { // continue only when an IRQ to this device was received @@ -248,11 +289,19 @@ impl IrqReactor { trace!("IRQ reactor received an IRQ"); - let _ = self.irq_file.as_mut().unwrap().write(&buffer); + if let Some(irq_file) = self.irq_file.as_mut() { + let _ = irq_file.write(&buffer); + } else { + error!("xhcid irq_reactor: IRQ file disappeared before IRQ acknowledgement"); + return self.run_polling(); + } // TODO: More event rings, probably even with different IRQs. - let mut event_ring = hci_clone.primary_event_ring.lock().unwrap(); + let mut event_ring = hci_clone + .primary_event_ring + .lock() + .unwrap_or_else(|e| e.into_inner()); let mut count = 0; @@ -321,17 +370,18 @@ impl IrqReactor { route_string: 0, }; trace!("Received Port Status Change Request on port {}", port_id); - self.device_enumerator_sender + if let Err(err) = self + .device_enumerator_sender .send(DeviceEnumerationRequest { port_id }) - .expect( - format!( - "Failed to transmit device numeration request on port {}", - port_id - ) - .as_str(), + { + log::error!( + "port {}: failed to send enumeration request: {}", + port_id, + err ); + } { - let mut ports = self.hci.ports.lock().unwrap(); + let mut ports = self.hci.ports.lock().unwrap_or_else(|e| e.into_inner()); let root_port_index = port_id.root_hub_port_index(); if root_port_index >= ports.len() { warn!( @@ -353,7 +403,7 @@ impl IrqReactor { } fn update_erdp(&self, event_ring: &EventRing) { - let dequeue_pointer_and_dcs = event_ring.erdp(); + let dequeue_pointer_and_dcs = event_ring.dequeue_ptr(); let dequeue_pointer = dequeue_pointer_and_dcs & 0xFFFF_FFFF_FFFF_FFFE; assert_eq!( dequeue_pointer & 0xFFFF_FFFF_FFFF_FFF0, @@ -363,10 +413,10 @@ impl IrqReactor { trace!("Updated ERDP to {:#0x}", dequeue_pointer); - self.hci.run.lock().unwrap().ints[0] + self.hci.run.lock().unwrap_or_else(|e| e.into_inner()).ints[0] .erdp_low .write(dequeue_pointer as u32); - self.hci.run.lock().unwrap().ints[0] + self.hci.run.lock().unwrap_or_else(|e| e.into_inner()).ints[0] .erdp_high .write((dequeue_pointer >> 32) as u32); } @@ -400,7 +450,7 @@ impl IrqReactor { .hci .cmd .lock() - .unwrap() + .unwrap_or_else(|e| e.into_inner()) .phys_addr_to_entry_mut(self.hci.cap.ac64(), phys_ptr) { Some(command_trb) => { @@ -533,8 +583,84 @@ impl IrqReactor { } /// Grows the event ring fn grow_event_ring(&mut self) { - // TODO - error!("TODO: grow event ring"); + let current_dequeue = { + let event_ring = self + .hci + .primary_event_ring + .lock() + .unwrap_or_else(|e| e.into_inner()); + event_ring.ring.i + }; + let current_size = { + let event_ring = self + .hci + .primary_event_ring + .lock() + .unwrap_or_else(|e| e.into_inner()); + event_ring.ring.trbs.len() + }; + let new_size = current_size * 2; + if new_size > 4096 { + log::error!("event ring growth capped at 4096 entries, skipping growth"); + return; + } + + log::info!( + "growing event ring from {} to {} entries", + current_size, + new_size + ); + + let ac64 = { + let event_ring = self + .hci + .primary_event_ring + .lock() + .unwrap_or_else(|e| e.into_inner()); + let old_ste = &event_ring.ste; + old_ste[0].address_high.read() != 0 + }; + + let mut new_event_ring = match EventRing::new_with_size::(ac64, new_size) { + Ok(ring) => ring, + Err(err) => { + log::error!( + "failed to allocate larger event ring ({} entries): {}", + new_size, + err + ); + return; + } + }; + + new_event_ring.ring.i = current_dequeue.min(new_size - 1); + + let erdp = new_event_ring.ring.register(); + let erstba = new_event_ring.erstba(); + + { + let mut event_ring = self + .hci + .primary_event_ring + .lock() + .unwrap_or_else(|e| e.into_inner()); + *event_ring = new_event_ring; + } + + { + let int = &mut self.hci.run.lock().unwrap_or_else(|e| e.into_inner()).ints[0]; + int.erdp_low.write(erdp as u32 | (1 << 3)); + int.erdp_high.write((erdp as u64 >> 32) as u32); + int.erstba_low.write(erstba as u32); + int.erstba_high.write((erstba as u64 >> 32) as u32); + } + + log::info!( + "event ring grown to {} entries, ERDP={:X}, ERSTBA={:X}", + new_size, + erdp, + erstba + ); } pub fn run(self) -> ! { @@ -570,7 +696,7 @@ impl EventDoorbell { pub fn ring(self) { trace!("Ring doorbell {} with data {}", self.index, self.data); - self.dbs.lock().unwrap()[self.index].write(self.data); + self.dbs.lock().unwrap_or_else(|e| e.into_inner())[self.index].write(self.data); trace!("Doorbell was rung."); } } @@ -595,20 +721,26 @@ impl Future for EventTrbFuture { ref state, ref sender, ref mut doorbell_opt, - } => match state.message.lock().unwrap().take() { + } => match state + .message + .lock() + .unwrap_or_else(|e| e.into_inner()) + .take() + { Some(message) => message, None => { // Register state with IRQ reactor trace!("Send state {:X?}", state.state_kind); - sender - .send(State { - message: Arc::clone(&state.message), - is_isoch_or_vf: state.is_isoch_or_vf, - kind: state.state_kind, - waker: context.waker().clone(), - }) - .expect("IRQ reactor thread unexpectedly stopped"); + if let Err(err) = sender.send(State { + message: Arc::clone(&state.message), + is_isoch_or_vf: state.is_isoch_or_vf, + kind: state.state_kind, + waker: context.waker().clone(), + }) { + log::error!("IRQ reactor state channel closed: {}", err); + panic!("IRQ reactor state channel closed: {err}"); + } // Doorbell must be rung after sending state if let Some(doorbell) = doorbell_opt.take() { @@ -667,15 +799,20 @@ impl Xhci { first_trb: &Trb, last_trb: &Trb, doorbell: EventDoorbell, - ) -> impl Future + Send + Sync + 'static { + ) -> Result + Send + Sync + 'static> { if !last_trb.is_transfer_trb() { - panic!("Invalid TRB type given to next_transfer_event_trb(): {} (TRB {:?}. Expected transfer TRB.", last_trb.trb_type(), last_trb) + error!( + "Invalid TRB type given to next_transfer_event_trb(): {} (TRB {:?}). Expected transfer TRB.", + last_trb.trb_type(), + last_trb + ); + return Err(Error::new(EIO)); } let is_isoch_or_vf = last_trb.trb_type() == TrbType::Isoch as u8; - let first_phys_ptr = ring.trb_phys_ptr(self.cap.ac64(), first_trb); - let last_phys_ptr = ring.trb_phys_ptr(self.cap.ac64(), last_trb); - EventTrbFuture::Pending { + let first_phys_ptr = ring.trb_phys_ptr(self.cap.ac64(), first_trb)?; + let last_phys_ptr = ring.trb_phys_ptr(self.cap.ac64(), last_trb)?; + Ok(EventTrbFuture::Pending { state: FutureState { is_isoch_or_vf, state_kind: StateKind::Transfer { @@ -687,38 +824,39 @@ impl Xhci { }, sender: self.irq_reactor_sender.clone(), doorbell_opt: Some(doorbell), - } + }) } pub fn next_command_completion_event_trb( &self, command_ring: &Ring, trb: &Trb, doorbell: EventDoorbell, - ) -> impl Future + Send + Sync + 'static { - trace!( - "Sending command at phys_ptr {:X}", - command_ring.trb_phys_ptr(self.cap.ac64(), trb) - ); + ) -> Result + Send + Sync + 'static> { if !trb.is_command_trb() { - panic!("Invalid TRB type given to next_command_completion_event_trb(): {} (TRB {:?}. Expected command TRB.", trb.trb_type(), trb) + error!( + "Invalid TRB type given to next_command_completion_event_trb(): {} (TRB {:?}). Expected command TRB.", + trb.trb_type(), + trb + ); + return Err(Error::new(EIO)); } - EventTrbFuture::Pending { + let phys_ptr = command_ring.trb_phys_ptr(self.cap.ac64(), trb)?; + trace!("Sending command at phys_ptr {:X}", phys_ptr); + Ok(EventTrbFuture::Pending { state: FutureState { // This is only possible for transfers if they are isochronous, or for Force Event TRBs (virtualization). is_isoch_or_vf: false, - state_kind: StateKind::CommandCompletion { - phys_ptr: command_ring.trb_phys_ptr(self.cap.ac64(), trb), - }, + state_kind: StateKind::CommandCompletion { phys_ptr }, message: Arc::new(Mutex::new(None)), }, sender: self.irq_reactor_sender.clone(), doorbell_opt: Some(doorbell), - } + }) } pub fn next_misc_event_trb( &self, trb_type: TrbType, - ) -> impl Future + Send + Sync + 'static { + ) -> Result + Send + Sync + 'static> { let valid_trb_types = [ TrbType::PortStatusChange as u8, TrbType::BandwidthRequest as u8, @@ -728,9 +866,13 @@ impl Xhci { TrbType::MfindexWrap as u8, ]; if !valid_trb_types.contains(&(trb_type as u8)) { - panic!("Invalid TRB type given to next_misc_event_trb(): {:?}. Only event TRB types that are neither transfer events or command completion events can be used.", trb_type) + error!( + "Invalid TRB type given to next_misc_event_trb(): {:?}. Only event TRB types that are neither transfer events or command completion events can be used.", + trb_type + ); + return Err(Error::new(EIO)); } - EventTrbFuture::Pending { + Ok(EventTrbFuture::Pending { state: FutureState { is_isoch_or_vf: false, state_kind: StateKind::Other(trb_type), @@ -738,6 +880,6 @@ impl Xhci { }, sender: self.irq_reactor_sender.clone(), doorbell_opt: None, - } + }) } } diff --git a/drivers/usb/xhcid/src/xhci/mod.rs b/drivers/usb/xhcid/src/xhci/mod.rs index f2143676..a51b98c1 100644 --- a/drivers/usb/xhcid/src/xhci/mod.rs +++ b/drivers/usb/xhcid/src/xhci/mod.rs @@ -110,7 +110,7 @@ impl Xhci { .get_mut(&0) .ok_or(Error::new(EIO))? .ring() - .expect("no ring for the default control pipe"); + .ok_or(Error::new(EIO))?; let first_index = ring.next_index(); let (cmd, cycle) = (&mut ring.trbs[first_index], ring.cycle); @@ -140,7 +140,7 @@ impl Xhci { &ring.trbs[first_index], &ring.trbs[last_index], EventDoorbell::new(self, usize::from(slot), Self::def_control_endp_doorbell()), - ) + )? }; debug!("Waiting for the next transfer event TRB..."); @@ -485,7 +485,11 @@ impl Xhci { pub fn init(&mut self, max_slots: u8) -> Result<()> { // Set run/stop to 0 debug!("Stopping xHC."); - self.op.get_mut().unwrap().usb_cmd.writef(USB_CMD_RS, false); + self.op + .get_mut() + .unwrap_or_else(|e| e.into_inner()) + .usb_cmd + .writef(USB_CMD_RS, false); // Warm reset { @@ -493,10 +497,16 @@ impl Xhci { let timeout = Timeout::from_secs(1); self.op .get_mut() - .unwrap() + .unwrap_or_else(|e| e.into_inner()) .usb_cmd .writef(USB_CMD_HCRST, true); - while self.op.get_mut().unwrap().usb_cmd.readf(USB_CMD_HCRST) { + while self + .op + .get_mut() + .unwrap_or_else(|e| e.into_inner()) + .usb_cmd + .readf(USB_CMD_HCRST) + { timeout.run().map_err(|()| { log::error!("timeout on USB_CMD_HCRST"); Error::new(EIO) @@ -506,51 +516,76 @@ impl Xhci { // Set enabled slots debug!("Setting enabled slots to {}.", max_slots); - self.op.get_mut().unwrap().config.write(max_slots as u32); + self.op + .get_mut() + .unwrap_or_else(|e| e.into_inner()) + .config + .write(max_slots as u32); debug!( "Enabled Slots: {}", - self.op.get_mut().unwrap().config.read() & 0xFF + self.op + .get_mut() + .unwrap_or_else(|e| e.into_inner()) + .config + .read() + & 0xFF ); // Set device context address array pointer let dcbaap = self.dev_ctx.dcbaap(); debug!("Writing DCBAAP: {:X}", dcbaap); - self.op.get_mut().unwrap().dcbaap_low.write(dcbaap as u32); self.op .get_mut() - .unwrap() + .unwrap_or_else(|e| e.into_inner()) + .dcbaap_low + .write(dcbaap as u32); + self.op + .get_mut() + .unwrap_or_else(|e| e.into_inner()) .dcbaap_high .write((dcbaap as u64 >> 32) as u32); // Set command ring control register - let crcr = self.cmd.get_mut().unwrap().register(); + let crcr = self.cmd.get_mut().unwrap_or_else(|e| e.into_inner()).register(); assert_eq!(crcr & 0xFFFF_FFFF_FFFF_FFC1, crcr, "unaligned CRCR"); debug!("Writing CRCR: {:X}", crcr); - self.op.get_mut().unwrap().crcr_low.write(crcr as u32); self.op .get_mut() - .unwrap() + .unwrap_or_else(|e| e.into_inner()) + .crcr_low + .write(crcr as u32); + self.op + .get_mut() + .unwrap_or_else(|e| e.into_inner()) .crcr_high .write((crcr as u64 >> 32) as u32); // Set event ring segment table registers debug!( "Interrupter 0: {:p}", - self.run.get_mut().unwrap().ints.as_ptr() + self.run.get_mut().unwrap_or_else(|e| e.into_inner()).ints.as_ptr() ); { - let int = &mut self.run.get_mut().unwrap().ints[0]; + let int = &mut self.run.get_mut().unwrap_or_else(|e| e.into_inner()).ints[0]; let erstz = 1; debug!("Writing ERSTZ: {}", erstz); int.erstsz.write(erstz); - let erdp = self.primary_event_ring.get_mut().unwrap().erdp(); + let erdp = self + .primary_event_ring + .get_mut() + .unwrap_or_else(|e| e.into_inner()) + .dequeue_ptr(); debug!("Writing ERDP: {:X}", erdp); int.erdp_low.write(erdp as u32 | (1 << 3)); int.erdp_high.write((erdp as u64 >> 32) as u32); - let erstba = self.primary_event_ring.get_mut().unwrap().erstba(); + let erstba = self + .primary_event_ring + .get_mut() + .unwrap_or_else(|e| e.into_inner()) + .erstba(); debug!("Writing ERSTBA: {:X}", erstba); int.erstba_low.write(erstba as u32); int.erstba_high.write((erstba as u64 >> 32) as u32); @@ -563,7 +598,7 @@ impl Xhci { } self.op .get_mut() - .unwrap() + .unwrap_or_else(|e| e.into_inner()) .usb_cmd .writef(USB_CMD_INTE, true); @@ -572,12 +607,22 @@ impl Xhci { // Set run/stop to 1 debug!("Starting xHC."); - self.op.get_mut().unwrap().usb_cmd.writef(USB_CMD_RS, true); + self.op + .get_mut() + .unwrap_or_else(|e| e.into_inner()) + .usb_cmd + .writef(USB_CMD_RS, true); { debug!("Waiting for start request to complete."); let timeout = Timeout::from_secs(1); - while self.op.get_mut().unwrap().usb_sts.readf(USB_STS_HCH) { + while self + .op + .get_mut() + .unwrap_or_else(|e| e.into_inner()) + .usb_sts + .readf(USB_STS_HCH) + { timeout.run().map_err(|()| { log::error!("timeout on USB_STS_HCH"); Error::new(EIO) @@ -587,11 +632,14 @@ impl Xhci { // Ring command doorbell debug!("Ringing command doorbell."); - self.dbs.lock().unwrap()[0].write(0); + self.dbs.lock().unwrap_or_else(|e| e.into_inner())[0].write(0); debug!("XHCI initialized."); - self.op.get_mut().unwrap().set_cie(self.cap.cic()); + self.op + .get_mut() + .unwrap_or_else(|e| e.into_inner()) + .set_cie(self.cap.cic()); self.print_port_capabilities(); @@ -599,15 +647,20 @@ impl Xhci { } pub fn get_pls(&self, port_id: PortId) -> u8 { - let mut ports = self.ports.lock().unwrap(); - let port = ports.get_mut(port_id.root_hub_port_index()).unwrap(); - port.state() + let mut ports = self.ports.lock().unwrap_or_else(|e| e.into_inner()); + match ports.get_mut(port_id.root_hub_port_index()) { + Some(port) => port.state(), + None => { + warn!("get_pls: invalid root hub port index {}", port_id.root_hub_port_index()); + 0 + } + } } pub fn poll(&self) { debug!("Polling Initial Devices!"); - let len = self.ports.lock().unwrap().len(); + let len = self.ports.lock().unwrap_or_else(|e| e.into_inner()).len(); for root_hub_port_num in 1..=(len as u8) { let port_id = PortId { @@ -617,7 +670,7 @@ impl Xhci { //Get the CCS and CSC flags let (ccs, csc, flags) = { - let mut ports = self.ports.lock().unwrap(); + let mut ports = self.ports.lock().unwrap_or_else(|e| e.into_inner()); let port = &mut ports[port_id.root_hub_port_index()]; let flags = port.flags(); let ccs = flags.contains(PortFlags::CCS); @@ -633,10 +686,11 @@ impl Xhci { //Do nothing } _ => { - //Either something is connected, or nothing is connected and a port status change was asserted. - self.device_enumerator_sender + if let Err(err) = self.device_enumerator_sender .send(DeviceEnumerationRequest { port_id }) - .expect("Failed to generate the port enumeration request!"); + { + log::error!("port {}: failed to send enumeration request: {}", port_id, err); + } } } } @@ -645,7 +699,7 @@ impl Xhci { pub fn print_port_capabilities(&self) { let len; { - let mut ports = self.ports.lock().unwrap(); + let mut ports = self.ports.lock().unwrap_or_else(|e| e.into_inner()); len = ports.len(); } @@ -658,7 +712,7 @@ impl Xhci { let state = self.get_pls(port_id); let mut flags; { - let mut ports = self.ports.lock().unwrap(); + let mut ports = self.ports.lock().unwrap_or_else(|e| e.into_inner()); flags = ports[port_id.root_hub_port_index()].flags(); } @@ -684,9 +738,11 @@ impl Xhci { pub fn reset_port(&self, port_id: PortId) -> Result<()> { debug!("XHCI Port {} reset", port_id); - //TODO handle the second unwrap - let mut ports = self.ports.lock().unwrap(); - let port = ports.get_mut(port_id.root_hub_port_index()).unwrap(); + let mut ports = self.ports.lock().unwrap_or_else(|e| e.into_inner()); + let port = ports.get_mut(port_id.root_hub_port_index()).ok_or_else(|| { + warn!("reset_port: invalid root hub port index {}", port_id.root_hub_port_index()); + Error::new(EIO) + })?; let instant = std::time::Instant::now(); debug!("Port {} Link State: {}", port_id, port.state()); @@ -731,7 +787,7 @@ impl Xhci { { // If ERDP EHB bit is set, clear it before sending command //TODO: find out why this bit is set earlier! - let mut run = self.run.lock().unwrap(); + let mut run = self.run.lock().unwrap_or_else(|e| e.into_inner()); let mut int = &mut run.ints[index]; if int.erdp_low.readf(1 << 3) { @@ -743,7 +799,7 @@ impl Xhci { } pub fn interrupt_is_pending(&self, index: usize) -> bool { - let mut run = self.run.lock().unwrap(); + let mut run = self.run.lock().unwrap_or_else(|e| e.into_inner()); let mut int = &mut run.ints[index]; int.erdp_low.readf(1 << 3) } @@ -753,7 +809,7 @@ impl Xhci { let (event_trb, command_trb) = self .execute_command(|cmd, cycle| cmd.enable_slot(slot_ty, cycle)) - .await; + .await?; trace!("Slot is enabled!"); self::scheme::handle_event_trb("ENABLE_SLOT", &event_trb, &command_trb)?; @@ -765,7 +821,7 @@ impl Xhci { trace!("Disable slot {}", slot); let (event_trb, command_trb) = self .execute_command(|cmd, cycle| cmd.disable_slot(slot, cycle)) - .await; + .await?; self::scheme::handle_event_trb("DISABLE_SLOT", &event_trb, &command_trb)?; //self.event_handler_finished(); @@ -793,19 +849,58 @@ impl Xhci { } pub async fn attach_device(&self, port_id: PortId) -> syscall::Result<()> { + self.attach_device_with_speed(port_id, None).await + } + + pub async fn attach_device_with_speed( + &self, + port_id: PortId, + speed_override: Option, + ) -> syscall::Result<()> { if self.port_states.contains_key(&port_id) { debug!("Already contains port {}", port_id); return Err(syscall::Error::new(EAGAIN)); } - let (data, state, speed, flags) = { - let port = &self.ports.lock().unwrap()[port_id.root_hub_port_index()]; + let (data, state, portsc_speed, flags) = { + let port = &self.ports.lock().unwrap_or_else(|e| e.into_inner())[port_id.root_hub_port_index()]; (port.read(), port.state(), port.speed(), port.flags()) }; + let speed = match speed_override { + Some(byte) => { + let category_res = UsbSpeed::try_from(byte).ok(); + match category_res { + Some(category) => match self.lookup_speed_category(port_id, category) { + Some(proto_speed) => { + let psiv = proto_speed.psiv(); + log::info!( + "port {} speed override {:?} mapped to PSIV {}", + port_id, + category, + psiv + ); + psiv + } + None => { + log::warn!( + "port {} no protocol speed found for {:?}, falling back to PORTSC speed {}", + port_id, + category, + portsc_speed + ); + portsc_speed + } + }, + None => portsc_speed, + } + } + None => portsc_speed, + }; + debug!( - "XHCI Port {}: {:X}, State {}, Speed {}, Flags {:?}", - port_id, data, state, speed, flags + "XHCI Port {}: {:X}, State {}, Speed {} (override {:?})", + port_id, data, state, speed, speed_override ); if flags.contains(port::PortFlags::CCS) { @@ -829,10 +924,13 @@ impl Xhci { debug!("Enabled port {}, which the xHC mapped to {}", port_id, slot); - //TODO: get correct speed for child devices - let protocol_speed = self - .lookup_psiv(port_id, speed) - .expect("Failed to retrieve speed ID"); + let protocol_speed = match self.lookup_psiv(port_id, speed) { + Some(ps) => ps, + None => { + log::error!("port {}: no protocol speed for PSIV {}, cannot attach", port_id, speed); + return Err(syscall::Error::new(syscall::EINVAL)); + } + }; let mut input = unsafe { self.alloc_dma_zeroed::>()? }; @@ -873,9 +971,12 @@ impl Xhci { // Ensure correct packet size is used let dev_desc_8_byte = self.fetch_dev_desc_8_byte(port_id, slot).await?; { - let mut port_state = self.port_states.get_mut(&port_id).unwrap(); + let mut port_state = self.port_states.get_mut(&port_id).ok_or_else(|| { + warn!("fetch_descriptors: missing port state for {}", port_id); + Error::new(EIO) + })?; - let mut input = port_state.input_context.lock().unwrap(); + let mut input = port_state.input_context.lock().unwrap_or_else(|e| e.into_inner()); self.update_max_packet_size(&mut *input, slot, dev_desc_8_byte) .await?; @@ -885,15 +986,24 @@ impl Xhci { let dev_desc = self.get_desc(port_id, slot).await?; debug!("Got the full device descriptor!"); - self.port_states.get_mut(&port_id).unwrap().dev_desc = Some(dev_desc); + self.port_states.get_mut(&port_id).ok_or_else(|| { + warn!("fetch_descriptors: missing port state for {}", port_id); + Error::new(EIO) + })?.dev_desc = Some(dev_desc); debug!("Got the port states again!"); { - let mut port_state = self.port_states.get_mut(&port_id).unwrap(); + let mut port_state = self.port_states.get_mut(&port_id).ok_or_else(|| { + warn!("fetch_descriptors: missing port state for {}", port_id); + Error::new(EIO) + })?; - let mut input = port_state.input_context.lock().unwrap(); + let mut input = port_state.input_context.lock().unwrap_or_else(|e| e.into_inner()); debug!("Got the input context!"); - let dev_desc = port_state.dev_desc.as_ref().unwrap(); + let dev_desc = port_state.dev_desc.as_ref().ok_or_else(|| { + warn!("fetch_descriptors: device descriptor not set for {}", port_id); + Error::new(EIO) + })?; self.update_default_control_pipe(&mut *input, slot, dev_desc) .await?; @@ -932,12 +1042,12 @@ impl Xhci { ); } None => { - //TODO: kill harder warn!( - "driver process {} for port {} still running", + "driver process {} for port {} still running, sending SIGKILL", child.id(), port_id ); + let _ = child.kill(); } }, Err(err) => { @@ -1001,7 +1111,7 @@ impl Xhci { .execute_command(|trb, cycle| { trb.evaluate_context(slot_id, input_context.physical(), false, cycle) }) - .await; + .await?; self::scheme::handle_event_trb("EVALUATE_CONTEXT", &event_trb, &command_trb)?; //self.event_handler_finished(); @@ -1035,7 +1145,7 @@ impl Xhci { .execute_command(|trb, cycle| { trb.evaluate_context(slot_id, input_context.physical(), false, cycle) }) - .await; + .await?; debug!("Completed the command to update the default control pipe"); self::scheme::handle_event_trb("EVALUATE_CONTEXT", &event_trb, &command_trb)?; @@ -1060,13 +1170,10 @@ impl Xhci { if let Some((parent_port, port_num)) = port.parent() { match self.port_states.get(&parent_port) { Some(parent_state) => { - // parent info must be supplied if: let mut needs_parent_info = false; + // parent info must be supplied if: // 1. the device is low or full speed and connected through a high speed hub - //TODO: determine device speed (speed is not accurate as it comes from the port) // 2. the device is superspeed and connected through a higher rank hub - //TODO: determine device speed (speed is not accurate as it comes from the port) - // For now, this is just set to true to force things to work needs_parent_info = true; if needs_parent_info { parent_hub_slot_id = parent_state.slot; @@ -1110,8 +1217,7 @@ impl Xhci { | (u32::from(number_of_ports) << 24), ); - // TODO - let ttt = 0u8; + let ttt = 0u8; // TODO: read from parent hub HubDesc think_time field let interrupter = 0u8; assert_eq!(ttt & 0b11, ttt); @@ -1133,7 +1239,7 @@ impl Xhci { 512 }; let host_initiate_disable = false; // only applies to streams - let max_burst_size = 0u8; // TODO + let max_burst_size = 0u8; assert_eq!(max_error_count & 0b11, max_error_count); input_context.device.endpoints[0].b.write( @@ -1166,7 +1272,7 @@ impl Xhci { .execute_command(|trb, cycle| { trb.address_device(slot, input_context_physical, false, cycle) }) - .await; + .await?; if event_trb.completion_code() != TrbCompletionCode::Success as u8 { error!( @@ -1190,7 +1296,7 @@ impl Xhci { /// Checks whether an IRQ has been received from *this* device, in case of an interrupt. Always /// true when using MSI/MSI-X. pub fn received_irq(&self) -> bool { - let mut runtime_regs = self.run.lock().unwrap(); + let mut runtime_regs = self.run.lock().unwrap_or_else(|e| e.into_inner()); if self.uses_msi_interrupts() { // Since using MSI and MSI-X implies having no IRQ sharing whatsoever, the IP bit @@ -1224,7 +1330,10 @@ impl Xhci { // TODO: Now that there are some good error crates, I don't think errno.h error codes are // suitable here. - let ps = self.port_states.get(&port).unwrap(); + let ps = self.port_states.get(&port).ok_or_else(|| { + warn!("spawn_drivers: missing port state for {}", port); + Error::new(EIO) + })?; trace!("Spawning driver on port: {}", port); //TODO: support choosing config? @@ -1243,7 +1352,10 @@ impl Xhci { })?; trace!("Got config and device descriptors on port {}", port); - let drivers_usercfg: &DriversConfig = &DRIVERS_CONFIG; + let drivers_usercfg = DRIVERS_CONFIG.as_ref().map_err(|err| { + error!("failed to parse internally embedded xhcid drivers config: {}", err); + Error::new(EIO) + })?; for ifdesc in config_desc.interface_descs.iter() { //TODO: support alternate settings @@ -1435,13 +1547,27 @@ impl Xhci { self.supported_protocol_speeds(port) .find(|speed| speed.psiv() == psiv) } + + fn lookup_speed_category( + &self, + port: PortId, + category: UsbSpeed, + ) -> Option<&'static ProtocolSpeed> { + self.supported_protocol_speeds(port).find(|speed| match category { + UsbSpeed::Low => speed.is_lowspeed(), + UsbSpeed::Full => speed.is_fullspeed(), + UsbSpeed::High => speed.is_highspeed(), + UsbSpeed::Super => speed.is_superspeed_gen1x1(), + UsbSpeed::SuperPlus => speed.is_superspeed_gen_x() && !speed.is_superspeed_gen1x1(), + }) + } } pub fn start_irq_reactor(hci: &Arc>, irq_file: Option) { let hci_clone = Arc::clone(&hci); debug!("About to start IRQ reactor"); - *hci.irq_reactor.lock().unwrap() = Some(thread::spawn(move || { + *hci.irq_reactor.lock().unwrap_or_else(|e| e.into_inner()) = Some(thread::spawn(move || { debug!("Started IRQ reactor thread"); IrqReactor::new(hci_clone, irq_file).run() })); @@ -1452,7 +1578,7 @@ pub fn start_device_enumerator(hci: &Arc>) { debug!("About to start Device Enumerator"); - *hci.device_enumerator.lock().unwrap() = Some(thread::spawn(move || { + *hci.device_enumerator.lock().unwrap_or_else(|e| e.into_inner()) = Some(thread::spawn(move || { debug!("Started Device Enumerator"); DeviceEnumerator::new(hci_clone).run(); })); @@ -1480,10 +1606,10 @@ use crate::xhci::port::PortFlags; use lazy_static::lazy_static; lazy_static! { - static ref DRIVERS_CONFIG: DriversConfig = { + static ref DRIVERS_CONFIG: std::result::Result = { // TODO: Load this at runtime. const TOML: &'static [u8] = include_bytes!("../../drivers.toml"); - toml::from_slice::(TOML).expect("Failed to parse internally embedded config file") + toml::from_slice::(TOML) }; } diff --git a/drivers/usb/xhcid/src/xhci/ring.rs b/drivers/usb/xhcid/src/xhci/ring.rs index 8e187ebe..c05ed8fb 100644 --- a/drivers/usb/xhcid/src/xhci/ring.rs +++ b/drivers/usb/xhcid/src/xhci/ring.rs @@ -1,6 +1,7 @@ use std::mem; -use syscall::error::Result; +use log::error; +use syscall::error::{Error, Result, EIO}; use common::dma::Dma; @@ -108,7 +109,7 @@ impl Ring { pub(crate) fn end_virt_addr(&self) -> *const Trb { unsafe { self.start_virt_addr().offset(self.trbs.len() as isize) } } - pub fn trb_phys_ptr(&self, ac64: bool, trb: &Trb) -> u64 { + pub fn trb_phys_ptr(&self, ac64: bool, trb: &Trb) -> Result { let trb_virt_pointer = trb as *const Trb; let trbs_base_virt_pointer = self.trbs.as_ptr(); @@ -116,7 +117,12 @@ impl Ring { || (trb_virt_pointer as usize) > (trbs_base_virt_pointer as usize) + self.trbs.len() * mem::size_of::() { - panic!("Gave a TRB outside of the ring, when retrieving its physical address in that ring. TRB: {:?} (at address {:p})", trb, trb); + error!( + "Gave a TRB outside of the ring when retrieving its physical address. TRB: {:?} (at address {:p})", + trb, + trb + ); + return Err(Error::new(EIO)); } let trb_offset_from_base = trb_virt_pointer as u64 - trbs_base_virt_pointer as u64; @@ -127,7 +133,7 @@ impl Ring { 0xFFFF_FFFF }; let trb_phys_ptr = trbs_base_phys_ptr + trb_offset_from_base; - trb_phys_ptr + Ok(trb_phys_ptr) } /* /// Endless mutable iterator that iterates through the ring items, over and over again. The diff --git a/drivers/usb/xhcid/src/xhci/scheme.rs b/drivers/usb/xhcid/src/xhci/scheme.rs index f2d439a4..5868d13d 100644 --- a/drivers/usb/xhcid/src/xhci/scheme.rs +++ b/drivers/usb/xhcid/src/xhci/scheme.rs @@ -54,31 +54,37 @@ use crate::driver_interface::*; use regex::Regex; lazy_static! { - static ref REGEX_PORT_CONFIGURE: Regex = Regex::new(r"^port([\d\.]+)/configure$") - .expect("Failed to create the regex for the port/configure scheme."); - static ref REGEX_PORT_ATTACH: Regex = Regex::new(r"^port([\d\.]+)/attach$") - .expect("Failed to create the regex for the port/attach scheme."); - static ref REGEX_PORT_DETACH: Regex = Regex::new(r"^port([\d\.]+)/detach$") - .expect("Failed to create the regex for the port/detach scheme."); - static ref REGEX_PORT_DESCRIPTORS: Regex = Regex::new(r"^port([\d\.]+)/descriptors$") - .expect("Failed to create the regex for the port/descriptors"); - static ref REGEX_PORT_STATE: Regex = Regex::new(r"^port([\d\.]+)/state$") - .expect("Failed to create the regex for the port/state scheme"); - static ref REGEX_PORT_REQUEST: Regex = Regex::new(r"^port([\d\.]+)/request$") - .expect("Failed to create the regex for the port/request scheme"); - static ref REGEX_PORT_ENDPOINTS: Regex = Regex::new(r"^port([\d\.]+)/endpoints$") - .expect("Failed to create the regex for the port/endpoints scheme"); - static ref REGEX_PORT_SPECIFIC_ENDPOINT: Regex = - Regex::new(r"^port([\d\.]+)/endpoints/(\d{1,3})$") - .expect("Failed to create the regex for the port/endpoints/ scheme"); - static ref REGEX_PORT_SUB_ENDPOINT: Regex = Regex::new( - r"port([\d\.]+)/endpoints/(\d{1,3})/(ctl|data)$" - ) - .expect("Failed to create the regex for the port/endpoints// scheme"); - static ref REGEX_PORT_ROOT: Regex = - Regex::new(r"^port([\d\.]+)$").expect("Failed to create the regex for the port scheme."); - static ref REGEX_TOP_LEVEL: Regex = - Regex::new(r"^$").expect("Failed to create the regex for the top-level scheme"); + static ref REGEX_PORT_CONFIGURE: std::result::Result = + Regex::new(r"^port([\d\.]+)/configure$"); + static ref REGEX_PORT_ATTACH: std::result::Result = + Regex::new(r"^port([\d\.]+)/attach$"); + static ref REGEX_PORT_DETACH: std::result::Result = + Regex::new(r"^port([\d\.]+)/detach$"); + static ref REGEX_PORT_DESCRIPTORS: std::result::Result = + Regex::new(r"^port([\d\.]+)/descriptors$"); + static ref REGEX_PORT_STATE: std::result::Result = + Regex::new(r"^port([\d\.]+)/state$"); + static ref REGEX_PORT_REQUEST: std::result::Result = + Regex::new(r"^port([\d\.]+)/request$"); + static ref REGEX_PORT_ENDPOINTS: std::result::Result = + Regex::new(r"^port([\d\.]+)/endpoints$"); + static ref REGEX_PORT_SPECIFIC_ENDPOINT: std::result::Result = + Regex::new(r"^port([\d\.]+)/endpoints/(\d{1,3})$"); + static ref REGEX_PORT_SUB_ENDPOINT: std::result::Result = + Regex::new(r"port([\d\.]+)/endpoints/(\d{1,3})/(ctl|data)$"); + static ref REGEX_PORT_ROOT: std::result::Result = + Regex::new(r"^port([\d\.]+)$"); + static ref REGEX_TOP_LEVEL: std::result::Result = Regex::new(r"^$"); +} + +fn compiled_regex( + regex: &'static std::result::Result, + description: &'static str, +) -> Result<&'static Regex> { + regex.as_ref().map_err(|err| { + error!("failed to compile {} regex: {}", description, err); + Error::new(EIO) + }) } pub enum ControlFlow { @@ -369,51 +375,64 @@ impl SchemeParameters { //and store it if it's valid. //Generate the regular expressions for all of our valid schemes. + let regex_port_configure = compiled_regex(®EX_PORT_CONFIGURE, "port/configure")?; + let regex_port_attach = compiled_regex(®EX_PORT_ATTACH, "port/attach")?; + let regex_port_detach = compiled_regex(®EX_PORT_DETACH, "port/detach")?; + let regex_port_descriptors = compiled_regex(®EX_PORT_DESCRIPTORS, "port/descriptors")?; + let regex_port_state = compiled_regex(®EX_PORT_STATE, "port/state")?; + let regex_port_request = compiled_regex(®EX_PORT_REQUEST, "port/request")?; + let regex_port_endpoints = compiled_regex(®EX_PORT_ENDPOINTS, "port/endpoints")?; + let regex_port_specific_endpoint = + compiled_regex(®EX_PORT_SPECIFIC_ENDPOINT, "port/endpoints/")?; + let regex_port_sub_endpoint = + compiled_regex(®EX_PORT_SUB_ENDPOINT, "port/endpoints//")?; + let regex_port_root = compiled_regex(®EX_PORT_ROOT, "port")?; + let regex_top_level = compiled_regex(®EX_TOP_LEVEL, "top-level")?; //Check if we have a match and either return a partially initialized scheme, OR ENOENT - if REGEX_PORT_CONFIGURE.is_match(scheme) { - let port_num = get_port_id_from_regex(®EX_PORT_CONFIGURE, scheme, 0)?; + if regex_port_configure.is_match(scheme) { + let port_num = get_port_id_from_regex(regex_port_configure, scheme, 0)?; Ok(Self::ConfigureEndpoints(port_num)) - } else if REGEX_PORT_ATTACH.is_match(scheme) { - let port_num = get_port_id_from_regex(®EX_PORT_ATTACH, scheme, 0)?; + } else if regex_port_attach.is_match(scheme) { + let port_num = get_port_id_from_regex(regex_port_attach, scheme, 0)?; Ok(Self::AttachDevice(port_num)) - } else if REGEX_PORT_DETACH.is_match(scheme) { - let port_num = get_port_id_from_regex(®EX_PORT_DETACH, scheme, 0)?; + } else if regex_port_detach.is_match(scheme) { + let port_num = get_port_id_from_regex(regex_port_detach, scheme, 0)?; Ok(Self::DetachDevice(port_num)) - } else if REGEX_PORT_DESCRIPTORS.is_match(scheme) { - let port_num = get_port_id_from_regex(®EX_PORT_DESCRIPTORS, scheme, 0)?; + } else if regex_port_descriptors.is_match(scheme) { + let port_num = get_port_id_from_regex(regex_port_descriptors, scheme, 0)?; Ok(Self::PortDesc(port_num)) - } else if REGEX_PORT_STATE.is_match(scheme) { - let port_num = get_port_id_from_regex(®EX_PORT_STATE, scheme, 0)?; + } else if regex_port_state.is_match(scheme) { + let port_num = get_port_id_from_regex(regex_port_state, scheme, 0)?; Ok(Self::PortState(port_num)) - } else if REGEX_PORT_REQUEST.is_match(scheme) { - let port_num = get_port_id_from_regex(®EX_PORT_REQUEST, scheme, 0)?; + } else if regex_port_request.is_match(scheme) { + let port_num = get_port_id_from_regex(regex_port_request, scheme, 0)?; Ok(Self::PortReq(port_num)) - } else if REGEX_PORT_ENDPOINTS.is_match(scheme) { - let port_num = get_port_id_from_regex(®EX_PORT_ENDPOINTS, scheme, 0)?; + } else if regex_port_endpoints.is_match(scheme) { + let port_num = get_port_id_from_regex(regex_port_endpoints, scheme, 0)?; Ok(Self::Endpoints(port_num)) - } else if REGEX_PORT_SPECIFIC_ENDPOINT.is_match(scheme) { - let port_num = get_port_id_from_regex(®EX_PORT_SPECIFIC_ENDPOINT, scheme, 0)?; - let endpoint_num = get_u8_from_regex(®EX_PORT_SPECIFIC_ENDPOINT, scheme, 1)?; + } else if regex_port_specific_endpoint.is_match(scheme) { + let port_num = get_port_id_from_regex(regex_port_specific_endpoint, scheme, 0)?; + let endpoint_num = get_u8_from_regex(regex_port_specific_endpoint, scheme, 1)?; Ok(Self::Endpoint(port_num, endpoint_num, String::from("root"))) - } else if REGEX_PORT_SUB_ENDPOINT.is_match(scheme) { - let port_num = get_port_id_from_regex(®EX_PORT_SUB_ENDPOINT, scheme, 0)?; - let endpoint_num = get_u8_from_regex(®EX_PORT_SUB_ENDPOINT, scheme, 1)?; - let handle_type = get_string_from_regex(®EX_PORT_SUB_ENDPOINT, scheme, 2)?; + } else if regex_port_sub_endpoint.is_match(scheme) { + let port_num = get_port_id_from_regex(regex_port_sub_endpoint, scheme, 0)?; + let endpoint_num = get_u8_from_regex(regex_port_sub_endpoint, scheme, 1)?; + let handle_type = get_string_from_regex(regex_port_sub_endpoint, scheme, 2)?; Ok(Self::Endpoint(port_num, endpoint_num, handle_type)) - } else if REGEX_PORT_ROOT.is_match(scheme) { - let port_num = get_port_id_from_regex(®EX_PORT_ROOT, scheme, 0)?; + } else if regex_port_root.is_match(scheme) { + let port_num = get_port_id_from_regex(regex_port_root, scheme, 0)?; Ok(Self::Port(port_num)) - } else if REGEX_TOP_LEVEL.is_match(scheme) { + } else if regex_top_level.is_match(scheme) { Ok(Self::TopLevel) } else { Err(Error::new(ENOENT)) @@ -589,7 +608,7 @@ impl Xhci { /// /// # Locking /// This function will lock `Xhci::cmd` and `Xhci::dbs`. - pub async fn execute_command(&self, f: F) -> (Trb, Trb) { + pub async fn execute_command(&self, f: F) -> Result<(Trb, Trb)> { //TODO: find out why this bit is set earlier! if self.interrupt_is_pending(0) { debug!("The EHB bit is already set!"); @@ -597,7 +616,7 @@ impl Xhci { } let next_event = { - let mut command_ring = self.cmd.lock().unwrap(); + let mut command_ring = self.cmd.lock().unwrap_or_else(|e| e.into_inner()); let (cmd_index, cycle) = (command_ring.next_index(), command_ring.cycle); debug!("Sending command with cycle bit {}", cycle as u8); @@ -613,12 +632,15 @@ impl Xhci { &*command_ring, command_trb, EventDoorbell::new(self, 0, 0), - ) + )? }; let trbs = next_event.await; let event_trb = trbs.event_trb; - let command_trb = trbs.src_trb.expect("Command completion event TRBs shall always have a valid pointer to a valid source command TRB"); + let command_trb = trbs.src_trb.ok_or_else(|| { + error!("command completion event TRB missing source command TRB"); + Error::new(EIO) + })?; assert_eq!( event_trb.trb_type(), @@ -626,7 +648,7 @@ impl Xhci { "The IRQ reactor (or the xHC) gave an invalid event TRB" ); - (event_trb, command_trb) + Ok((event_trb, command_trb)) } pub async fn execute_control_transfer( &self, @@ -681,7 +703,7 @@ impl Xhci { &ring.trbs[first_index], &ring.trbs[last_index], EventDoorbell::new(self, usize::from(slot), Self::def_control_endp_doorbell()), - ) + )? }; let trbs = future.await; @@ -773,7 +795,7 @@ impl Xhci { doorbell_data_no_stream }, ), - ); + )?; } ControlFlow::Continue => continue, } @@ -856,7 +878,7 @@ impl Xhci { .execute_command(|trb, cycle| { trb.reset_endpoint(slot, endp_num_xhc, tsp, cycle); }) - .await; + .await?; //self.event_handler_finished(); handle_event_trb("RESET_ENDPOINT", &event_trb, &command_trb) @@ -893,7 +915,9 @@ impl Xhci { endp_desc: &EndpDesc, ) -> u8 { if speed_id.is_highspeed() && (endp_desc.is_interrupt() || endp_desc.is_isoch()) { - assert_eq!(dev_desc.major_version(), 2); + if dev_desc.major_version() != 2 { + log::warn!("high-speed endpoint on USB {} device, expected USB 2", dev_desc.major_version()); + } ((endp_desc.max_packet_size & 0x0C00) >> 11) as u8 } else if endp_desc.is_superspeed() { endp_desc.max_burst() @@ -916,10 +940,10 @@ impl Xhci { if dev_desc.major_version() == 2 && endp_desc.is_periodic() { u32::from(max_packet_size) * (u32::from(max_burst_size) + 1) - } else if endp_desc.has_ssp_companion() { - endp_desc.sspc.as_ref().unwrap().bytes_per_interval - } else if endp_desc.ssc.is_some() { - u32::from(endp_desc.ssc.as_ref().unwrap().bytes_per_interval) + } else if let Some(sspc) = endp_desc.sspc.as_ref() { + sspc.bytes_per_interval + } else if let Some(ssc) = endp_desc.ssc.as_ref() { + u32::from(ssc.bytes_per_interval) } else if speed_id.is_fullspeed() && endp_desc.is_interrupt() { 64 } else if speed_id.is_fullspeed() && endp_desc.is_isoch() { @@ -957,16 +981,17 @@ impl Xhci { - port_state.cfg_idx = Some(req.config_desc); - - let config_desc = port_state + let dev_desc = port_state .dev_desc .as_ref() - .unwrap() + .ok_or(Error::new(EBADFD))?; + + let config_desc = dev_desc .config_descs .iter() .find(|desc| desc.configuration_value == req.config_desc) .ok_or(Error::new(EBADFD))?; + let configuration_value = config_desc.configuration_value; @@ -987,16 +1012,18 @@ impl Xhci { return Err(Error::new(EIO)); } + port_state.cfg_idx = Some(configuration_value); + ( endp_desc_count, new_context_entries, - config_desc.configuration_value, + configuration_value, ) }; let lec = self.cap.lec(); let log_max_psa_size = self.cap.max_psa_size(); @@ -996,7 +1022,7 @@ impl Xhci { let lec = self.cap.lec(); let log_max_psa_size = self.cap.max_psa_size(); - let port_speed_id = self.ports.lock().unwrap()[port.root_hub_port_index()].speed(); + let port_speed_id = self.ports.lock().unwrap_or_else(|e| e.into_inner())[port.root_hub_port_index()].speed(); let speed_id: &ProtocolSpeed = self.lookup_psiv(port, port_speed_id).ok_or_else(|| { warn!("no speed_id"); Error::new(EIO) @@ -1004,7 +1030,7 @@ impl Xhci { { let port_state = self.port_states.get(&port).ok_or(Error::new(EBADFD))?; - let mut input_context = port_state.input_context.lock().unwrap(); + let mut input_context = port_state.input_context.lock().unwrap_or_else(|e| e.into_inner()); // Configure the slot context as well, which holds the last index of the endp descs. input_context.add_context.write(1); @@ -1035,7 +1061,7 @@ impl Xhci { input_context.device.slot.a.write(current_slot_a); input_context.device.slot.b.write(current_slot_b); - let control = if self.op.lock().unwrap().cie() { + let control = if self.op.lock().unwrap_or_else(|e| e.into_inner()).cie() { (u32::from(req.alternate_setting.unwrap_or(0)) << 16) | (u32::from(req.interface_desc.unwrap_or(0)) << 8) | u32::from(configuration_value) @@ -1049,7 +1075,7 @@ impl Xhci { let endp_num = endp_idx + 1; let mut port_state = self.port_states.get_mut(&port).ok_or(Error::new(EBADFD))?; - let dev_desc = port_state.dev_desc.as_ref().unwrap(); + let dev_desc = port_state.dev_desc.as_ref().ok_or(Error::new(EBADFD))?; let endp_desc = port_state.get_endp_desc(endp_idx).ok_or_else(|| { warn!("failed to find endpoint {}", endp_idx); Error::new(EIO) @@ -1111,7 +1137,10 @@ impl Xhci { assert_eq!(ep_ty & 0x7, ep_ty); assert_eq!(mult & 0x3, mult); assert_eq!(max_error_count & 0x3, max_error_count); - assert_ne!(ep_ty, 0); // 0 means invalid. + if ep_ty == 0 { + warn!("endpoint {} has invalid xHCI type 0", endp_num); + return Err(Error::new(EIO)); + } let ring_ptr = if usb_log_max_streams.is_some() { let mut array = @@ -1155,7 +1184,7 @@ impl Xhci { }; assert_eq!(primary_streams & 0x1F, primary_streams); - let mut input_context = port_state.input_context.lock().unwrap(); + let mut input_context = port_state.input_context.lock().unwrap_or_else(|e| e.into_inner()); input_context.add_context.writef(1 << endp_num_xhc, true); let endp_i = endp_num_xhc as usize - 1; @@ -1191,13 +1220,13 @@ impl Xhci { { let port_state = self.port_states.get(&port).ok_or(Error::new(EBADFD))?; let slot = port_state.slot; - let input_context_physical = port_state.input_context.lock().unwrap().physical(); + let input_context_physical = port_state.input_context.lock().unwrap_or_else(|e| e.into_inner()).physical(); - let (event_trb, command_trb) = self - .execute_command(|trb, cycle| { - trb.configure_endpoint(slot, input_context_physical, cycle) - }) - .await; + let (event_trb, command_trb) = self + .execute_command(|trb, cycle| { + trb.configure_endpoint(slot, input_context_physical, cycle) + }) + .await?; //self.event_handler_finished(); @@ -1234,8 +1263,16 @@ impl Xhci { if let Some(interface_num) = req.interface_desc { if let Some(alternate_setting) = req.alternate_setting { - self.set_interface(port, interface_num, alternate_setting) - .await?; + if let Err(err) = self.set_interface(port, interface_num, alternate_setting).await { + if alternate_setting == 0 && interface_num == 0 { + log::debug!( + "port {}: SET_INTERFACE(0,0) failed (stall likely): {}, continuing", + port, err + ); + } else { + return Err(err); + } + } } } @@ -1261,7 +1298,8 @@ impl Xhci { ) .await?; - buf.copy_from_slice(&*dma_buffer.as_ref().unwrap()); + let dma_ref = dma_buffer.as_ref().ok_or(Error::new(EIO))?; + buf.copy_from_slice(&*dma_ref); Ok((completion_code, bytes_transferred)) } async fn transfer_write( @@ -1327,7 +1365,6 @@ impl Xhci { dma_buf: Option>, direction: PortReqDirection, ) -> Result<(u8, u32, Option>)> { - // TODO: Check that only readable enpoints are read, etc. let endp_num = endp_idx + 1; let mut port_state = self @@ -1442,7 +1479,7 @@ impl Xhci { Ok((event.completion_code(), bytes_transferred, dma_buf)) } pub async fn get_desc(&self, port_id: PortId, slot: u8) -> Result { - let ports = self.ports.lock().unwrap(); + let ports = self.ports.lock().unwrap_or_else(|e| e.into_inner()); let port = ports .get(port_id.root_hub_port_index()) .ok_or(Error::new(ENOENT))?; @@ -1506,12 +1543,45 @@ impl Xhci { serial_str ); - //TODO let (bos_desc, bos_data) = self.fetch_bos_desc(port_id, slot).await?; - - let supports_superspeed = false; - //TODO usb::bos_capability_descs(bos_desc, &bos_data).any(|desc| desc.is_superspeed()); - let supports_superspeedplus = false; - //TODO usb::bos_capability_descs(bos_desc, &bos_data).any(|desc| desc.is_superspeedplus()); + let (supports_superspeed, supports_superspeedplus) = + match self.fetch_bos_desc(port_id, slot).await { + Ok((bos_desc, bos_data)) => { + let bos_len = bos_desc.total_len as usize; + let bos_slice = if bos_len <= bos_data.len() { + &bos_data[..bos_len] + } else { + log::warn!( + "port {} slot {} BOS total_len {} exceeds buffer {}, truncating", + port_id, slot, bos_len, bos_data.len() + ); + &bos_data[..] + }; + let caps: Vec<_> = usb::bos_capability_descs( + bos_desc, + bos_slice, + ) + .collect(); + let ss = caps.iter().any(|desc| desc.is_superspeed()); + let ssp = caps.iter().any(|desc| desc.is_superspeedplus()); + log::info!( + "port {} slot {} BOS: superspeed={} superspeedplus={}", + port_id, + slot, + ss, + ssp + ); + (ss, ssp) + } + Err(err) => { + log::debug!( + "port {} slot {} BOS descriptor not available: {}", + port_id, + slot, + err + ); + (false, false) + } + }; let mut config_descs = SmallVec::new(); @@ -1564,11 +1634,11 @@ impl Xhci { match iter.peek() { Some(AnyDescriptor::SuperSpeedCompanion(n)) => { endp.ssc = Some(SuperSpeedCmp::from(n.clone())); - iter.next().unwrap(); + let _ = iter.next(); } Some(AnyDescriptor::SuperSpeedPlusCompanion(n)) => { endp.sspc = Some(SuperSpeedPlusIsochCmp::from(n.clone())); - iter.next().unwrap(); + let _ = iter.next(); } _ => break, } @@ -1619,6 +1689,8 @@ impl Xhci { product_str, serial_str, config_descs, + supports_superspeed, + supports_superspeedplus, }) } fn port_desc_json(&self, port_id: PortId) -> Result> { @@ -1801,14 +1873,14 @@ impl Xhci { if flags & O_DIRECTORY != 0 || flags & O_STAT != 0 { let mut contents = Vec::new(); - let ports_guard = self.ports.lock().unwrap(); + let ports_guard = self.ports.lock().unwrap_or_else(|e| e.into_inner()); for (index, _) in ports_guard .iter() .enumerate() .filter(|(_, port)| port.flags().contains(port::PortFlags::CCS)) { - write!(contents, "port{}\n", index).unwrap(); + write!(contents, "port{}\n", index).map_err(|_| Error::new(EIO))?; } Ok(Handle::TopLevel(contents)) @@ -1856,7 +1928,7 @@ impl Xhci { if (flags & O_DIRECTORY != 0) || (flags & O_STAT != 0) { let mut contents = Vec::new(); - write!(contents, "descriptors\nendpoints\n").unwrap(); + write!(contents, "descriptors\nendpoints\n").map_err(|_| Error::new(EIO))?; if self.slot_state( self.port_states @@ -1865,7 +1937,7 @@ impl Xhci { .slot as usize, ) != SlotState::Configured as u8 { - write!(contents, "configure\n").unwrap(); + write!(contents, "configure\n").map_err(|_| Error::new(EIO))?; } Ok(Handle::Port(port_num, contents)) @@ -1916,7 +1988,7 @@ impl Xhci { }*/ for ep_num in ps.endpoint_states.keys() { - write!(contents, "{}\n", ep_num).unwrap(); + write!(contents, "{}\n", ep_num).map_err(|_| Error::new(EIO))?; } Ok(Handle::Endpoints(port_num, contents)) @@ -2007,10 +2079,13 @@ impl Xhci { }; Ok(Handle::Endpoint(port_num, endpoint_num, st)) } - _ => panic!( - "Scheme parser returned an invalid string '{}' for the endpoint handle type", - handle_type - ), + _ => { + log::error!( + "Scheme parser returned an invalid string '{}' for the endpoint handle type", + handle_type + ); + return Err(Error::new(ENOENT)); + } } } @@ -2218,16 +2293,9 @@ impl SchemeSync for &Xhci { let guard = self.handles.get(&fd).ok_or(Error::new(EBADF))?; let scheme = (&*guard).to_scheme(); - write!(cursor, "{}", scheme.as_str()).expect( - format!( - "Failed to convert the file descriptor with value {} to the associated file path", - fd - ) - .as_str(), - ); + write!(cursor, "{}", scheme.as_str()).map_err(|_| Error::new(EIO))?; - let src_len = usize::try_from(cursor.seek(io::SeekFrom::End(0)).unwrap()).unwrap(); - Ok(src_len) + Ok(cursor.position() as usize) } fn read( @@ -2324,12 +2392,15 @@ impl SchemeSync for &Xhci { Ok(buf.len()) } &mut Handle::AttachDevice(port_num) => { - //TODO: accept some arguments in buffer? - block_on(self.attach_device(port_num))?; + let speed_override = if buf.len() == 1 { + Some(buf[0]) + } else { + None + }; + block_on(self.attach_device_with_speed(port_num, speed_override))?; Ok(buf.len()) } &mut Handle::DetachDevice(port_num) => { - //TODO: accept some arguments in buffer? block_on(self.detach_device(port_num))?; Ok(buf.len()) } @@ -2364,7 +2435,7 @@ impl Xhci { let endp_desc = port_state .dev_desc .as_ref() - .unwrap() + .ok_or(Error::new(EBADFD))? .config_descs .get(0) .ok_or(Error::new(EIO))? @@ -2409,8 +2480,12 @@ impl Xhci { if self.get_endp_status(port_num, endp_num)? != EndpointStatus::Halted { return Err(Error::new(EPROTO)); } - // Change the endpoint state from anything, but most likely HALTED (otherwise resetting - // would be quite meaningless), to stopped. + + let endp_idx = endp_num.checked_sub(1).ok_or(Error::new(EIO))?; + let port_state = self.port_states.get(&port_num).ok_or(Error::new(EBADFD))?; + let endp_desc = port_state.get_endp_desc(endp_idx).ok_or(Error::new(EBADF))?; + let usb_endp_addr = endp_desc.address; + self.reset_endpoint(port_num, endp_num, false).await?; self.restart_endpoint(port_num, endp_num).await?; @@ -2418,10 +2493,10 @@ impl Xhci { self.device_req_no_data( port_num, usb::Setup { - kind: 0b0000_0010, // endpoint recipient - request: 0x01, // CLEAR_FEATURE - value: 0x00, // ENDPOINT_HALT - index: 0, // TODO: interface num + kind: 0b0000_0010, + request: 0x01, + value: 0x00, + index: u16::from(usb_endp_addr), length: 0, }, ) @@ -2456,7 +2531,7 @@ impl Xhci { let endp_desc = port_state .dev_desc .as_ref() - .unwrap() + .ok_or(Error::new(EBADFD))? .config_descs .get(0) .ok_or(Error::new(EIO))? @@ -2475,7 +2550,7 @@ impl Xhci { Self::def_control_endp_doorbell() }; - self.dbs.lock().unwrap()[slot as usize].write(doorbell); + self.dbs.lock().unwrap_or_else(|e| e.into_inner())[slot as usize].write(doorbell); self.set_tr_deque_ptr(port_num, endp_num, deque_ptr_and_cycle) .await?; @@ -2483,13 +2558,14 @@ impl Xhci { Ok(()) } pub fn endp_direction(&self, port_num: PortId, endp_num: u8) -> Result { + let endp_idx = endp_num.checked_sub(1).ok_or(Error::new(EIO))? as usize; Ok(self .port_states .get(&port_num) .ok_or(Error::new(EIO))? .dev_desc .as_ref() - .unwrap() + .ok_or(Error::new(EBADFD))? .config_descs .first() .ok_or(Error::new(EIO))? .interface_descs .first() .ok_or(Error::new(EIO))? .endpoints - .get(endp_num as usize) + .get(endp_idx) .ok_or(Error::new(EIO))? .direction()) } @@ -2530,7 +2605,7 @@ impl Xhci { slot, ) }) - .await; + .await?; //self.event_handler_finished(); handle_event_trb("SET_TR_DEQUEUE_PTR", &event_trb, &command_trb) @@ -2724,7 +2799,7 @@ impl Xhci { let mut cursor = io::Cursor::new(buf); serde_json::to_writer(&mut cursor, &res).or(Err(Error::new(EIO)))?; - Ok(cursor.seek(io::SeekFrom::Current(0)).unwrap() as usize) + Ok(cursor.position() as usize) } pub async fn on_read_endp_data( &self, @@ -2803,7 +2878,7 @@ impl Xhci { pub fn event_handler_finished(&self) { trace!("Event handler finished"); // write 1 to EHB to clear it - self.run.lock().unwrap().ints[0] + self.run.lock().unwrap_or_else(|e| e.into_inner()).ints[0] .erdp_low .writef(1 << 3, true); }