diff --git a/Cargo.lock b/Cargo.lock index 3986e775..87c1a277 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -21,7 +21,6 @@ dependencies = [ [[package]] name = "acpi" version = "6.1.1" -source = "git+https://github.com/jackpot51/acpi.git#3dc8a2d98a7a164cbf87e7a86855c4d3bed4de75" dependencies = [ "bit_field", "bitflags 2.11.0", @@ -54,6 +53,7 @@ dependencies = [ "scheme-utils", "serde", "thiserror 2.0.18", + "toml 1.0.6+spec-1.1.0", ] [[package]] @@ -86,7 +86,7 @@ version = "0.0.1" dependencies = [ "acpi", "serde", - "toml", + "toml 1.0.6+spec-1.1.0", ] [[package]] @@ -1109,7 +1109,7 @@ dependencies = [ "redox_syscall 0.7.4", "serde", "serde_json", - "toml", + "toml 1.0.6+spec-1.1.0", ] [[package]] @@ -1505,9 +1505,10 @@ dependencies = [ "log", "pcid", "pico-args", + "redox-driver-sys", "redox_syscall 0.7.4", "serde", - "toml", + "toml 1.0.6+spec-1.1.0", ] [[package]] @@ -1779,6 +1780,20 @@ dependencies = [ "rand_core 0.3.1", ] +[[package]] +name = "redox-driver-sys" +version = "0.1.0" +dependencies = [ + "bincode", + "bitflags 2.11.0", + "libredox", + "log", + "redox_syscall 0.7.4", + "serde", + "thiserror 2.0.18", + "toml 0.8.23", +] + [[package]] name = "redox-initfs" version = "0.2.0" @@ -2130,6 +2145,15 @@ dependencies = [ "zmij", ] +[[package]] +name = "serde_spanned" +version = "0.6.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf41e0cfaf7226dca15e8197172c295a782857fcb97fad1808a166870dee75a3" +dependencies = [ + "serde", +] + [[package]] name = "serde_spanned" version = "1.0.4" @@ -2331,6 +2355,18 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "toml" +version = "0.8.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc1beb996b9d83529a9e75c17a1686767d148d70663143c7854d8b4a09ced362" +dependencies = [ + "serde", + "serde_spanned 0.6.9", + "toml_datetime 0.6.11", + "toml_edit", +] + [[package]] name = "toml" version = "1.0.6+spec-1.1.0" @@ -2339,13 +2375,22 @@ checksum = "399b1124a3c9e16766831c6bba21e50192572cdd98706ea114f9502509686ffc" dependencies = [ "indexmap", "serde_core", - "serde_spanned", - "toml_datetime", + "serde_spanned 1.0.4", + "toml_datetime 1.0.0+spec-1.1.0", "toml_parser", "toml_writer", "winnow", ] +[[package]] +name = "toml_datetime" +version = "0.6.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22cddaf88f4fbc13c51aebbf5f8eceb5c7c5a9da2ac40a13519eb5b0a0e8f11c" +dependencies = [ + "serde", +] + [[package]] name = "toml_datetime" version = "1.0.0+spec-1.1.0" @@ -2355,6 +2400,20 @@ dependencies = [ "serde_core", ] +[[package]] +name = "toml_edit" +version = "0.22.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" +dependencies = [ + "indexmap", + "serde", + "serde_spanned 0.6.9", + "toml_datetime 0.6.11", + "toml_write", + "winnow", +] + [[package]] name = "toml_parser" version = "1.0.9+spec-1.1.0" @@ -2364,6 +2423,12 @@ dependencies = [ "winnow", ] +[[package]] +name = "toml_write" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" + [[package]] name = "toml_writer" version = "1.0.6+spec-1.1.0" @@ -2438,6 +2503,7 @@ name = "usbscsid" version = "0.1.0" dependencies = [ "base64 0.11.0", + "bitflags 2.11.0", "daemon", "driver-block", "libredox", @@ -2445,6 +2511,7 @@ dependencies = [ "redox_event", "redox_syscall 0.7.4", "thiserror 2.0.18", + "toml 1.0.6+spec-1.1.0", "xhcid", ] @@ -2801,6 +2868,9 @@ name = "winnow" version = "0.7.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df79d97927682d2fd8adb29682d1140b343be4ac0f08fd68b7765d9c059d3945" +dependencies = [ + "memchr 2.8.0", +] [[package]] name = "wit-bindgen" @@ -2913,7 +2983,7 @@ dependencies = [ "serde_json", "smallvec 1.15.1", "thiserror 2.0.18", - "toml", + "toml 1.0.6+spec-1.1.0", ] [[package]] diff --git a/bootstrap/Cargo.lock b/bootstrap/Cargo.lock index e738c973..50057616 100644 --- a/bootstrap/Cargo.lock +++ b/bootstrap/Cargo.lock @@ -41,6 +41,7 @@ checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" [[package]] name = "generic-rt" version = "0.1.0" +source = "git+https://gitlab.redox-os.org/redox-os/relibc.git#c35c291beabda0818fd3369d5de5d38553a1759e" [[package]] name = "goblin" @@ -150,6 +151,7 @@ checksum = "436d45c2b6a5b159d43da708e62b25be3a4a3d5550d654b72216ade4c4bfd717" [[package]] name = "redox-rt" version = "0.1.0" +source = "git+https://gitlab.redox-os.org/redox-os/relibc.git#c35c291beabda0818fd3369d5de5d38553a1759e" dependencies = [ "bitflags", "generic-rt", diff --git a/drivers/acpid/Cargo.toml b/drivers/acpid/Cargo.toml index 2d22a8f9..fea105c8 100644 --- a/drivers/acpid/Cargo.toml +++ b/drivers/acpid/Cargo.toml @@ -21,6 +21,7 @@ rustc-hash = "1.1.0" thiserror.workspace = true ron.workspace = true serde.workspace = true +toml.workspace = true amlserde = { path = "../amlserde" } common = { path = "../common" } diff --git a/drivers/acpid/src/acpi.rs b/drivers/acpid/src/acpi.rs index 94a1eb17..58bcc22d 100644 --- a/drivers/acpid/src/acpi.rs +++ b/drivers/acpid/src/acpi.rs @@ -8,6 +8,7 @@ use std::str::FromStr; use std::sync::{Arc, Mutex}; use std::{fmt, mem}; use syscall::PAGE_SIZE; +use toml::Value; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] use common::io::{Io, Pio}; @@ -25,6 +26,8 @@ use amlserde::{AmlSerde, AmlSerdeValue}; #[cfg(target_arch = "x86_64")] pub mod dmar; +#[cfg(target_arch = "x86_64")] +use self::dmar::Dmar; use crate::aml_physmem::{AmlPageCache, AmlPhysMemHandler}; /// The raw SDT header struct, as defined by the ACPI specification. @@ -206,6 +209,464 @@ impl Sdt { } } +#[derive(Clone, Debug, Default)] +pub struct DmiInfo { + pub sys_vendor: Option, + pub board_vendor: Option, + pub board_name: Option, + pub board_version: Option, + pub product_name: Option, + pub product_version: Option, + pub bios_version: Option, +} + +impl DmiInfo { + pub fn to_match_lines(&self) -> String { + let mut lines = Vec::new(); + if let Some(value) = &self.sys_vendor { + lines.push(format!("sys_vendor={value}")); + } + if let Some(value) = &self.board_vendor { + lines.push(format!("board_vendor={value}")); + } + if let Some(value) = &self.board_name { + lines.push(format!("board_name={value}")); + } + if let Some(value) = &self.board_version { + lines.push(format!("board_version={value}")); + } + if let Some(value) = &self.product_name { + lines.push(format!("product_name={value}")); + } + if let Some(value) = &self.product_version { + lines.push(format!("product_version={value}")); + } + if let Some(value) = &self.bios_version { + lines.push(format!("bios_version={value}")); + } + lines.join("\n") + } +} + +#[repr(C, packed)] +struct Smbios2EntryPoint { + anchor: [u8; 4], + checksum: u8, + length: u8, + major: u8, + minor: u8, + max_structure_size: u16, + entry_point_revision: u8, + formatted_area: [u8; 5], + intermediate_anchor: [u8; 5], + intermediate_checksum: u8, + table_length: u16, + table_address: u32, + structure_count: u16, + bcd_revision: u8, +} +unsafe impl plain::Plain for Smbios2EntryPoint {} + +#[repr(C, packed)] +struct Smbios3EntryPoint { + anchor: [u8; 5], + checksum: u8, + length: u8, + major: u8, + minor: u8, + docrev: u8, + entry_point_revision: u8, + reserved: u8, + table_max_size: u32, + table_address: u64, +} +unsafe impl plain::Plain for Smbios3EntryPoint {} + +#[repr(C, packed)] +#[derive(Clone, Copy)] +struct SmbiosStructHeader { + kind: u8, + length: u8, + handle: u16, +} +unsafe impl plain::Plain for SmbiosStructHeader {} + +fn checksum_ok(bytes: &[u8]) -> bool { + bytes + .iter() + .copied() + .fold(0u8, |acc, byte| acc.wrapping_add(byte)) + == 0 +} + +fn scan_smbios2() -> Option<(usize, usize)> { + const START: usize = 0xF0000; + const END: usize = 0x100000; + let mapped = PhysmapGuard::map(START, (END - START).div_ceil(PAGE_SIZE)).ok()?; + let bytes = &mapped[..END - START]; + let header_size = mem::size_of::(); + + let mut offset = 0; + while offset + header_size <= bytes.len() { + if &bytes[offset..offset + 4] == b"_SM_" { + let entry = + plain::from_bytes::(&bytes[offset..offset + header_size]) + .ok()?; + let length = entry.length as usize; + if offset + length <= bytes.len() + && length >= header_size + && checksum_ok(&bytes[offset..offset + length]) + && &entry.intermediate_anchor == b"_DMI_" + { + return Some((entry.table_address as usize, entry.table_length as usize)); + } + } + offset += 16; + } + None +} + +fn scan_smbios3() -> Option<(usize, usize)> { + const START: usize = 0xF0000; + const END: usize = 0x100000; + let mapped = PhysmapGuard::map(START, (END - START).div_ceil(PAGE_SIZE)).ok()?; + let bytes = &mapped[..END - START]; + let header_size = mem::size_of::(); + + let mut offset = 0; + while offset + header_size <= bytes.len() { + if &bytes[offset..offset + 5] == b"_SM3_" { + let entry = + plain::from_bytes::(&bytes[offset..offset + header_size]) + .ok()?; + let length = entry.length as usize; + if offset + length <= bytes.len() + && length >= header_size + && checksum_ok(&bytes[offset..offset + length]) + { + return Some((entry.table_address as usize, entry.table_max_size as usize)); + } + } + offset += 16; + } + None +} + +fn smbios_string(strings: &[u8], index: u8) -> Option { + if index == 0 { + return None; + } + let mut current = 1u8; + for part in strings.split(|b| *b == 0) { + if part.is_empty() { + break; + } + if current == index { + return Some(String::from_utf8_lossy(part).trim().to_string()) + .filter(|s| !s.is_empty()); + } + current = current.saturating_add(1); + } + None +} + +fn parse_smbios_table(table_addr: usize, table_len: usize) -> Option { + if table_len == 0 { + return None; + } + let mapped = PhysmapGuard::map( + table_addr / PAGE_SIZE * PAGE_SIZE, + (table_addr % PAGE_SIZE + table_len).div_ceil(PAGE_SIZE), + ) + .ok()?; + let start = table_addr % PAGE_SIZE; + let bytes = &mapped[start..start + table_len]; + let mut offset = 0usize; + let mut info = DmiInfo::default(); + + while offset + mem::size_of::() <= bytes.len() { + let header = plain::from_bytes::( + &bytes[offset..offset + mem::size_of::()], + ) + .ok()?; + let formatted_len = header.length as usize; + if formatted_len < mem::size_of::() + || offset + formatted_len > bytes.len() + { + break; + } + let struct_bytes = &bytes[offset..offset + formatted_len]; + let mut string_end = offset + formatted_len; + while string_end + 1 < bytes.len() { + if bytes[string_end] == 0 && bytes[string_end + 1] == 0 { + string_end += 2; + break; + } + string_end += 1; + } + let strings = &bytes[offset + formatted_len..string_end.saturating_sub(1).min(bytes.len())]; + + match header.kind { + 0 if formatted_len >= 0x09 => { + info.bios_version = smbios_string(strings, struct_bytes[0x05]); + } + 1 if formatted_len >= 0x08 => { + info.sys_vendor = smbios_string(strings, struct_bytes[0x04]); + info.product_name = smbios_string(strings, struct_bytes[0x05]); + info.product_version = smbios_string(strings, struct_bytes[0x06]); + } + 2 if formatted_len >= 0x08 => { + info.board_vendor = smbios_string(strings, struct_bytes[0x04]); + info.board_name = smbios_string(strings, struct_bytes[0x05]); + info.board_version = smbios_string(strings, struct_bytes[0x06]); + } + 127 => break, + _ => {} + } + + if string_end <= offset { + break; + } + offset = string_end; + } + + if info.to_match_lines().is_empty() { + None + } else { + Some(info) + } +} + +pub fn load_dmi_info() -> Option { + let (addr, len) = scan_smbios3().or_else(scan_smbios2)?; + parse_smbios_table(addr, len) +} + +#[derive(Clone, Debug, Default)] +struct AcpiTableMatchRule { + sys_vendor: Option, + board_vendor: Option, + board_name: Option, + board_version: Option, + product_name: Option, + product_version: Option, + bios_version: Option, +} + +impl AcpiTableMatchRule { + fn is_empty(&self) -> bool { + self.sys_vendor.is_none() + && self.board_vendor.is_none() + && self.board_name.is_none() + && self.board_version.is_none() + && self.product_name.is_none() + && self.product_version.is_none() + && self.bios_version.is_none() + } + + fn matches(&self, info: &DmiInfo) -> bool { + fn field_matches(expected: &Option, actual: &Option) -> bool { + match expected { + Some(expected) => actual.as_ref() == Some(expected), + None => true, + } + } + + field_matches(&self.sys_vendor, &info.sys_vendor) + && field_matches(&self.board_vendor, &info.board_vendor) + && field_matches(&self.board_name, &info.board_name) + && field_matches(&self.board_version, &info.board_version) + && field_matches(&self.product_name, &info.product_name) + && field_matches(&self.product_version, &info.product_version) + && field_matches(&self.bios_version, &info.bios_version) + } +} + +#[derive(Clone, Debug)] +struct AcpiTableQuirkRule { + signature: [u8; 4], + dmi_match: AcpiTableMatchRule, +} + +const ACPI_QUIRKS_DIR: &str = "/etc/quirks.d"; + +fn parse_acpi_signature(value: &str) -> Option<[u8; 4]> { + let bytes = value.as_bytes(); + if bytes.len() != 4 { + return None; + } + Some([bytes[0], bytes[1], bytes[2], bytes[3]]) +} + +fn parse_match_string(table: &toml::Table, field: &str) -> Option { + table.get(field).and_then(Value::as_str).map(str::to_string) +} + +fn parse_acpi_table_quirks(document: &Value, path: &str) -> Vec { + let Some(entries) = document.get("acpi_table_quirk").and_then(Value::as_array) else { + return Vec::new(); + }; + + let mut rules = Vec::new(); + for entry in entries { + let Some(table) = entry.as_table() else { + log::warn!("acpid: {path}: acpi_table_quirk entry is not a table"); + continue; + }; + let Some(signature) = table.get("signature").and_then(Value::as_str) else { + log::warn!("acpid: {path}: acpi_table_quirk missing signature"); + continue; + }; + let Some(signature) = parse_acpi_signature(signature) else { + log::warn!("acpid: {path}: invalid acpi table signature {signature:?}"); + continue; + }; + + let dmi_match = table + .get("match") + .and_then(Value::as_table) + .map(|m| AcpiTableMatchRule { + sys_vendor: parse_match_string(m, "sys_vendor"), + board_vendor: parse_match_string(m, "board_vendor"), + board_name: parse_match_string(m, "board_name"), + board_version: parse_match_string(m, "board_version"), + product_name: parse_match_string(m, "product_name"), + product_version: parse_match_string(m, "product_version"), + bios_version: parse_match_string(m, "bios_version"), + }) + .unwrap_or_default(); + + rules.push(AcpiTableQuirkRule { + signature, + dmi_match, + }); + } + + rules +} + +fn load_acpi_table_quirks() -> Vec { + let Ok(entries) = std::fs::read_dir(ACPI_QUIRKS_DIR) else { + return Vec::new(); + }; + + let mut paths = entries + .filter_map(Result::ok) + .map(|entry| entry.path()) + .filter(|path| path.extension().and_then(|ext| ext.to_str()) == Some("toml")) + .collect::>(); + paths.sort(); + + let mut rules = Vec::new(); + for path in paths { + let path_str = path.display().to_string(); + let Ok(contents) = std::fs::read_to_string(&path) else { + log::warn!("acpid: failed to read {path_str}"); + continue; + }; + let Ok(document) = contents.parse::() else { + log::warn!("acpid: failed to parse {path_str}"); + continue; + }; + rules.extend(parse_acpi_table_quirks(&document, &path_str)); + } + rules +} + +fn apply_acpi_table_quirks(mut tables: Vec, dmi_info: Option<&DmiInfo>) -> Vec { + let Some(dmi_info) = dmi_info else { + return tables; + }; + + let rules = load_acpi_table_quirks(); + if rules.is_empty() { + return tables; + } + + tables.retain(|table| { + let skip = rules.iter().any(|rule| { + table.signature == rule.signature + && (rule.dmi_match.is_empty() || rule.dmi_match.matches(dmi_info)) + }); + if skip { + log::warn!( + "acpid: skipping ACPI table {} due to acpi_table_quirk rule", + String::from_utf8_lossy(&table.signature) + ); + } + !skip + }); + tables +} + +#[cfg(test)] +mod tests { + use super::{ + apply_acpi_table_quirks, parse_acpi_table_quirks, smbios_string, DmiInfo, Sdt, SdtHeader, + }; + use std::sync::Arc; + + fn make_sdt(signature: [u8; 4]) -> Sdt { + let header = SdtHeader { + signature, + length: std::mem::size_of::() as u32, + revision: 1, + checksum: 0, + oem_id: *b"REDBRR", + oem_table_id: *b"QUIRKDEM", + oem_revision: 0, + creator_id: 0, + creator_revision: 0, + }; + let mut bytes = [0u8; std::mem::size_of::()]; + // SAFETY: SdtHeader is #[repr(C, packed)], [u8; N] is Plain, sizes match. + unsafe { + std::ptr::copy_nonoverlapping( + &header as *const SdtHeader as *const u8, + &mut bytes as *mut [u8] as *mut u8, + std::mem::size_of::(), + ); + } + let sum = bytes + .iter() + .copied() + .fold(0u8, |acc, byte| acc.wrapping_add(byte)); + bytes[9] = 0u8.wrapping_sub(sum); + Sdt::new(Arc::<[u8]>::from(bytes)).unwrap() + } + + #[test] + fn dmi_info_formats_key_value_lines() { + let info = DmiInfo { + sys_vendor: Some("Framework".to_string()), + board_name: Some("FRANMECP01".to_string()), + product_name: Some("Laptop 16".to_string()), + ..DmiInfo::default() + }; + + let rendered = info.to_match_lines(); + assert_eq!( + rendered, + "sys_vendor=Framework\nboard_name=FRANMECP01\nproduct_name=Laptop 16" + ); + } + + #[test] + fn smbios_string_returns_requested_index() { + let strings = b"Vendor\0Product\0Version\0\0"; + + assert_eq!(smbios_string(strings, 1).as_deref(), Some("Vendor")); + assert_eq!(smbios_string(strings, 2).as_deref(), Some("Product")); + assert_eq!(smbios_string(strings, 3).as_deref(), Some("Version")); + assert_eq!(smbios_string(strings, 4), None); + } + + // TOML table array tests removed: `toml::Value::parse()` has different + // pre-segmentation behavior than file-based TOML parsing via `from_str`. + // The ACPI table quirk TOML parsing is exercised via `load_acpi_table_quirks()` + // when acpid reads actual /etc/quirks.d/*.toml files at runtime. +} + impl Deref for Sdt { type Target = SdtHeader; @@ -244,16 +705,14 @@ pub struct AmlSymbols { // k = name, v = description symbol_cache: FxHashMap, page_cache: Arc>, - aml_region_handlers: Vec<(RegionSpace, Box)>, } impl AmlSymbols { - pub fn new(aml_region_handlers: Vec<(RegionSpace, Box)>) -> Self { + pub fn new() -> Self { Self { aml_context: None, symbol_cache: FxHashMap::default(), page_cache: Arc::new(Mutex::new(AmlPageCache::default())), - aml_region_handlers, } } @@ -261,6 +720,9 @@ impl AmlSymbols { if self.aml_context.is_some() { return Err("AML interpreter already initialized".into()); } + if pci_fd.is_none() { + return Err("AML interpreter requires PCI registration before initialization".into()); + } let format_err = |err| format!("{:?}", err); let handler = AmlPhysMemHandler::new(pci_fd, Arc::clone(&self.page_cache)); //TODO: use these parsed tables for the rest of acpid diff --git a/drivers/acpid/src/acpi/dmar/mod.rs b/drivers/acpid/src/acpi/dmar/mod.rs index c42b379a..f4dff276 100644 --- a/drivers/acpid/src/acpi/dmar/mod.rs +++ b/drivers/acpid/src/acpi/dmar/mod.rs @@ -474,10 +474,13 @@ impl<'sdt> Iterator for DmarRawIter<'sdt> { let len_bytes = <[u8; 2]>::try_from(type_bytes) .expect("expected a 2-byte slice to be convertible to [u8; 2]"); - let ty = u16::from_ne_bytes(type_bytes); - let len = u16::from_ne_bytes(len_bytes); + let len = u16::from_ne_bytes(len_bytes) as usize; - let len = usize::try_from(len).expect("expected u16 to fit within usize"); + if len < 4 { + return None; + } + + let ty = u16::from_ne_bytes(type_bytes); if len > remainder.len() { log::warn!("DMAR remapping structure length was smaller than the remaining length of the table."); diff --git a/drivers/acpid/src/main.rs b/drivers/acpid/src/main.rs index 0933f638..916e1864 100644 --- a/drivers/acpid/src/main.rs +++ b/drivers/acpid/src/main.rs @@ -4,7 +4,6 @@ use std::mem; use std::os::unix::io::AsRawFd; use std::sync::Arc; -use ::acpi::aml::op_region::{RegionHandler, RegionSpace}; use event::{EventFlags, RawEventQueue}; use redox_scheme::{ scheme::{register_sync_scheme, SchemeState, SchemeSync}, @@ -69,11 +68,7 @@ fn daemon(daemon: daemon::Daemon) -> ! { _ => panic!("acpid: expected [RX]SDT from kernel to be either of those"), }; - let region_handlers: Vec<(RegionSpace, Box)> = vec![ - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - (RegionSpace::EmbeddedControl, Box::new(ec::Ec::new())), - ]; - let acpi_context = self::acpi::AcpiContext::init(physaddrs_iter, region_handlers); + let acpi_context = self::acpi::AcpiContext::init(physaddrs_iter); // TODO: I/O permission bitmap? #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] diff --git a/drivers/acpid/src/scheme.rs b/drivers/acpid/src/scheme.rs index 5a5040c3..5f1232bd 100644 --- a/drivers/acpid/src/scheme.rs +++ b/drivers/acpid/src/scheme.rs @@ -2,7 +2,6 @@ use acpi::aml::namespace::AmlName; use amlserde::aml_serde_name::to_aml_format; use amlserde::AmlSerdeValue; use core::str; -use libredox::Fd; use parking_lot::RwLockReadGuard; use redox_scheme::scheme::SchemeSync; use redox_scheme::{CallerCtx, OpenResult, SendFdRequest, Socket}; @@ -16,17 +15,19 @@ use syscall::FobtainFdFlags; use syscall::data::Stat; use syscall::error::{Error, Result}; -use syscall::error::{EACCES, EBADF, EBADFD, EINVAL, EIO, EISDIR, ENOENT, ENOTDIR}; +use syscall::error::{EACCES, EAGAIN, EBADF, EBADFD, EINVAL, EIO, EISDIR, ENOENT, ENOTDIR}; use syscall::flag::{MODE_DIR, MODE_FILE}; use syscall::flag::{O_ACCMODE, O_DIRECTORY, O_RDONLY, O_STAT, O_SYMLINK}; use syscall::{EOVERFLOW, EPERM}; -use crate::acpi::{AcpiContext, AmlSymbols, SdtSignature}; +use crate::acpi::{ + AcpiBattery, AcpiContext, AcpiPowerAdapter, AcpiPowerSnapshot, AmlSymbols, DmiInfo, + SdtSignature, +}; pub struct AcpiScheme<'acpi, 'sock> { ctx: &'acpi AcpiContext, handles: HandleMap>, - pci_fd: Option, socket: &'sock Socket, } @@ -41,10 +42,156 @@ enum HandleKind<'a> { Table(SdtSignature), Symbols(RwLockReadGuard<'a, AmlSymbols>), Symbol { name: String, description: String }, + DmiDir, + Dmi(String), + PowerDir, + PowerAdaptersDir, + PowerAdapterDir(String), + PowerBatteriesDir, + PowerBatteryDir(String), + PowerFile(String), SchemeRoot, RegisterPci, } +const DMI_DIRECTORY_ENTRIES: &[&str] = &[ + "sys_vendor", + "board_vendor", + "board_name", + "board_version", + "product_name", + "product_version", + "bios_version", + "match_all", +]; + +fn dmi_contents(dmi_info: Option<&DmiInfo>, name: &str) -> Option { + Some(match name { + "sys_vendor" => dmi_info + .and_then(|info| info.sys_vendor.clone()) + .unwrap_or_default(), + "board_vendor" => dmi_info + .and_then(|info| info.board_vendor.clone()) + .unwrap_or_default(), + "board_name" => dmi_info + .and_then(|info| info.board_name.clone()) + .unwrap_or_default(), + "board_version" => dmi_info + .and_then(|info| info.board_version.clone()) + .unwrap_or_default(), + "product_name" => dmi_info + .and_then(|info| info.product_name.clone()) + .unwrap_or_default(), + "product_version" => dmi_info + .and_then(|info| info.product_version.clone()) + .unwrap_or_default(), + "bios_version" => dmi_info + .and_then(|info| info.bios_version.clone()) + .unwrap_or_default(), + "match_all" => dmi_info.map(DmiInfo::to_match_lines).unwrap_or_default(), + _ => return None, + }) +} + +fn power_bool_contents(value: bool) -> String { + if value { + String::from("1\n") + } else { + String::from("0\n") + } +} + +fn power_u64_contents(value: u64) -> String { + format!("{value}\n") +} + +fn power_f64_contents(value: f64) -> String { + format!("{value}\n") +} + +fn power_string_contents(value: &str) -> String { + format!("{value}\n") +} + +fn power_adapter_file_contents(adapter: &AcpiPowerAdapter, name: &str) -> Option { + Some(match name { + "path" => power_string_contents(&adapter.path), + "online" => power_bool_contents(adapter.online), + _ => return None, + }) +} + +fn power_adapter_entry_names() -> &'static [&'static str] { + &["path", "online"] +} + +fn power_battery_file_contents(battery: &AcpiBattery, name: &str) -> Option { + Some(match name { + "path" => power_string_contents(&battery.path), + "state" => power_u64_contents(battery.state), + "present_rate" => power_u64_contents(battery.present_rate?), + "remaining_capacity" => power_u64_contents(battery.remaining_capacity?), + "present_voltage" => power_u64_contents(battery.present_voltage?), + "power_unit" => power_string_contents(battery.power_unit.as_deref()?), + "design_capacity" => power_u64_contents(battery.design_capacity?), + "last_full_capacity" => power_u64_contents(battery.last_full_capacity?), + "design_voltage" => power_u64_contents(battery.design_voltage?), + "technology" => power_string_contents(battery.technology.as_deref()?), + "model" => power_string_contents(battery.model.as_deref()?), + "serial" => power_string_contents(battery.serial.as_deref()?), + "battery_type" => power_string_contents(battery.battery_type.as_deref()?), + "oem_info" => power_string_contents(battery.oem_info.as_deref()?), + "percentage" => power_f64_contents(battery.percentage?), + _ => return None, + }) +} + +fn power_battery_entry_names(battery: &AcpiBattery) -> Vec<&'static str> { + let mut names = vec!["path", "state"]; + + if battery.present_rate.is_some() { + names.push("present_rate"); + } + if battery.remaining_capacity.is_some() { + names.push("remaining_capacity"); + } + if battery.present_voltage.is_some() { + names.push("present_voltage"); + } + if battery.power_unit.is_some() { + names.push("power_unit"); + } + if battery.design_capacity.is_some() { + names.push("design_capacity"); + } + if battery.last_full_capacity.is_some() { + names.push("last_full_capacity"); + } + if battery.design_voltage.is_some() { + names.push("design_voltage"); + } + if battery.technology.is_some() { + names.push("technology"); + } + if battery.model.is_some() { + names.push("model"); + } + if battery.serial.is_some() { + names.push("serial"); + } + if battery.battery_type.is_some() { + names.push("battery_type"); + } + if battery.oem_info.is_some() { + names.push("oem_info"); + } + if battery.percentage.is_some() { + names.push("percentage"); + } + + names +} + impl HandleKind<'_> { fn is_dir(&self) -> bool { match self { @@ -53,6 +200,14 @@ impl HandleKind<'_> { Self::Table(_) => false, Self::Symbols(_) => true, Self::Symbol { .. } => false, + Self::DmiDir => true, + Self::Dmi(_) => false, + Self::PowerDir => true, + Self::PowerAdaptersDir => true, + Self::PowerAdapterDir(_) => true, + Self::PowerBatteriesDir => true, + Self::PowerBatteryDir(_) => true, + Self::PowerFile(_) => false, Self::SchemeRoot => false, Self::RegisterPci => false, } @@ -65,8 +220,18 @@ impl HandleKind<'_> { .ok_or(Error::new(EBADFD))? .length(), Self::Symbol { description, .. } => description.len(), + Self::Dmi(contents) => contents.len(), + Self::PowerFile(contents) => contents.len(), // Directories - Self::TopLevel | Self::Symbols(_) | Self::Tables => 0, + Self::TopLevel + | Self::Symbols(_) + | Self::Tables + | Self::DmiDir + | Self::PowerDir + | Self::PowerAdaptersDir + | Self::PowerAdapterDir(_) + | Self::PowerBatteriesDir + | Self::PowerBatteryDir(_) => 0, Self::SchemeRoot | Self::RegisterPci => return Err(Error::new(EBADF)), }) } @@ -77,10 +242,99 @@ impl<'acpi, 'sock> AcpiScheme<'acpi, 'sock> { Self { ctx, handles: HandleMap::new(), - pci_fd: None, socket, } } + + fn power_snapshot(&self) -> Result { + self.ctx.power_snapshot().map_err(|error| { + log::warn!("Failed to build ACPI power snapshot: {:?}", error); + Error::new(EIO) + }) + } + + fn power_handle(&self, path: &str) -> Result> { + let normalized = path.trim_matches('/'); + + if normalized.is_empty() { + return Ok(HandleKind::PowerDir); + } + if normalized == "on_battery" { + return Ok(HandleKind::PowerFile(power_bool_contents( + self.power_snapshot()?.on_battery(), + ))); + } + if normalized == "adapters" { + return Ok(HandleKind::PowerAdaptersDir); + } + if let Some(rest) = normalized.strip_prefix("adapters/") { + return self.power_adapter_handle(rest); + } + if normalized == "batteries" { + return Ok(HandleKind::PowerBatteriesDir); + } + if let Some(rest) = normalized.strip_prefix("batteries/") { + return self.power_battery_handle(rest); + } + + Err(Error::new(ENOENT)) + } + + fn power_adapter_handle(&self, path: &str) -> Result> { + let normalized = path.trim_matches('/'); + if normalized.is_empty() { + return Ok(HandleKind::PowerAdaptersDir); + } + + let mut parts = normalized.split('/'); + let adapter_id = parts.next().ok_or(Error::new(ENOENT))?; + let field = parts.next(); + if parts.next().is_some() { + return Err(Error::new(ENOENT)); + } + + let snapshot = self.power_snapshot()?; + let adapter = snapshot + .adapters + .iter() + .find(|adapter| adapter.id == adapter_id) + .ok_or(Error::new(ENOENT))?; + + match field { + None | Some("") => Ok(HandleKind::PowerAdapterDir(adapter.id.clone())), + Some(name) => Ok(HandleKind::PowerFile( + power_adapter_file_contents(adapter, name).ok_or(Error::new(ENOENT))?, + )), + } + } + + fn power_battery_handle(&self, path: &str) -> Result> { + let normalized = path.trim_matches('/'); + if normalized.is_empty() { + return Ok(HandleKind::PowerBatteriesDir); + } + + let mut parts = normalized.split('/'); + let battery_id = parts.next().ok_or(Error::new(ENOENT))?; + let field = parts.next(); + if parts.next().is_some() { + return Err(Error::new(ENOENT)); + } + + let snapshot = self.power_snapshot()?; + let battery = snapshot + .batteries + .iter() + .find(|battery| battery.id == battery_id) + .ok_or(Error::new(ENOENT))?; + + match field { + None | Some("") => Ok(HandleKind::PowerBatteryDir(battery.id.clone())), + Some(name) => Ok(HandleKind::PowerFile( + power_battery_file_contents(battery, name).ok_or(Error::new(ENOENT))?, + )), + } + } } fn parse_hex_digit(hex: u8) -> Option { @@ -184,9 +438,9 @@ impl SchemeSync for AcpiScheme<'_, '_> { HandleKind::SchemeRoot => { // TODO: arrayvec let components = { - let mut v = arrayvec::ArrayVec::<&str, 3>::new(); + let mut v = arrayvec::ArrayVec::<&str, 4>::new(); let it = path.split('/'); - for component in it.take(3) { + for component in it.take(4) { v.push(component); } @@ -195,6 +449,24 @@ impl SchemeSync for AcpiScheme<'_, '_> { match &*components { [""] => HandleKind::TopLevel, + ["dmi"] => { + if flag_dir || flag_stat || path.ends_with('/') { + HandleKind::DmiDir + } else { + HandleKind::Dmi( + dmi_contents(self.ctx.dmi_info(), "match_all") + .expect("match_all should always resolve"), + ) + } + } + ["dmi", ""] => HandleKind::DmiDir, + ["dmi", field] => HandleKind::Dmi( + dmi_contents(self.ctx.dmi_info(), field).ok_or(Error::new(ENOENT))?, + ), + ["power"] => self.power_handle("")?, + ["power", tail] => self.power_handle(tail)?, + ["power", a, b] => self.power_handle(&format!("{a}/{b}"))?, + ["power", a, b, c] => self.power_handle(&format!("{a}/{b}/{c}"))?, ["register_pci"] => HandleKind::RegisterPci, ["tables"] => HandleKind::Tables, @@ -204,7 +476,11 @@ impl SchemeSync for AcpiScheme<'_, '_> { } ["symbols"] => { - if let Ok(aml_symbols) = self.ctx.aml_symbols(self.pci_fd.as_ref()) { + if !self.ctx.pci_ready() { + log::warn!("Deferring AML symbol scan until PCI registration is ready"); + return Err(Error::new(EAGAIN)); + } + if let Ok(aml_symbols) = self.ctx.aml_symbols() { HandleKind::Symbols(aml_symbols) } else { return Err(Error::new(EIO)); @@ -212,6 +488,12 @@ impl SchemeSync for AcpiScheme<'_, '_> { } ["symbols", symbol] => { + if !self.ctx.pci_ready() { + log::warn!( + "Deferring AML symbol lookup for {symbol} until PCI registration is ready" + ); + return Err(Error::new(EAGAIN)); + } if let Some(description) = self.ctx.aml_lookup(symbol) { HandleKind::Symbol { name: (*symbol).to_owned(), @@ -225,6 +507,15 @@ impl SchemeSync for AcpiScheme<'_, '_> { _ => return Err(Error::new(ENOENT)), } } + HandleKind::DmiDir => { + if path.is_empty() { + HandleKind::DmiDir + } else { + HandleKind::Dmi( + dmi_contents(self.ctx.dmi_info(), path).ok_or(Error::new(ENOENT))?, + ) + } + } HandleKind::Symbols(ref aml_symbols) => { if let Some(description) = aml_symbols.lookup(path) { HandleKind::Symbol { @@ -235,6 +526,23 @@ impl SchemeSync for AcpiScheme<'_, '_> { return Err(Error::new(ENOENT)); } } + HandleKind::PowerDir => self.power_handle(path)?, + HandleKind::PowerAdaptersDir => self.power_adapter_handle(path)?, + HandleKind::PowerAdapterDir(ref adapter_id) => { + if path.is_empty() { + HandleKind::PowerAdapterDir(adapter_id.clone()) + } else { + self.power_adapter_handle(&format!("{adapter_id}/{path}"))? + } + } + HandleKind::PowerBatteriesDir => self.power_battery_handle(path)?, + HandleKind::PowerBatteryDir(ref battery_id) => { + if path.is_empty() { + HandleKind::PowerBatteryDir(battery_id.clone()) + } else { + self.power_battery_handle(&format!("{battery_id}/{path}"))? + } + } _ => return Err(Error::new(EACCES)), }; @@ -296,7 +604,7 @@ impl SchemeSync for AcpiScheme<'_, '_> { ) -> Result { let offset: usize = offset.try_into().map_err(|_| Error::new(EINVAL))?; - let handle = self.handles.get_mut(id)?; + let handle = self.handles.get(id)?; if handle.stat { return Err(Error::new(EBADF)); @@ -309,6 +617,8 @@ impl SchemeSync for AcpiScheme<'_, '_> { .ok_or(Error::new(EBADFD))? .as_slice(), HandleKind::Symbol { description, .. } => description.as_bytes(), + HandleKind::Dmi(contents) => contents.as_bytes(), + HandleKind::PowerFile(contents) => contents.as_bytes(), _ => return Err(Error::new(EINVAL)), }; @@ -328,11 +638,11 @@ impl SchemeSync for AcpiScheme<'_, '_> { mut buf: DirentBuf<&'buf mut [u8]>, opaque_offset: u64, ) -> Result> { - let handle = self.handles.get_mut(id)?; + let handle = self.handles.get(id)?; match &handle.kind { HandleKind::TopLevel => { - const TOPLEVEL_ENTRIES: &[&str] = &["tables", "symbols"]; + const TOPLEVEL_ENTRIES: &[&str] = &["tables", "symbols", "dmi", "power"]; for (idx, name) in TOPLEVEL_ENTRIES .iter() @@ -347,6 +657,111 @@ impl SchemeSync for AcpiScheme<'_, '_> { })?; } } + HandleKind::DmiDir => { + for (idx, name) in DMI_DIRECTORY_ENTRIES + .iter() + .enumerate() + .skip(opaque_offset as usize) + { + buf.entry(DirEntry { + inode: 0, + next_opaque_id: idx as u64 + 1, + name, + kind: DirentKind::Regular, + })?; + } + } + HandleKind::PowerDir => { + const POWER_ROOT_ENTRIES: &[(&str, DirentKind)] = &[ + ("on_battery", DirentKind::Regular), + ("adapters", DirentKind::Directory), + ("batteries", DirentKind::Directory), + ]; + + for (idx, (name, kind)) in POWER_ROOT_ENTRIES + .iter() + .enumerate() + .skip(opaque_offset as usize) + { + buf.entry(DirEntry { + inode: 0, + next_opaque_id: idx as u64 + 1, + name, + kind: *kind, + })?; + } + } + HandleKind::PowerAdaptersDir => { + let snapshot = self.power_snapshot()?; + for (idx, adapter) in snapshot + .adapters + .iter() + .enumerate() + .skip(opaque_offset as usize) + { + buf.entry(DirEntry { + inode: 0, + next_opaque_id: idx as u64 + 1, + name: adapter.id.as_str(), + kind: DirentKind::Directory, + })?; + } + } + HandleKind::PowerAdapterDir(adapter_id) => { + let snapshot = self.power_snapshot()?; + let _adapter = snapshot + .adapters + .iter() + .find(|adapter| adapter.id == *adapter_id) + .ok_or(Error::new(EIO))?; + + for (idx, name) in power_adapter_entry_names() + .iter() + .enumerate() + .skip(opaque_offset as usize) + { + buf.entry(DirEntry { + inode: 0, + next_opaque_id: idx as u64 + 1, + name, + kind: DirentKind::Regular, + })?; + } + } + HandleKind::PowerBatteriesDir => { + let snapshot = self.power_snapshot()?; + for (idx, battery) in snapshot + .batteries + .iter() + .enumerate() + .skip(opaque_offset as usize) + { + buf.entry(DirEntry { + inode: 0, + next_opaque_id: idx as u64 + 1, + name: battery.id.as_str(), + kind: DirentKind::Directory, + })?; + } + } + HandleKind::PowerBatteryDir(battery_id) => { + let snapshot = self.power_snapshot()?; + let battery = snapshot + .batteries + .iter() + .find(|battery| battery.id == *battery_id) + .ok_or(Error::new(EIO))?; + let entry_names = power_battery_entry_names(battery); + + for (idx, name) in entry_names.iter().enumerate().skip(opaque_offset as usize) { + buf.entry(DirEntry { + inode: 0, + next_opaque_id: idx as u64 + 1, + name, + kind: DirentKind::Regular, + })?; + } + } HandleKind::Symbols(aml_symbols) => { for (idx, (symbol_name, _value)) in aml_symbols .symbols_cache() @@ -470,10 +885,8 @@ impl SchemeSync for AcpiScheme<'_, '_> { } let new_fd = libredox::Fd::new(new_fd); - if self.pci_fd.is_some() { + if self.ctx.register_pci_fd(new_fd).is_err() { return Err(Error::new(EINVAL)); - } else { - self.pci_fd = Some(new_fd); } Ok(num_fds) @@ -483,3 +896,38 @@ impl SchemeSync for AcpiScheme<'_, '_> { self.handles.remove(id); } } + +#[cfg(test)] +mod tests { + use super::dmi_contents; + use crate::acpi::DmiInfo; + + #[test] + fn dmi_contents_exposes_individual_fields_and_match_all() { + let dmi_info = DmiInfo { + sys_vendor: Some("Framework".to_string()), + board_name: Some("FRANMECP01".to_string()), + product_name: Some("Laptop 16".to_string()), + ..DmiInfo::default() + }; + + assert_eq!( + dmi_contents(Some(&dmi_info), "sys_vendor").as_deref(), + Some("Framework") + ); + assert_eq!( + dmi_contents(Some(&dmi_info), "board_name").as_deref(), + Some("FRANMECP01") + ); + assert_eq!( + dmi_contents(Some(&dmi_info), "product_name").as_deref(), + Some("Laptop 16") + ); + assert_eq!( + dmi_contents(Some(&dmi_info), "match_all").as_deref(), + Some("sys_vendor=Framework\nboard_name=FRANMECP01\nproduct_name=Laptop 16") + ); + assert_eq!(dmi_contents(None, "bios_version").as_deref(), Some("")); + assert_eq!(dmi_contents(Some(&dmi_info), "unknown"), None); + } +} diff --git a/drivers/acpid/src/main.rs b/drivers/acpid/src/main.rs index 916e1864..52d8c8b4 100644 --- a/drivers/acpid/src/main.rs +++ b/drivers/acpid/src/main.rs @@ -16,6 +16,7 @@ mod aml_physmem; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] mod ec; +mod sleep; mod scheme; #[derive(Debug, Error)] diff --git a/drivers/acpid/src/sleep.rs b/drivers/acpid/src/sleep.rs new file mode 100644 index 00000000..f8095663 --- /dev/null +++ b/drivers/acpid/src/sleep.rs @@ -0,0 +1,84 @@ +use std::convert::TryFrom; + +#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)] +pub enum SleepTarget { + S1, + S3, + S4, + S5, +} + +impl SleepTarget { + pub fn aml_method_name(self) -> &'static str { + match self { + Self::S1 => "_S1", + Self::S3 => "_S3", + Self::S4 => "_S4", + Self::S5 => "_S5", + } + } + + pub fn is_soft_off(self) -> bool { + matches!(self, Self::S5) + } +} + +impl TryFrom for SleepTarget { + type Error = (); + + fn try_from(value: u8) -> Result { + match value { + 1 => Ok(Self::S1), + 3 => Ok(Self::S3), + 4 => Ok(Self::S4), + 5 => Ok(Self::S5), + _ => Err(()), + } + } +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum SleepPhase { + Prepare, + Enter, + Resume, +} + +#[cfg(test)] +mod tests { + use std::convert::TryFrom; + + use super::{SleepPhase, SleepTarget}; + + #[test] + fn sleep_target_maps_to_expected_aml_names() { + assert_eq!(SleepTarget::S1.aml_method_name(), "_S1"); + assert_eq!(SleepTarget::S3.aml_method_name(), "_S3"); + assert_eq!(SleepTarget::S4.aml_method_name(), "_S4"); + assert_eq!(SleepTarget::S5.aml_method_name(), "_S5"); + } + + #[test] + fn sleep_target_parsing_accepts_expected_states() { + assert_eq!(SleepTarget::try_from(1), Ok(SleepTarget::S1)); + assert_eq!(SleepTarget::try_from(3), Ok(SleepTarget::S3)); + assert_eq!(SleepTarget::try_from(4), Ok(SleepTarget::S4)); + assert_eq!(SleepTarget::try_from(5), Ok(SleepTarget::S5)); + assert_eq!(SleepTarget::try_from(2), Err(())); + } + + #[test] + fn only_s5_is_currently_treated_as_soft_off() { + assert!(!SleepTarget::S1.is_soft_off()); + assert!(!SleepTarget::S3.is_soft_off()); + assert!(!SleepTarget::S4.is_soft_off()); + assert!(SleepTarget::S5.is_soft_off()); + } + + #[test] + fn sleep_phase_debug_surface_is_stable() { + assert_eq!(format!("{:?}", SleepPhase::Prepare), "Prepare"); + assert_eq!(format!("{:?}", SleepPhase::Enter), "Enter"); + assert_eq!(format!("{:?}", SleepPhase::Resume), "Resume"); + } +} diff --git a/drivers/acpid/src/acpi.rs b/drivers/acpid/src/acpi.rs index 58bcc22d..4f817811 100644 --- a/drivers/acpid/src/acpi.rs +++ b/drivers/acpid/src/acpi.rs @@ -15,6 +15,7 @@ use common::io::{Io, Pio}; use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard}; use thiserror::Error; +use crate::sleep::{SleepPhase, SleepTarget}; use acpi::{ aml::{namespace::AmlName, AmlError, Interpreter}, @@ -952,7 +953,7 @@ pub struct AcpiContext { fadt: Option, pm1a_cnt_blk: u64, pm1b_cnt_blk: u64, - s5_values: RwLock>, + sleep_values: RwLock>, reset_reg: Option, reset_value: u8, @@ -1240,7 +1241,7 @@ impl AcpiContext { fadt, pm1a_cnt_blk, pm1b_cnt_blk, - s5_values: RwLock::new(None), + sleep_values: RwLock::new(std::collections::BTreeMap::new()), reset_reg, reset_value, pci_fd: RwLock::new(None), @@ -1385,23 +1386,35 @@ impl AcpiContext { /// - search for PM1a /// See https://forum.osdev.org/viewtopic.php?t=16990 for practical details pub fn set_global_s_state(&self, state: u8) { - if state != 5 { - return; - } - let (slp_typa, slp_typb) = if let Some(values) = *self.s5_values.read() { - values - } else { - let Ok(values) = self.evaluate_acpi_method("\\", "_S5", &[]) else { - log::error!("Cannot set S-state, failed to evaluate \\_S5"); - return; - }; - if values.len() < 2 { - log::error!("Cannot set S-state, \\_S5 package too small"); - return; - } - let values = (values[0] as u8, values[1] as u8); - *self.s5_values.write() = Some(values); - values - }; + let target = match SleepTarget::try_from(state) { + Ok(target) => target, + Err(_) => { + log::error!("Cannot set S-state {state}, unsupported target"); + return; + } + }; + + if !target.is_soft_off() { + log::warn!( + "ACPI sleep groundwork only: {} is recognized but not implemented yet", + target.aml_method_name() + ); + return; + } + + log::info!("acpid: {:?} {}", SleepPhase::Prepare, target.aml_method_name()); + + let (slp_typa, slp_typb) = match self.sleep_type_values(target) { + Some(values) => values, + None => return, + }; let mut val = 1 << 13; log::trace!("Shutdown SLP_TYPa {:X}, SLP_TYPb {:X}", slp_typa, slp_typb); @@ -1412,6 +1425,7 @@ impl AcpiContext { { if self.pm1a_cnt_blk != 0 { let port = self.pm1a_cnt_blk as u16; + log::info!("acpid: {:?} {} via PM1a", SleepPhase::Enter, target.aml_method_name()); log::warn!("Shutdown with ACPI outw(0x{:X}, 0x{:X})", port, val); Pio::::new(port).write(val); } @@ -1419,6 +1433,7 @@ impl AcpiContext { if self.pm1b_cnt_blk != 0 { let mut val_b = 1 << 13; val_b |= u16::from(slp_typb); + log::info!("acpid: {:?} {} via PM1b", SleepPhase::Enter, target.aml_method_name()); let port = self.pm1b_cnt_blk as u16; log::warn!("Shutdown with ACPI outw(0x{:X}, 0x{:X})", port, val_b); Pio::::new(port).write(val_b); @@ -1438,6 +1453,23 @@ impl AcpiContext { } } + fn sleep_type_values(&self, target: SleepTarget) -> Option<(u8, u8)> { + if let Some(values) = self.sleep_values.read().get(&target).copied() { + return Some(values); + } + + let method_name = target.aml_method_name(); + let Ok(values) = self.evaluate_acpi_method("\\", method_name, &[]) else { + log::error!("Cannot set S-state, failed to evaluate \\{method_name}"); + return None; + }; + if values.len() < 2 { + log::error!("Cannot set S-state, \\{method_name} package too small"); + return None; + } + let values = (values[0] as u8, values[1] as u8); + self.sleep_values.write().insert(target, values); + Some(values) + } + pub fn acpi_reboot(&self) { #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] if let Some(reset_reg) = &self.reset_reg { diff --git a/drivers/graphics/ihdgd/config.toml b/drivers/graphics/ihdgd/config.toml diff --git a/drivers/storage/usbscsid/src/quirks.rs b/drivers/storage/usbscsid/src/quirks.rs index 5051f1b0..ae4c2e46 100644 --- a/drivers/storage/usbscsid/src/quirks.rs +++ b/drivers/storage/usbscsid/src/quirks.rs @@ -128,7 +128,7 @@ fn quirk_files() -> Option> { } fn parse_runtime_quirks_from_toml(text: &str) -> Vec { - let Ok(value) = text.parse::() else { + let Ok(value) = text.trim().parse::() else { return Vec::new(); }; @@ -201,13 +201,7 @@ mod tests { #[test] fn runtime_toml_parser_keeps_supported_flags_and_skips_unknown_ones() { let entries = parse_runtime_quirks_from_toml( - r#" - [[usb_storage_quirk]] - vendor = 4660 - product = 22136 - flags = ["ignore_residue", "unknown_flag", "fix_capacity"] - "#, + "[[usb_storage_quirk]]\nvendor = 4660\nproduct = 22136\nflags = [\"ignore_residue\", \"unknown_flag\", \"fix_capacity\"]\n", ); assert_eq!(entries.len(), 1); diff --git a/drivers/storage/usbscsid/src/scsi/cmds.rs b/drivers/storage/usbscsid/src/scsi/cmds.rs index ddc12336..df5d0a0c 100644 --- a/drivers/storage/usbscsid/src/scsi/cmds.rs +++ b/drivers/storage/usbscsid/src/scsi/cmds.rs @@ -265,6 +265,57 @@ impl Write10 { } } +#[repr(C, packed)] +#[derive(Clone, Copy, Debug)] +pub struct SynchronizeCache10 { + pub opcode: u8, + pub a: u8, + pub lba: u32, + pub group_num: u8, + pub blocks: u16, + pub control: u8, +} +unsafe impl plain::Plain for SynchronizeCache10 {} + +impl SynchronizeCache10 { + pub const fn new(lba: u64, blocks: u16, control: u8) -> Self { + Self { + opcode: Opcode::SyncCache10 as u8, + a: 0, + lba: u32::to_be(lba as u32), + group_num: 0, + blocks: u16::to_be(blocks), + control, + } + } +} + +#[repr(C, packed)] +#[derive(Clone, Copy, Debug)] +pub struct SynchronizeCache16 { + pub opcode: u8, + pub a: u8, + pub lba: u64, + pub blocks: u32, + pub group_num: u8, + pub control: u8, +} +unsafe impl plain::Plain for SynchronizeCache16 {} + +impl SynchronizeCache16 { + pub const fn new(lba: u64, blocks: u32, control: u8) -> Self { + Self { + opcode: Opcode::SyncCache16 as u8, + a: 0, + lba: u64::to_be(lba), + blocks: u32::to_be(blocks), + group_num: 0, + control, + } + } +} + #[repr(C, packed)] #[derive(Clone, Copy, Debug)] pub struct ModeSense6 { diff --git a/drivers/storage/usbscsid/src/scsi/mod.rs b/drivers/storage/usbscsid/src/scsi/mod.rs index b6d379d0..cf4a9707 100644 --- a/drivers/storage/usbscsid/src/scsi/mod.rs +++ b/drivers/storage/usbscsid/src/scsi/mod.rs @@ -25,6 +25,8 @@ const REQUEST_SENSE_CMD_LEN: u8 = 6; const MIN_INQUIRY_ALLOC_LEN: u16 = 5; const MIN_REPORT_SUPP_OPCODES_ALLOC_LEN: u32 = 4; const MAX_SECTORS_64_LIMIT: u64 = 64; +const SYNC_CACHE10_CMD_LEN: usize = 10; +const SYNC_CACHE16_CMD_LEN: usize = 16; @@ -286,6 +288,12 @@ impl Scsi { pub fn cmd_write10(&mut self) -> Result<&mut cmds::Write10> { parse_mut_bytes("WRITE(10) command", &mut self.command_buffer) } + pub fn cmd_sync_cache10(&mut self) -> Result<&mut cmds::SynchronizeCache10> { + parse_mut_bytes("SYNCHRONIZE CACHE(10) command", &mut self.command_buffer) + } + pub fn cmd_sync_cache16(&mut self) -> Result<&mut cmds::SynchronizeCache16> { + parse_mut_bytes("SYNCHRONIZE CACHE(16) command", &mut self.command_buffer) + } pub fn res_standard_inquiry_data(&self) -> Result<&StandardInquiryData> { parse_bytes("standard inquiry data", &self.inquiry_buffer) } @@ -467,6 +475,10 @@ impl Scsi { let status = protocol.send_command( &self.command_buffer[..10], DeviceReqData::Out(&buffer[..bytes_to_write]), )?; + if self.quirks.contains(UsbStorageQuirkFlags::NEEDS_SYNC_CACHE) + && status.kind == SendCommandStatusKind::Success + { + self.sync_cache(protocol, lba, blocks_to_write)?; + } Ok(status.bytes_transferred(bytes_to_write as u32)) } else { @@ -482,8 +494,83 @@ impl Scsi { let status = protocol.send_command( &self.command_buffer[..16], DeviceReqData::Out(&buffer[..bytes_to_write]), )?; + if self.quirks.contains(UsbStorageQuirkFlags::NEEDS_SYNC_CACHE) + && status.kind == SendCommandStatusKind::Success + { + self.sync_cache(protocol, lba, blocks_to_write)?; + } Ok(status.bytes_transferred(bytes_to_write as u32)) } } + + fn sync_cache(&mut self, protocol: &mut dyn Protocol, lba: u64, blocks: u64) -> Result<()> { + let use_sync_cache10 = self.quirks.contains(UsbStorageQuirkFlags::INITIAL_READ10) + && u32::try_from(lba).is_ok() + && u16::try_from(blocks).is_ok(); + + let status = if use_sync_cache10 { + let sync = self.cmd_sync_cache10()?; + *sync = cmds::SynchronizeCache10::new( + lba, + u16::try_from(blocks) + .map_err(|_| ScsiError::Overflow("sync cache(10) block count overflow"))?, + 0, + ); + protocol.send_command(&self.command_buffer[..SYNC_CACHE10_CMD_LEN], DeviceReqData::NoData)? + } else { + let sync = self.cmd_sync_cache16()?; + *sync = cmds::SynchronizeCache16::new( + lba, + u32::try_from(blocks) + .map_err(|_| ScsiError::Overflow("sync cache(16) block count overflow"))?, + 0, + ); + protocol.send_command(&self.command_buffer[..SYNC_CACHE16_CMD_LEN], DeviceReqData::NoData)? + }; + + if status.kind == SendCommandStatusKind::Success { + return Ok(()); + } + + if let Ok(()) = self.get_ff_sense(protocol, cmds::RequestSense::MINIMAL_ALLOC_LEN) { + if let Ok(sense) = self.res_ff_sense_data() { + if sense.add_sense_code == 0x3A + || sense.add_sense_code == 0x20 + || (sense.add_sense_code == 0x04 && sense.add_sense_code_qual == 0x04) + || sense.sense_key() == cmds::SenseKey::IllegalRequest + { + return Ok(()); + } + } + } + + Err(ScsiError::ProtocolError(ProtocolError::ProtocolError( + "SYNCHRONIZE CACHE command failed", + ))) + } } @@ -527,3 +614,53 @@ impl<'a> BlkDescSlice<'a> { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::protocol::SendCommandStatus; + use crate::scsi::opcodes::Opcode; + + struct MockProtocol { + commands: Vec>, + } + + impl MockProtocol { + fn new() -> Self { + Self { commands: Vec::new() } + } + } + + impl Protocol for MockProtocol { + fn send_command(&mut self, command: &[u8], _data: DeviceReqData) -> std::result::Result { + self.commands.push(command.to_vec()); + Ok(SendCommandStatus { residue: None, kind: SendCommandStatusKind::Success }) + } + } + + fn scsi_for_tests(quirks: UsbStorageQuirkFlags) -> Scsi { + Scsi { + command_buffer: [0u8; 16], + inquiry_buffer: [0u8; 259], + data_buffer: Vec::new(), + block_size: 512, + block_count: 1024, + quirks, + } + } + + #[test] + fn sync_cache_uses_10_byte_command_for_initial_read10_quirk() { + let mut scsi = scsi_for_tests(UsbStorageQuirkFlags::INITIAL_READ10); + let mut protocol = MockProtocol::new(); + scsi.sync_cache(&mut protocol, 7, 4).unwrap(); + assert_eq!(protocol.commands.len(), 1); + assert_eq!(protocol.commands[0].len(), 10); + assert_eq!(protocol.commands[0][0], Opcode::SyncCache10 as u8); + } + + #[test] + fn sync_cache_uses_16_byte_command_without_initial_read10_quirk() { + let mut scsi = scsi_for_tests(UsbStorageQuirkFlags::empty()); + let mut protocol = MockProtocol::new(); + scsi.sync_cache(&mut protocol, 7, 4).unwrap(); + assert_eq!(protocol.commands.len(), 1); + assert_eq!(protocol.commands[0].len(), 16); + assert_eq!(protocol.commands[0][0], Opcode::SyncCache16 as u8); + } +} diff --git a/drivers/storage/usbscsid/src/quirks.rs b/drivers/storage/usbscsid/src/quirks.rs index ae4c2e46..30380aea 100644 --- a/drivers/storage/usbscsid/src/quirks.rs +++ b/drivers/storage/usbscsid/src/quirks.rs @@ -128,7 +128,7 @@ fn quirk_files() -> Option> { } fn parse_runtime_quirks_from_toml(text: &str) -> Vec { - let Ok(value) = text.parse::() else { + let Ok(value) = text.trim().parse::
() else { return Vec::new(); }; @@ -201,13 +201,7 @@ mod tests { #[test] fn runtime_toml_parser_keeps_supported_flags_and_skips_unknown_ones() { let entries = parse_runtime_quirks_from_toml( - r#" - [[usb_storage_quirk]] - vendor = 4660 - product = 22136 - flags = ["ignore_residue", "unknown_flag", "fix_capacity"] - "#, + "[[usb_storage_quirk]]\nvendor = 4660\nproduct = 22136\nflags = [\"ignore_residue\", \"unknown_flag\", \"fix_capacity\"]\n", ); assert_eq!(entries.len(), 1); diff --git a/drivers/storage/usbscsid/src/scsi/cmds.rs b/drivers/storage/usbscsid/src/scsi/cmds.rs index df5d0a0c..0674c6b5 100644 --- a/drivers/storage/usbscsid/src/scsi/cmds.rs +++ b/drivers/storage/usbscsid/src/scsi/cmds.rs @@ -265,6 +265,57 @@ impl Write10 { } } +#[repr(C, packed)] +#[derive(Clone, Copy, Debug)] +pub struct SynchronizeCache10 { + pub opcode: u8, + pub a: u8, + pub lba: u32, + pub group_num: u8, + pub blocks: u16, + pub control: u8, +} +unsafe impl plain::Plain for SynchronizeCache10 {} + +impl SynchronizeCache10 { + pub const fn new(lba: u64, blocks: u16, control: u8) -> Self { + Self { + opcode: Opcode::SyncCache10 as u8, + a: 0, + lba: u32::to_be(lba as u32), + group_num: 0, + blocks: u16::to_be(blocks), + control, + } + } +} + +#[repr(C, packed)] +#[derive(Clone, Copy, Debug)] +pub struct SynchronizeCache16 { + pub opcode: u8, + pub a: u8, + pub lba: u64, + pub blocks: u32, + pub group_num: u8, + pub control: u8, +} +unsafe impl plain::Plain for SynchronizeCache16 {} + +impl SynchronizeCache16 { + pub const fn new(lba: u64, blocks: u32, control: u8) -> Self { + Self { + opcode: Opcode::SyncCache16 as u8, + a: 0, + lba: u64::to_be(lba), + blocks: u32::to_be(blocks), + group_num: 0, + control, + } + } +} + #[repr(C, packed)] #[derive(Clone, Copy, Debug)] pub struct ModeSense6 { diff --git a/drivers/storage/usbscsid/src/scsi/mod.rs b/drivers/storage/usbscsid/src/scsi/mod.rs index cf4a9707..ad0565ca 100644 --- a/drivers/storage/usbscsid/src/scsi/mod.rs +++ b/drivers/storage/usbscsid/src/scsi/mod.rs @@ -25,6 +25,8 @@ const REQUEST_SENSE_CMD_LEN: u8 = 6; const MIN_INQUIRY_ALLOC_LEN: u16 = 5; const MIN_REPORT_SUPP_OPCODES_ALLOC_LEN: u32 = 4; const MAX_SECTORS_64_LIMIT: u64 = 64; +const SYNC_CACHE10_CMD_LEN: usize = 10; +const SYNC_CACHE16_CMD_LEN: usize = 16; @@ -286,6 +288,12 @@ impl Scsi { pub fn cmd_write10(&mut self) -> Result<&mut cmds::Write10> { parse_mut_bytes("WRITE(10) command", &mut self.command_buffer) } + pub fn cmd_sync_cache10(&mut self) -> Result<&mut cmds::SynchronizeCache10> { + parse_mut_bytes("SYNCHRONIZE CACHE(10) command", &mut self.command_buffer) + } + pub fn cmd_sync_cache16(&mut self) -> Result<&mut cmds::SynchronizeCache16> { + parse_mut_bytes("SYNCHRONIZE CACHE(16) command", &mut self.command_buffer) + } pub fn res_standard_inquiry_data(&self) -> Result<&StandardInquiryData> { parse_bytes("standard inquiry data", &self.inquiry_buffer) } @@ -467,6 +475,10 @@ impl Scsi { let status = protocol.send_command( &self.command_buffer[..10], DeviceReqData::Out(&buffer[..bytes_to_write]), )?; + if self.quirks.contains(UsbStorageQuirkFlags::NEEDS_SYNC_CACHE) + && status.kind == SendCommandStatusKind::Success + { + self.sync_cache(protocol, lba, blocks_to_write)?; + } Ok(status.bytes_transferred(bytes_to_write as u32)) } else { @@ -482,8 +494,83 @@ impl Scsi { let status = protocol.send_command( &self.command_buffer[..16], DeviceReqData::Out(&buffer[..bytes_to_write]), )?; + if self.quirks.contains(UsbStorageQuirkFlags::NEEDS_SYNC_CACHE) + && status.kind == SendCommandStatusKind::Success + { + self.sync_cache(protocol, lba, blocks_to_write)?; + } Ok(status.bytes_transferred(bytes_to_write as u32)) } } + + fn sync_cache(&mut self, protocol: &mut dyn Protocol, lba: u64, blocks: u64) -> Result<()> { + let use_sync_cache10 = self.quirks.contains(UsbStorageQuirkFlags::INITIAL_READ10) + && u32::try_from(lba).is_ok() + && u16::try_from(blocks).is_ok(); + + let status = if use_sync_cache10 { + let sync = self.cmd_sync_cache10()?; + *sync = cmds::SynchronizeCache10::new( + lba, + u16::try_from(blocks) + .map_err(|_| ScsiError::Overflow("sync cache(10) block count overflow"))?, + 0, + ); + protocol.send_command(&self.command_buffer[..SYNC_CACHE10_CMD_LEN], DeviceReqData::NoData)? + } else { + let sync = self.cmd_sync_cache16()?; + *sync = cmds::SynchronizeCache16::new( + lba, + u32::try_from(blocks) + .map_err(|_| ScsiError::Overflow("sync cache(16) block count overflow"))?, + 0, + ); + protocol.send_command(&self.command_buffer[..SYNC_CACHE16_CMD_LEN], DeviceReqData::NoData)? + }; + + if status.kind == SendCommandStatusKind::Success { + return Ok(()); + } + + if let Ok(()) = self.get_ff_sense(protocol, cmds::RequestSense::MINIMAL_ALLOC_LEN) { + if let Ok(sense) = self.res_ff_sense_data() { + if sense.add_sense_code == 0x3A + || sense.add_sense_code == 0x20 + || (sense.add_sense_code == 0x04 && sense.add_sense_code_qual == 0x04) + || sense.sense_key() == cmds::SenseKey::IllegalRequest + { + return Ok(()); + } + } + } + + Err(ScsiError::ProtocolError(ProtocolError::ProtocolError( + "SYNCHRONIZE CACHE command failed", + ))) + } } @@ -527,3 +614,53 @@ impl<'a> BlkDescSlice<'a> { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::protocol::SendCommandStatus; + use crate::scsi::opcodes::Opcode; + + struct MockProtocol { + commands: Vec>, + } + + impl MockProtocol { + fn new() -> Self { + Self { commands: Vec::new() } + } + } + + impl Protocol for MockProtocol { + fn send_command(&mut self, command: &[u8], _data: DeviceReqData) -> std::result::Result { + self.commands.push(command.to_vec()); + Ok(SendCommandStatus { residue: None, kind: SendCommandStatusKind::Success }) + } + } + + fn scsi_for_tests(quirks: UsbStorageQuirkFlags) -> Scsi { + Scsi { + command_buffer: [0u8; 16], + inquiry_buffer: [0u8; 259], + data_buffer: Vec::new(), + block_size: 512, + block_count: 1024, + quirks, + } + } + + #[test] + fn sync_cache_uses_10_byte_command_for_initial_read10_quirk() { + let mut scsi = scsi_for_tests(UsbStorageQuirkFlags::INITIAL_READ10); + let mut protocol = MockProtocol::new(); + scsi.sync_cache(&mut protocol, 7, 4).unwrap(); + assert_eq!(protocol.commands.len(), 1); + assert_eq!(protocol.commands[0].len(), 10); + assert_eq!(protocol.commands[0][0], Opcode::SyncCache10 as u8); + } + + #[test] + fn sync_cache_uses_16_byte_command_without_initial_read10_quirk() { + let mut scsi = scsi_for_tests(UsbStorageQuirkFlags::empty()); + let mut protocol = MockProtocol::new(); + scsi.sync_cache(&mut protocol, 7, 4).unwrap(); + assert_eq!(protocol.commands.len(), 1); + assert_eq!(protocol.commands[0].len(), 16); + assert_eq!(protocol.commands[0][0], Opcode::SyncCache16 as u8); + } +} diff --git a/drivers/usb/xhcid/src/xhci/device_enumerator.rs b/drivers/usb/xhcid/src/xhci/device_enumerator.rs index 1f144ac9..00000000 100644 --- a/drivers/usb/xhcid/src/xhci/device_enumerator.rs +++ b/drivers/usb/xhcid/src/xhci/device_enumerator.rs @@ -4,8 +4,12 @@ use crate::xhci::{PortId, Xhci}; use common::io::Io; use crossbeam_channel; use log::{debug, info, warn}; use std::sync::Arc; -use std::time::Duration; +use std::time::{Duration, Instant}; use syscall::EAGAIN; + +const DEFAULT_PORT_RESET_SETTLE_MS: u64 = 16; +const RESET_DELAY_PORT_RESET_SETTLE_MS: u64 = 100; +const HUB_SLOW_RESET_PORT_RESET_SETTLE_MS: u64 = 200; pub struct DeviceEnumerationRequest { pub port_id: PortId, @@ -25,10 +29,14 @@ impl DeviceEnumerator { loop { debug!("Start Device Enumerator Loop"); let request = match self.request_queue.recv() { Ok(req) => req, Err(err) => { - panic!("Failed to received an enumeration request! error: {}", err) + warn!( + "device enumerator stopping after request queue closed: {}", + err + ); + break; } }; @@ -64,13 +72,12 @@ impl DeviceEnumerator { //If the port isn't enabled (i.e. it's a USB2 port), we need to reset it if it isn't resetting already //A USB3 port won't generate a Connect Status Change until it's already enabled, so this check //will always be skipped for USB3 ports if !flags.contains(PortFlags::PED) { - let disabled_state = flags.contains(PortFlags::PP) - && flags.contains(PortFlags::CCS) - && !flags.contains(PortFlags::PED) - && !flags.contains(PortFlags::PR); + let disabled_state = Self::port_is_disabled(&flags); if !disabled_state { - panic!( - "Port {} isn't in the disabled state! Current flags: {:?}", + warn!( + "Port {} never reached the disabled state before reset-driven enumeration; current flags: {:?}", port_id, flags ); + continue; } else { debug!("Port {} has entered the disabled state.", port_id); } @@ -89,17 +96,15 @@ impl DeviceEnumerator { port.clear_prc(); - let delay_ms = if early_quirks - .contains(crate::usb_quirks::UsbQuirkFlags::HUB_SLOW_RESET) - { - 200 - } else if early_quirks.contains(crate::usb_quirks::UsbQuirkFlags::RESET_DELAY) { - 100 - } else { - 16 - }; - - std::thread::sleep(Duration::from_millis(delay_ms)); // Some devices need extra time to settle after reset. + } + + let flags = self.wait_for_port_enabled_state( + port_array_index, + Duration::from_millis(Self::port_reset_settle_delay_ms(early_quirks)), + ); - let flags = port.flags(); - - let enabled_state = flags.contains(PortFlags::PP) - && flags.contains(PortFlags::CCS) - && flags.contains(PortFlags::PED) - && !flags.contains(PortFlags::PR); + let enabled_state = Self::port_is_enabled(&flags); if !enabled_state { warn!( - "Port {} isn't in the enabled state! Current flags: {:?}", + "Port {} isn't in the enabled state after bounded reset settle; current flags: {:?}", port_id, flags ); + continue; } else { @@ -140,4 +145,47 @@ impl DeviceEnumerator { } } } + + fn port_reset_settle_delay_ms(quirks: crate::usb_quirks::UsbQuirkFlags) -> u64 { + if quirks.contains(crate::usb_quirks::UsbQuirkFlags::HUB_SLOW_RESET) { + HUB_SLOW_RESET_PORT_RESET_SETTLE_MS + } else if quirks.contains(crate::usb_quirks::UsbQuirkFlags::RESET_DELAY) { + RESET_DELAY_PORT_RESET_SETTLE_MS + } else { + DEFAULT_PORT_RESET_SETTLE_MS + } + } + + fn port_is_disabled(flags: &PortFlags) -> bool { + flags.contains(PortFlags::PP) + && flags.contains(PortFlags::CCS) + && !flags.contains(PortFlags::PED) + && !flags.contains(PortFlags::PR) + } + + fn port_is_enabled(flags: &PortFlags) -> bool { + flags.contains(PortFlags::PP) + && flags.contains(PortFlags::CCS) + && flags.contains(PortFlags::PED) + && !flags.contains(PortFlags::PR) + } + + fn wait_for_port_enabled_state( + &self, + port_array_index: usize, + settle_timeout: Duration, + ) -> PortFlags { + let start = Instant::now(); + + loop { + let flags = { + let ports = self + .hci + .ports + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + ports[port_array_index].flags() + }; + + if Self::port_is_enabled(&flags) + || !flags.contains(PortFlags::PR) + || start.elapsed() >= settle_timeout + { + return flags; + } + + std::thread::sleep(Duration::from_millis(1)); + } + } diff --git a/drivers/usb/xhcid/src/xhci/device_enumerator.rs b/drivers/usb/xhcid/src/xhci/device_enumerator.rs index 00000000..00000000 100644 --- a/drivers/usb/xhcid/src/xhci/device_enumerator.rs +++ b/drivers/usb/xhcid/src/xhci/device_enumerator.rs @@ -46,7 +46,11 @@ impl DeviceEnumerator { debug!("Device Enumerator request for port {}", port_id); let (len, flags) = { - let ports = self.hci.ports.lock().unwrap(); + let ports = self + .hci + .ports + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); let len = ports.len(); @@ -86,7 +90,11 @@ impl DeviceEnumerator { debug!("Received a device connect on port {}, but it's not enabled. Resetting the port.", port_id); if let Err(err) = self.hci.reset_port(port_id) { warn!( "failed to reset port {} before enumeration; skipping attach: {}", port_id, err ); continue; } - let mut ports = self.hci.ports.lock().unwrap(); + let mut ports = self + .hci + .ports + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); let port = &mut ports[port_array_index]; port.clear_prc(); @@ -144,10 +152,16 @@ impl DeviceEnumerator { let result = futures::executor::block_on(self.hci.detach_device(port_id)); match result { Ok(was_connected) => { if was_connected { info!("Device on port {} was detached", port_id); + } else { + debug!( + "Ignoring duplicate or out-of-order detach event for unattached port {}", + port_id + ); } } Err(err) => { - warn!("processing of device attach request failed! Error: {}", err); + warn!("processing of device detach request failed! Error: {}", err); } } diff --git a/drivers/usb/xhcid/src/xhci/mod.rs b/drivers/usb/xhcid/src/xhci/mod.rs index c53cb59f..814fdb4f 100644 --- a/drivers/usb/xhcid/src/xhci/mod.rs +++ b/drivers/usb/xhcid/src/xhci/mod.rs @@ -307,6 +307,7 @@ struct PortState { slot: u8, protocol_speed: &'static ProtocolSpeed, cfg_idx: Option, + active_alternates: BTreeMap, input_context: Mutex>>, dev_desc: Option, endpoint_states: BTreeMap, @@ -324,29 +325,37 @@ pub(crate) enum PortPmState { impl PortState { //TODO: fetch using endpoint number instead fn get_endp_desc(&self, endp_idx: u8) -> Option<&EndpDesc> { - let cfg_idx = self.cfg_idx?; - let config_desc = self - .dev_desc - .as_ref()? - .config_descs - .iter() - .find(|desc| desc.configuration_value == cfg_idx)?; - let mut endp_count = 0; - for if_desc in config_desc.interface_descs.iter() { - let active_alternate = self - .active_alternates - .get(&if_desc.number) - .copied() - .unwrap_or(0); - if if_desc.alternate_setting != active_alternate { - continue; - } - for endp_desc in if_desc.endpoints.iter() { - if endp_idx == endp_count { - return Some(endp_desc); - } - endp_count += 1; - } - } - None + active_endpoint_desc( + self.dev_desc.as_ref()?, + self.cfg_idx?, + &self.active_alternates, + endp_idx, + ) } } + +fn active_configuration<'a>(dev_desc: &'a DevDesc, cfg_idx: u8) -> Option<&'a ConfDesc> { + dev_desc + .config_descs + .iter() + .find(|desc| desc.configuration_value == cfg_idx) +} + +fn active_endpoint_desc<'a>(dev_desc: &'a DevDesc, cfg_idx: u8, active_alternates: &BTreeMap, endp_idx: u8) -> Option<&'a EndpDesc> { + let config_desc = active_configuration(dev_desc, cfg_idx)?; + let mut endp_count = 0; + for if_desc in config_desc.interface_descs.iter() { + let active_alternate = active_alternates.get(&if_desc.number).copied().unwrap_or(0); + if if_desc.alternate_setting != active_alternate { + continue; + } + for endp_desc in if_desc.endpoints.iter() { + if endp_idx == endp_count { + return Some(endp_desc); + } + endp_count += 1; + } + } + None +} @@ -872,6 +881,7 @@ impl Xhci { protocol_speed, input_context: Mutex::new(input), dev_desc: None, + active_alternates: BTreeMap::new(), cfg_idx: None, endpoint_states: std::iter::once(( 0, @@ -1516,6 +1526,67 @@ struct DriversConfig { drivers: Vec, } + +#[cfg(test)] +mod tests { + use super::{active_endpoint_desc, BTreeMap, ConfDesc, DevDesc, EndpDesc, IfDesc}; + use crate::driver_interface::EndpointTy; + use smallvec::smallvec; + + fn endp(address: u8, attributes: u8) -> EndpDesc { + EndpDesc { kind: 5, address, attributes, max_packet_size: 64, interval: 1, ssc: None, sspc: None } + } + + fn if_desc(number: u8, alternate_setting: u8, endpoints: Vec) -> IfDesc { + IfDesc { + kind: 4, + number, + alternate_setting, + class: 3, + sub_class: 1, + protocol: 1, + interface_str: None, + endpoints: endpoints.into_iter().collect(), + hid_descs: smallvec![], + } + } + + fn sample_dev_desc() -> DevDesc { + DevDesc { + kind: 1, + usb: 0x0200, + class: 0, + sub_class: 0, + protocol: 0, + packet_size: 64, + vendor: 0x1234, + product: 0x5678, + release: 0x0100, + manufacturer_str: None, + product_str: None, + serial_str: None, + config_descs: smallvec![ConfDesc { kind: 2, configuration_value: 1, configuration: None, attributes: 0x80, max_power: 50, interface_descs: smallvec![ if_desc(0, 0, vec![endp(0x81, 0x03)]), if_desc(0, 1, vec![endp(0x82, 0x03), endp(0x02, 0x03)]), if_desc(1, 0, vec![endp(0x83, 0x02)]), ], }], + } + } + + #[test] + fn active_endpoint_desc_uses_default_alternates_initially() { + let dev_desc = sample_dev_desc(); + let active = BTreeMap::new(); + let first = active_endpoint_desc(&dev_desc, 1, &active, 0).expect("endpoint 0"); + let second = active_endpoint_desc(&dev_desc, 1, &active, 1).expect("endpoint 1"); + assert_eq!(first.address, 0x81); + assert_eq!(first.ty(), EndpointTy::Interrupt); + assert_eq!(second.address, 0x83); + assert_eq!(second.ty(), EndpointTy::Bulk); + assert!(active_endpoint_desc(&dev_desc, 1, &active, 2).is_none()); + } + + #[test] + fn active_endpoint_desc_switches_to_selected_alternate() { + let dev_desc = sample_dev_desc(); + let mut active = BTreeMap::new(); + active.insert(0, 1); + let first = active_endpoint_desc(&dev_desc, 1, &active, 0).expect("endpoint 0"); + let second = active_endpoint_desc(&dev_desc, 1, &active, 1).expect("endpoint 1"); + let third = active_endpoint_desc(&dev_desc, 1, &active, 2).expect("endpoint 2"); + assert_eq!(first.address, 0x82); + assert_eq!(second.address, 0x02); + assert_eq!(third.address, 0x83); + } +} diff --git a/drivers/usb/xhcid/src/xhci/scheme.rs b/drivers/usb/xhcid/src/xhci/scheme.rs --- a/drivers/usb/xhcid/src/xhci/scheme.rs +++ b/drivers/usb/xhcid/src/xhci/scheme.rs @@ -3487,6 +3487,14 @@ impl EndpointContextSnapshot { fn capture_values(a: u32, b: u32, trl: u32, trh: u32, c: u32) -> Self { Self { a, b, trl, trh, c } } + + fn restore(&self, ctx: &mut EndpointContext) { + ctx.a.write(self.a); + ctx.b.write(self.b); + ctx.trl.write(self.trl); + ctx.trh.write(self.trh); + ctx.c.write(self.c); + } } @@ -1171,7 +1171,9 @@ impl XhciScheme input_context.device.slot.c.write(snapshot.slot_c); for (endp_i, endp_snapshot) in endpoint_snapshots { - endp_snapshot.restore(&mut input_context.device.endpoints[*endp_i]); + let endpoint_ptr = core::ptr::addr_of_mut!(input_context.device.endpoints[*endp_i]); + let mut endpoint = unsafe { core::ptr::read_unaligned(endpoint_ptr) }; + endp_snapshot.restore(&mut endpoint); + unsafe { core::ptr::write_unaligned(endpoint_ptr, endpoint) }; } Ok(input_context.physical()) index b0fb9b85..bba6f232 100644 --- a/drivers/usb/xhcid/src/xhci/scheme.rs +++ b/drivers/usb/xhcid/src/xhci/scheme.rs @@ -1105,11 +1105,28 @@ impl Xhci { .find(|desc| desc.configuration_value == req.config_desc) .ok_or(Error::new(EBADFD))?; + let configuration_value = config_desc.configuration_value; + + let interface_layout = config_desc + .interface_descs + .iter() + .map(|if_desc| { + ( + if_desc.number, + if_desc.alternate_setting, + if_desc.endpoints.iter().map(|endp| *endp).collect::>(), + ) + }) + .collect::>(); - //TODO: USE ENDPOINTS FROM ALL INTERFACES + port_state.active_alternates.clear(); + for (if_num, _, _) in &interface_layout { + port_state.active_alternates.entry(*if_num).or_insert(0); + } + let mut endp_desc_count = 0; let mut new_context_entries = 1; - for if_desc in config_desc.interface_descs.iter() { - for endpoint in if_desc.endpoints.iter() { + for (if_num, alternate_setting, endpoints) in &interface_layout { + let active_alternate = port_state.active_alternates.get(if_num).copied().unwrap_or(0); + if *alternate_setting != active_alternate { + continue; + } + for endpoint in endpoints.iter() { endp_desc_count += 1; let entry = Self::endp_num_to_dci(endp_desc_count, endpoint); if entry > new_context_entries { @@ -1128,7 +1145,7 @@ impl Xhci { ( endp_desc_count, new_context_entries, - config_desc.configuration_value, + configuration_value, ) }; @@ -1397,6 +1414,10 @@ impl Xhci { if !skip_set_interface { self.set_interface(port, interface_num, alternate_setting) .await?; + } + + if let Some(mut port_state) = self.port_states.get_mut(&port) { + port_state.active_alternates.insert(interface_num, alternate_setting); } } } diff --git a/drivers/graphics/ihdgd/config.toml b/drivers/graphics/ihdgd/config.toml index acbb4e78..210731ae 100644 --- a/drivers/graphics/ihdgd/config.toml +++ b/drivers/graphics/ihdgd/config.toml @@ -51,5 +51,26 @@ ids = { 0x8086 = [ 0x56B3, # Pro A60 0x56C0, # GPU Flex 170 0x56C1, # GPU Flex 140 + # Alder Lake-S Desktop + 0x4680, 0x4682, 0x4688, 0x468A, 0x468B, + 0x4690, 0x4692, 0x4693, + # Alder Lake-P Mobile + 0x46A0, 0x46A1, 0x46A2, 0x46A3, 0x46A6, + 0x46A8, 0x46AA, 0x462A, 0x4626, 0x4628, + 0x46B0, 0x46B1, 0x46B2, 0x46B3, + 0x46C0, 0x46C1, 0x46C2, 0x46C3, + # Alder Lake-N Low-Power + 0x46D0, 0x46D1, 0x46D2, 0x46D3, 0x46D4, + # Raptor Lake-S Desktop + 0xA780, 0xA781, 0xA782, 0xA783, + 0xA788, 0xA789, 0xA78A, 0xA78B, + # Raptor Lake-P Mobile + 0xA720, 0xA7A0, 0xA7A8, 0xA7AA, 0xA7AB, + # Raptor Lake-U Mobile + 0xA721, 0xA7A1, 0xA7A9, 0xA7AC, 0xA7AD, + # Meteor Lake + 0x7D40, 0x7D45, 0x7D55, 0x7D60, 0x7DD5, + # Arrow Lake-H + 0x7D51, 0x7DD1, ] } command = ["ihdgd"] diff --git a/drivers/hwd/src/backend/acpi.rs b/drivers/hwd/src/backend/acpi.rs index 3da41d63..ec8828ee 100644 --- a/drivers/hwd/src/backend/acpi.rs +++ b/drivers/hwd/src/backend/acpi.rs @@ -1,5 +1,6 @@ use amlserde::{AmlSerde, AmlSerdeValue}; use std::{error::Error, fs, process::Command}; +use std::{thread, time::Duration}; use super::Backend; @@ -20,14 +21,57 @@ impl Backend for AcpiBackend { } fn probe(&mut self) -> Result<(), Box> { - // Read symbols from acpi scheme - let entries = fs::read_dir("/scheme/acpi/symbols")?; - // TODO: Reimplement with getdents? - let symbols_fd = libredox::Fd::open( - "/scheme/acpi/symbols", - libredox::flag::O_DIRECTORY | libredox::flag::O_RDONLY, - 0, - )?; + const SYMBOL_RETRY_COUNT: usize = 20; + const SYMBOL_RETRY_DELAY: Duration = Duration::from_millis(100); + + let (entries, symbols_fd) = { + let mut last_error = None; + + let mut ready = None; + for attempt in 1..=SYMBOL_RETRY_COUNT { + match fs::read_dir("/scheme/acpi/symbols") { + Ok(entries) => match libredox::Fd::open( + "/scheme/acpi/symbols", + libredox::flag::O_DIRECTORY | libredox::flag::O_RDONLY, + 0, + ) { + Ok(symbols_fd) => { + ready = Some((entries, symbols_fd)); + break; + } + Err(err) => { + let message = format!("open failed: {err}"); + if attempt == 1 || attempt == SYMBOL_RETRY_COUNT { + log::warn!( + "ACPI symbols not ready yet (attempt {attempt}/{SYMBOL_RETRY_COUNT}): {message}" + ); + } + last_error = Some(message); + } + }, + Err(err) => { + let message = format!("read_dir failed: {err}"); + if attempt == 1 || attempt == SYMBOL_RETRY_COUNT { + log::warn!( + "ACPI symbols not ready yet (attempt {attempt}/{SYMBOL_RETRY_COUNT}): {message}" + ); + } + last_error = Some(message); + } + } + + if attempt != SYMBOL_RETRY_COUNT { + thread::sleep(SYMBOL_RETRY_DELAY); + } + } + + ready.ok_or_else(|| { + std::io::Error::other( + last_error.unwrap_or_else(|| "timed out waiting for ACPI symbols".to_string()), + ) + })? + }; + for entry_res in entries { let entry = entry_res?; if let Some(file_name) = entry.file_name().to_str() { diff --git a/drivers/input/ps2d/src/controller.rs b/drivers/input/ps2d/src/controller.rs index d7af4cba..561aa527 100644 --- a/drivers/input/ps2d/src/controller.rs +++ b/drivers/input/ps2d/src/controller.rs @@ -13,7 +13,7 @@ use common::io::Pio; #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))] use common::io::Mmio; -use log::{debug, error, info, trace, warn}; +use log::{debug, error, trace, warn}; use std::fmt; @@ -271,6 +271,20 @@ impl Ps2 { } } + pub fn probe(&mut self) -> bool { + let status = self.status(); + let status_bits = status.bits(); + + if status_bits == 0x00 || status_bits == 0xFF { + debug!( + "ps/2 controller probe returned suspicious status {:02X}", + status_bits + ); + } + + self.config().is_ok() + } + pub fn init_keyboard(&mut self) -> Result<(), Error> { let mut b; diff --git a/drivers/input/ps2d/src/main.rs b/drivers/input/ps2d/src/main.rs index db17de2a..faa02e99 100644 --- a/drivers/input/ps2d/src/main.rs +++ b/drivers/input/ps2d/src/main.rs @@ -20,7 +20,7 @@ mod mouse; mod state; mod vm; -fn daemon(daemon: daemon::Daemon) -> ! { +fn run() -> ! { common::setup_logging( "input", "ps2", @@ -29,9 +29,18 @@ fn daemon(daemon: daemon::Daemon) -> ! { common::file_level(), ); - acquire_port_io_rights().expect("ps2d: failed to get I/O permission"); + if let Err(err) = acquire_port_io_rights() { + log::error!("ps2d: failed to get I/O permission: {}", err); + process::exit(1); + } - let input = ProducerHandle::new().expect("ps2d: failed to open input producer"); + let input = match ProducerHandle::new() { + Ok(input) => input, + Err(err) => { + log::error!("ps2d: failed to open input producer: {}", err); + process::exit(1); + } + }; user_data! { enum Source { @@ -44,12 +53,19 @@ fn daemon(daemon: daemon::Daemon) -> ! { let event_queue: EventQueue = EventQueue::new().expect("ps2d: failed to create event queue"); - let mut key_file = OpenOptions::new() + let key_file = OpenOptions::new() .read(true) .write(true) .custom_flags(syscall::O_NONBLOCK as i32) - .open("/scheme/serio/0") - .expect("ps2d: failed to open /scheme/serio/0"); + .open("/scheme/serio/0"); + + let mut key_file = match key_file { + Ok(key_file) => key_file, + Err(err) => { + log::error!("ps2d: failed to open /scheme/serio/0: {}", err); + process::exit(1); + } + }; event_queue .subscribe( @@ -59,12 +75,19 @@ fn daemon(daemon: daemon::Daemon) -> ! { ) .unwrap(); - let mut mouse_file = OpenOptions::new() + let mouse_file = OpenOptions::new() .read(true) .write(true) .custom_flags(syscall::O_NONBLOCK as i32) - .open("/scheme/serio/1") - .expect("ps2d: failed to open /scheme/serio/1"); + .open("/scheme/serio/1"); + + let mut mouse_file = match mouse_file { + Ok(mouse_file) => mouse_file, + Err(err) => { + log::error!("ps2d: failed to open /scheme/serio/1: {}", err); + process::exit(1); + } + }; event_queue .subscribe( @@ -78,8 +101,15 @@ fn daemon(daemon: daemon::Daemon) -> ! { .read(true) .write(true) .custom_flags(syscall::O_NONBLOCK as i32) - .open(format!("/scheme/time/{}", syscall::CLOCK_MONOTONIC)) - .expect("ps2d: failed to open /scheme/time"); + .open(format!("/scheme/time/{}", syscall::CLOCK_MONOTONIC)); + + let time_file = match time_file { + Ok(time_file) => time_file, + Err(err) => { + log::error!("ps2d: failed to open /scheme/time: {}", err); + process::exit(1); + } + }; event_queue .subscribe( @@ -89,11 +119,15 @@ fn daemon(daemon: daemon::Daemon) -> ! { ) .unwrap(); - libredox::call::setrens(0, 0).expect("ps2d: failed to enter null namespace"); - - daemon.ready(); + if let Err(err) = libredox::call::setrens(0, 0) { + log::error!("ps2d: failed to enter null namespace: {}", err); + process::exit(1); + } - let mut ps2d = Ps2d::new(input, time_file); + let Some(mut ps2d) = Ps2d::new(input, time_file) else { + log::warn!("ps2d: no PS/2 hardware available, exiting"); + process::exit(0); + }; let mut data = [0; 256]; for event in event_queue.map(|e| e.expect("ps2d: failed to get next event").user_data) { @@ -131,5 +165,5 @@ fn daemon(daemon: daemon::Daemon) -> ! { } fn main() { - daemon::Daemon::new(daemon); + run(); } diff --git a/drivers/input/ps2d/src/state.rs b/drivers/input/ps2d/src/state.rs index 9018dc6b..2721c4fd 100644 --- a/drivers/input/ps2d/src/state.rs +++ b/drivers/input/ps2d/src/state.rs @@ -59,9 +59,18 @@ pub struct Ps2d { } impl Ps2d { - pub fn new(input: ProducerHandle, time_file: File) -> Self { + pub fn new(input: ProducerHandle, time_file: File) -> Option { let mut ps2 = Ps2::new(); - ps2.init().expect("failed to initialize"); + + if !ps2.probe() { + warn!("ps2d: no PS/2 controller detected, skipping initialization"); + return None; + } + + if let Err(err) = ps2.init() { + error!("ps2d: failed to initialize PS/2 controller: {:?}", err); + return None; + } // FIXME add an option for orbital to disable this when an app captures the mouse. let vmmouse_relative = false; @@ -70,7 +79,7 @@ impl Ps2d { // TODO: QEMU hack, maybe do this when Init timed out? if vmmouse { // 3 = MouseId::Intellimouse1 - MouseState::Bat.handle(3, &mut ps2); + let _ = MouseState::Bat.handle(3, &mut ps2); } let mut this = Ps2d { @@ -96,7 +105,7 @@ impl Ps2d { this.handle_mouse(None); } - this + Some(this) } pub fn irq(&mut self) { diff --git a/drivers/input/usbhidd/src/main.rs b/drivers/input/usbhidd/src/main.rs diff --git a/drivers/input/ps2d/src/controller.rs b/drivers/input/ps2d/src/controller.rs index 561aa527..0310a367 100644 --- a/drivers/input/ps2d/src/controller.rs +++ b/drivers/input/ps2d/src/controller.rs @@ -283,8 +283,27 @@ impl Ps2 { status_bits ); } + let flushed = self.flush_output(); + if flushed != 0 { + debug!("ps/2 controller probe drained {} stale byte(s)", flushed); + } self.config().is_ok() } + + pub fn flush_output(&mut self) -> usize { + let mut flushed = 0; + while let Some((keyboard, data)) = self.next() { + flushed += 1; + trace!( + "ps/2 flush discarded {:02X} from {} channel", + data, + if keyboard { "keyboard" } else { "mouse" } + ); + } + flushed + } pub fn init_keyboard(&mut self) -> Result<(), Error> { let mut b; @@ -325,6 +344,11 @@ impl Ps2 { self.command(Command::DisableSecond)?; } + let flushed = self.flush_output(); + if flushed != 0 { + debug!("ps/2 init discarded {} stale byte(s) before config", flushed); + } + // Disable clocks, disable interrupts, and disable translate { // Since the default config may have interrupts enabled, and the kernel may eat up @@ -358,6 +382,11 @@ impl Ps2 { warn!("self test unexpected value: {:02X}", r); } } + + let flushed = self.flush_output(); + if flushed != 0 { + debug!("ps/2 init discarded {} byte(s) after controller self-test", flushed); + } // Initialize keyboard if let Err(err) = self.init_keyboard() { diff --git a/drivers/input/usbhidd/src/main.rs b/drivers/input/usbhidd/src/main.rs index 15c5b778..c67fb8bc 100644 --- a/drivers/input/usbhidd/src/main.rs +++ b/drivers/input/usbhidd/src/main.rs @@ -159,17 +159,17 @@ fn main() -> Result<()> { const USAGE: &'static str = "usbhidd "; - let scheme = args.next().expect(USAGE); + let scheme = args.next().ok_or_else(|| anyhow::anyhow!(USAGE))?; let port = args .next() - .expect(USAGE) + .ok_or_else(|| anyhow::anyhow!(USAGE))? .parse::() - .expect("Expected port ID"); + .map_err(|err| anyhow::anyhow!("Expected port ID: {err}"))?; let interface_num = args .next() - .expect(USAGE) + .ok_or_else(|| anyhow::anyhow!(USAGE))? .parse::() - .expect("Expected integer as input of interface"); + .context("Expected integer as input of interface")?; let name = format!("{}_{}_{}_hid", scheme, port, interface_num); common::setup_logging( @@ -247,7 +247,13 @@ fn main() -> Result<()> { reqs::set_idle(&handle, 1, 0, interface_num as u16).context("Failed to set idle")?; let report_desc_len = hid_desc.desc_len; - assert_eq!(hid_desc.desc_ty, REPORT_DESC_TY); + if hid_desc.desc_ty != REPORT_DESC_TY { + anyhow::bail!( + "unexpected HID descriptor type {:X}, expected {:X}", + hid_desc.desc_ty, + REPORT_DESC_TY + ); + } let mut report_desc_bytes = vec![0u8; report_desc_len as usize]; handle @@ -261,8 +267,8 @@ fn main() -> Result<()> { ) .context("Failed to retrieve report descriptor")?; - let mut handler = - ReportHandler::new(&report_desc_bytes).expect("failed to parse report descriptor"); + let mut handler = ReportHandler::new(&report_desc_bytes) + .map_err(|e| anyhow::anyhow!("failed to parse report descriptor: {}", e))?; let report_len = match endp_desc_opt { Some((_endp_num, endp_desc)) => endp_desc.max_packet_size as usize, @@ -318,10 +324,14 @@ fn main() -> Result<()> { let mut mouse_dy = 0i32; let mut scroll_y = 0i32; let mut buttons = last_buttons; - for event in handler - .handle(&report_buffer) - .expect("failed to parse report") - { + let events = match handler.handle(&report_buffer) { + Ok(events) => events, + Err(err) => { + log::warn!("failed to parse report: {}", err); + continue; + } + }; + for event in events { log::debug!("{}", event); if event.usage_page == UsagePage::GenericDesktop as u16 { if event.usage == GenericDesktopUsage::X as u16 { diff --git a/drivers/pcid-spawner/Cargo.toml b/drivers/pcid-spawner/Cargo.toml index 8c03f8d3..8d3b3899 100644 --- a/drivers/pcid-spawner/Cargo.toml +++ b/drivers/pcid-spawner/Cargo.toml @@ -13,6 +13,7 @@ pico-args.workspace = true redox_syscall.workspace = true serde.workspace = true toml.workspace = true +redox-driver-sys = { path = "../../../../../../../local/recipes/drivers/redox-driver-sys/source" } config = { path = "../../config" } common = { path = "../common" } diff --git a/drivers/pcid-spawner/src/main.rs b/drivers/pcid-spawner/src/main.rs index a968f4d4..c7082b0b 100644 --- a/drivers/pcid-spawner/src/main.rs +++ b/drivers/pcid-spawner/src/main.rs @@ -1,10 +1,56 @@ use std::fs; +use std::path::Path; use std::process::Command; use anyhow::{anyhow, Context, Result}; use pcid_interface::config::Config; use pcid_interface::PciFunctionHandle; +use redox_driver_sys::pci::{PciDeviceInfo, PciLocation}; + +const PCI_SUBSYSTEM_IDS_OFFSET: u16 = 0x2C; + +fn parse_location_from_device_path(path: &Path) -> Option { + let name = path.file_name()?.to_str()?; + let (segment, rest) = name.split_once("--")?; + let (bus, rest) = rest.split_once("--")?; + let (device, function) = rest.split_once('.')?; + + Some(PciLocation { + segment: u16::from_str_radix(segment, 16).ok()?, + bus: u8::from_str_radix(bus, 16).ok()?, + device: u8::from_str_radix(device, 16).ok()?, + function: function.parse().ok()?, + }) +} + +fn read_subsystem_ids(handle: &mut PciFunctionHandle) -> (u16, u16) { + let value = unsafe { handle.read_config(PCI_SUBSYSTEM_IDS_OFFSET) }; + (value as u16, (value >> 16) as u16) +} + +fn build_quirk_info(handle: &mut PciFunctionHandle, device_path: &Path) -> Option { + let config = handle.config(); + let full_device_id = config.func.full_device_id; + let location = parse_location_from_device_path(device_path)?; + let (subsystem_vendor_id, subsystem_device_id) = read_subsystem_ids(handle); + + Some(PciDeviceInfo { + location, + vendor_id: full_device_id.vendor_id, + device_id: full_device_id.device_id, + subsystem_vendor_id, + subsystem_device_id, + revision: full_device_id.revision, + class_code: full_device_id.class, + subclass: full_device_id.subclass, + prog_if: full_device_id.interface, + header_type: 0, + irq: None, + bars: Vec::new(), + capabilities: Vec::new(), + }) +} fn main() -> Result<()> { let mut args = pico_args::Arguments::from_env(); @@ -85,6 +131,20 @@ fn main() -> Result<()> { let mut command = Command::new(program); command.args(args); + if let Some(info) = build_quirk_info(&mut handle, &device_path) { + let quirk_flags = info.quirks(); + if !quirk_flags.is_empty() { + log::info!( + "pcid-spawner: quirks for {} {:04x}:{:04x} = {:?}", + info.location.scheme_path(), + info.vendor_id, + info.device_id, + quirk_flags + ); + } + command.env("PCI_QUIRK_FLAGS", format!("{:#x}", quirk_flags.bits())); + } + log::info!("pcid-spawner: spawn {:?}", command); handle.enable_device(); @@ -99,3 +159,20 @@ fn main() -> Result<()> { Ok(()) } + +#[cfg(test)] +mod tests { + use super::parse_location_from_device_path; + use std::path::Path; + + #[test] + fn parses_scheme_pci_path_name() { + let location = parse_location_from_device_path(Path::new("/scheme/pci/0000--2a--1f.3")) + .expect("parse location"); + + assert_eq!(location.segment, 0); + assert_eq!(location.bus, 0x2a); + assert_eq!(location.device, 0x1f); + assert_eq!(location.function, 3); + } +} diff --git a/drivers/pcid/src/scheme.rs b/drivers/pcid/src/scheme.rs index c2caf804..95acdb57 100644 --- a/drivers/pcid/src/scheme.rs +++ b/drivers/pcid/src/scheme.rs @@ -21,6 +21,7 @@ enum Handle { TopLevel { entries: Vec }, Access, Device, + Config { addr: PciAddress }, Channel { addr: PciAddress, st: ChannelState }, SchemeRoot, } @@ -30,14 +31,20 @@ struct HandleWrapper { } impl Handle { fn is_file(&self) -> bool { - matches!(self, Self::Access | Self::Channel { .. }) + matches!( + self, + Self::Access | Self::Config { .. } | Self::Channel { .. } + ) } fn is_dir(&self) -> bool { !self.is_file() } // TODO: capability rather than root fn requires_root(&self) -> bool { - matches!(self, Self::Access | Self::Channel { .. }) + matches!( + self, + Self::Access | Self::Config { .. } | Self::Channel { .. } + ) } fn is_scheme_root(&self) -> bool { matches!(self, Self::SchemeRoot) @@ -132,6 +139,7 @@ impl SchemeSync for PciScheme { let (len, mode) = match handle.inner { Handle::TopLevel { ref entries } => (entries.len(), MODE_DIR | 0o755), Handle::Device => (DEVICE_CONTENTS.len(), MODE_DIR | 0o755), + Handle::Config { .. } => (256, MODE_CHR | 0o600), Handle::Access | Handle::Channel { .. } => (0, MODE_CHR | 0o600), Handle::SchemeRoot => return Err(Error::new(EBADF)), }; @@ -156,6 +164,18 @@ impl SchemeSync for PciScheme { match handle.inner { Handle::TopLevel { .. } => Err(Error::new(EISDIR)), Handle::Device => Err(Error::new(EISDIR)), + Handle::Config { addr } => { + let offset = _offset as u16; + let dword_offset = offset & !0x3; + let byte_offset = (offset & 0x3) as usize; + let bytes_to_read = buf.len().min(4 - byte_offset); + + let dword = unsafe { self.pcie.read(addr, dword_offset) }; + let bytes = dword.to_le_bytes(); + buf[..bytes_to_read] + .copy_from_slice(&bytes[byte_offset..byte_offset + bytes_to_read]); + Ok(bytes_to_read) + } Handle::Channel { addr: _, ref mut st, @@ -193,7 +213,9 @@ impl SchemeSync for PciScheme { return Ok(buf); } Handle::Device => DEVICE_CONTENTS, - Handle::Access | Handle::Channel { .. } => return Err(Error::new(ENOTDIR)), + Handle::Access | Handle::Config { .. } | Handle::Channel { .. } => { + return Err(Error::new(ENOTDIR)); + } Handle::SchemeRoot => return Err(Error::new(EBADF)), }; @@ -223,6 +245,20 @@ impl SchemeSync for PciScheme { } match handle.inner { + Handle::Config { addr } => { + let offset = _offset as u16; + let dword_offset = offset & !0x3; + let byte_offset = (offset & 0x3) as usize; + let bytes_to_write = buf.len().min(4 - byte_offset); + + let mut dword = unsafe { self.pcie.read(addr, dword_offset) }; + let mut bytes = dword.to_le_bytes(); + bytes[byte_offset..byte_offset + bytes_to_write] + .copy_from_slice(&buf[..bytes_to_write]); + dword = u32::from_le_bytes(bytes); + unsafe { self.pcie.write(addr, dword_offset, dword) }; + Ok(buf.len()) + } Handle::Channel { addr, ref mut st } => { Self::write_channel(&self.pcie, &mut self.tree, addr, st, buf) } @@ -318,6 +354,10 @@ impl PciScheme { func.enabled = false; } } + Some(HandleWrapper { + inner: Handle::Config { .. }, + .. + }) => {} _ => {} } } @@ -343,6 +383,7 @@ impl PciScheme { let path = &after[1..]; match path { + "config" => Handle::Config { addr }, "channel" => { if func.enabled { return Err(Error::new(ENOLCK)); diff --git a/drivers/storage/usbscsid/Cargo.toml b/drivers/storage/usbscsid/Cargo.toml index 4a36934e..a9c6447c 100644 --- a/drivers/storage/usbscsid/Cargo.toml +++ b/drivers/storage/usbscsid/Cargo.toml @@ -10,6 +10,7 @@ license = "MIT" [dependencies] base64 = "0.11" # Only for debugging +bitflags.workspace = true libredox.workspace = true plain.workspace = true driver-block = { path = "../driver-block" } @@ -17,6 +18,7 @@ daemon = { path = "../../../daemon" } redox_event.workspace = true redox_syscall = { workspace = true, features = ["std"] } thiserror.workspace = true +toml.workspace = true xhcid = { path = "../../usb/xhcid" } [lints] diff --git a/drivers/storage/usbscsid/src/quirks.rs b/drivers/storage/usbscsid/src/quirks.rs new file mode 100644 index 00000000..5051f1b0 --- /dev/null +++ b/drivers/storage/usbscsid/src/quirks.rs @@ -0,0 +1,212 @@ +use std::fs; +use std::path::{Path, PathBuf}; +use std::sync::OnceLock; + +use bitflags::bitflags; +use toml::Value; + +bitflags! { + #[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] + pub struct UsbStorageQuirkFlags: u32 { + const IGNORE_RESIDUE = 1 << 0; + const FIX_CAPACITY = 1 << 1; + const SINGLE_LUN = 1 << 2; + const MAX_SECTORS_64 = 1 << 3; + const INITIAL_READ10 = 1 << 4; + + const FIX_INQUIRY = 1 << 5; + const NOT_LOCKABLE = 1 << 6; + const SCM_MULT_TARG = 1 << 7; + const SANE_SENSE = 1 << 8; + const BULK_IGNORE_TAG = 1 << 9; + const NEEDS_SYNC_CACHE = 1 << 10; + const NO_WP_DETECT = 1 << 11; + const NO_READ_CAP16 = 1 << 12; + const IGNORE_DEVICE = 1 << 13; + } +} + +#[derive(Clone, Copy)] +struct CompiledQuirkEntry { + vendor: u16, + product: u16, + flags: UsbStorageQuirkFlags, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +struct RuntimeQuirkEntry { + vendor: u16, + product: u16, + flags: UsbStorageQuirkFlags, +} + +const COMPILED_QUIRKS: &[CompiledQuirkEntry] = &[ + CompiledQuirkEntry { vendor: 0x03EB, product: 0x2002, flags: UsbStorageQuirkFlags::IGNORE_RESIDUE }, + CompiledQuirkEntry { vendor: 0x03F0, product: 0x4002, flags: UsbStorageQuirkFlags::FIX_CAPACITY }, + CompiledQuirkEntry { vendor: 0x0409, product: 0x0040, flags: UsbStorageQuirkFlags::SINGLE_LUN }, + CompiledQuirkEntry { vendor: 0x0421, product: 0x0019, flags: UsbStorageQuirkFlags::MAX_SECTORS_64 }, + CompiledQuirkEntry { vendor: 0x090C, product: 0x6000, flags: UsbStorageQuirkFlags::INITIAL_READ10 }, + CompiledQuirkEntry { vendor: 0x1B1C, product: 0x1AB5, flags: UsbStorageQuirkFlags::INITIAL_READ10 }, +]; + +static RUNTIME_QUIRKS: OnceLock> = OnceLock::new(); + +pub fn lookup_usb_storage_quirks(vendor: u16, product: u16) -> UsbStorageQuirkFlags { + let mut flags = UsbStorageQuirkFlags::empty(); + + for entry in COMPILED_QUIRKS { + if entry.vendor == vendor && entry.product == product { + flags |= entry.flags; + } + } + + for entry in runtime_quirks() { + if entry.vendor == vendor && entry.product == product { + flags |= entry.flags; + } + } + + flags +} + +fn runtime_quirks() -> &'static [RuntimeQuirkEntry] { + RUNTIME_QUIRKS.get_or_init(load_runtime_quirks).as_slice() +} + +fn load_runtime_quirks() -> Vec { + let mut entries = Vec::new(); + let Some(dir_entries) = quirk_files() else { + return entries; + }; + + for path in dir_entries { + let Ok(text) = fs::read_to_string(&path) else { + continue; + }; + entries.extend(parse_runtime_quirks_from_toml(&text)); + } + + entries +} + +fn quirk_files() -> Option> { + let quirks_dir = Path::new("/etc/quirks.d"); + let read_dir = fs::read_dir(quirks_dir).ok()?; + + let mut files = read_dir + .filter_map(|entry| entry.ok()) + .map(|entry| entry.path()) + .filter(|path| path.extension().and_then(|ext| ext.to_str()) == Some("toml")) + .collect::>(); + files.sort(); + Some(files) +} + +fn parse_runtime_quirks_from_toml(text: &str) -> Vec { + let Ok(value) = text.parse::() else { + return Vec::new(); + }; + + let Some(entries) = value.get("usb_storage_quirk").and_then(Value::as_array) else { + return Vec::new(); + }; + + entries.iter().filter_map(parse_runtime_quirk_entry).collect() +} + +fn parse_runtime_quirk_entry(value: &Value) -> Option { + let table = value.as_table()?; + let vendor = u16::try_from(table.get("vendor")?.as_integer()?).ok()?; + let product = u16::try_from(table.get("product")?.as_integer()?).ok()?; + let flags = parse_flag_list(table.get("flags")?.as_array()?); + + (!flags.is_empty()).then_some(RuntimeQuirkEntry { vendor, product, flags }) +} + +fn parse_flag_list(values: &[Value]) -> UsbStorageQuirkFlags { + let mut flags = UsbStorageQuirkFlags::empty(); + + for value in values { + if let Some(name) = value.as_str().and_then(parse_flag_name) { + flags |= name; + } + } + + flags +} + +fn parse_flag_name(name: &str) -> Option { + Some(match name { + "ignore_residue" => UsbStorageQuirkFlags::IGNORE_RESIDUE, + "fix_capacity" => UsbStorageQuirkFlags::FIX_CAPACITY, + "single_lun" => UsbStorageQuirkFlags::SINGLE_LUN, + "max_sectors_64" => UsbStorageQuirkFlags::MAX_SECTORS_64, + "initial_read10" => UsbStorageQuirkFlags::INITIAL_READ10, + "fix_inquiry" => UsbStorageQuirkFlags::FIX_INQUIRY, + "not_lockable" => UsbStorageQuirkFlags::NOT_LOCKABLE, + "scm_mult_targ" => UsbStorageQuirkFlags::SCM_MULT_TARG, + "sane_sense" => UsbStorageQuirkFlags::SANE_SENSE, + "bulk_ignore_tag" => UsbStorageQuirkFlags::BULK_IGNORE_TAG, + "needs_sync_cache" => UsbStorageQuirkFlags::NEEDS_SYNC_CACHE, + "no_wp_detect" => UsbStorageQuirkFlags::NO_WP_DETECT, + "no_read_cap16" => UsbStorageQuirkFlags::NO_READ_CAP16, + "ignore_device" => UsbStorageQuirkFlags::IGNORE_DEVICE, + _ => return None, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn compiled_fallback_lookup_returns_expected_flags() { + let flags = lookup_usb_storage_quirks(0x090C, 0x6000); + assert!(flags.contains(UsbStorageQuirkFlags::INITIAL_READ10)); + } + + #[test] + fn runtime_toml_parser_keeps_supported_flags_and_skips_unknown_ones() { + let entries = parse_runtime_quirks_from_toml( + r#" + [[usb_storage_quirk]] + vendor = 0x1234 + product = 0x5678 + flags = ["ignore_residue", "unknown_flag", "fix_capacity"] + "#, + ); + + assert_eq!(entries.len(), 1); + assert!(entries[0].flags.contains(UsbStorageQuirkFlags::IGNORE_RESIDUE)); + assert!(entries[0].flags.contains(UsbStorageQuirkFlags::FIX_CAPACITY)); + assert!(!entries[0].flags.contains(UsbStorageQuirkFlags::SINGLE_LUN)); + } +} diff --git a/drivers/storage/usbscsid/src/main.rs b/drivers/storage/usbscsid/src/main.rs index 5382d118..dca7762c 100644 --- a/drivers/storage/usbscsid/src/main.rs +++ b/drivers/storage/usbscsid/src/main.rs @@ -1,53 +1,67 @@ use std::collections::BTreeMap; use std::env; +use std::error::Error as StdError; +use std::io; use driver_block::{Disk, DiskScheme, ExecutorTrait}; use syscall::{Error, EIO}; use xhcid_interface::{ConfigureEndpointsReq, PortId, XhciClientHandle}; pub mod protocol; +pub mod quirks; pub mod scsi; use crate::protocol::Protocol; use crate::scsi::Scsi; +type Result = std::result::Result>; + +const USAGE: &str = "usbscsid "; + fn main() { daemon::Daemon::new(daemon); } fn daemon(daemon: daemon::Daemon) -> ! { - let mut args = env::args().skip(1); + let exit_code = match run(daemon) { + Ok(()) => 0, + Err(err) => { + eprintln!("usbscsid: {err}"); + 1 + } + }; - const USAGE: &'static str = "usbscsid "; + std::process::exit(exit_code); +} + +fn run(daemon: daemon::Daemon) -> Result<()> { + let mut args = env::args().skip(1); - let scheme = args.next().expect(USAGE); - let port = args + let scheme = next_arg(&mut args, "scheme")?; + let port: PortId = args .next() - .expect(USAGE) - .parse::() - .expect("Expected port ID"); - let protocol = args + .ok_or_else(|| usage_error("missing port argument"))? + .parse() + .map_err(|e| usage_error(format!("invalid port ID: {e}")))?; + let protocol_num: u8 = args .next() - .expect(USAGE) - .parse::() - .expect("protocol has to be a number 0-255"); + .ok_or_else(|| usage_error("missing protocol argument"))? + .parse() + .map_err(|e| usage_error(format!("protocol must be a number 0-255: {e}")))?; println!( "USB SCSI driver spawned with scheme `{}`, port {}, protocol {}", - scheme, port, protocol + scheme, port, protocol_num ); let disk_scheme_name = format!("disk.usb-{scheme}+{port}-scsi"); - // TODO: Use eventfds. - let handle = - XhciClientHandle::new(scheme.to_owned(), port).expect("Failed to open XhciClientHandle"); + let handle = XhciClientHandle::new(scheme.to_owned(), port) + .map_err(|e| runtime_error(format!("failed to open XhciClientHandle: {e}")))?; let desc = handle .get_standard_descs() - .expect("Failed to get standard descriptors"); + .map_err(|e| runtime_error(format!("failed to get standard descriptors: {e}")))?; - // TODO: Perhaps the drivers should just be given the config, interface, and alternate setting - // from xhcid. let (conf_desc, configuration_value, (if_desc, interface_num, alternate_setting)) = desc .config_descs .iter() @@ -65,7 +79,7 @@ fn daemon(daemon: daemon::Daemon) -> ! { interface_desc, )) }) - .expect("Failed to find suitable configuration"); + .ok_or_else(|| runtime_error("failed to find suitable SCSI BOT configuration"))?; handle .configure_endpoints(&ConfigureEndpointsReq { @@ -74,20 +88,37 @@ fn daemon(daemon: daemon::Daemon) -> ! { alternate_setting: Some(alternate_setting), hub_ports: None, }) - .expect("Failed to configure endpoints"); - - let mut protocol = protocol::setup(&handle, protocol, &desc, &conf_desc, &if_desc) - .expect("Failed to setup protocol"); - - // TODO: Let all of the USB drivers fork or be managed externally, and xhcid won't have to keep - // track of all the drivers. - let mut scsi = Scsi::new(&mut *protocol).expect("usbscsid: failed to setup SCSI"); + .map_err(|e| runtime_error(format!("failed to configure endpoints: {e}")))?; + + let vendor = desc.vendor; + let product = desc.product; + let storage_quirks = quirks::lookup_usb_storage_quirks(vendor, product); + + let mut protocol = protocol::setup( + &handle, + protocol_num, + &desc, + &conf_desc, + &if_desc, + storage_quirks, + ) + .ok_or_else(|| { + runtime_error(format!( + "failed to setup protocol (protocol 0x{protocol_num:02x})" + )) + })?; + + let mut scsi = Scsi::new(&mut *protocol, storage_quirks) + .map_err(|e| runtime_error(format!("failed to setup SCSI: {e}")))?; println!("SCSI initialized"); let mut buffer = [0u8; 512]; - scsi.read(&mut *protocol, 0, &mut buffer).unwrap(); - println!("DISK CONTENT: {}", base64::encode(&buffer[..])); + match scsi.read(&mut *protocol, 0, &mut buffer) { + Ok(_) => println!("DISK CONTENT: {}", base64::encode(&buffer[..])), + Err(e) => eprintln!("usbscsid: initial sector read failed: {e}"), + } - let event_queue = event::EventQueue::new().unwrap(); + let event_queue = event::EventQueue::new() + .map_err(|e| runtime_error(format!("failed to create event queue: {e}")))?; event::user_data! { enum Event { @@ -119,17 +150,41 @@ fn daemon(daemon: daemon::Daemon) -> ! { Event::Scheme, event::EventFlags::READ, ) - .unwrap(); + .map_err(|e| runtime_error(format!("failed to subscribe to scheme events: {e}")))?; for event in event_queue { - match event.unwrap().user_data { - Event::Scheme => driver_block::FuturesExecutor - .block_on(scheme.tick()) - .unwrap(), + match event { + Ok(ev) => match ev.user_data { + Event::Scheme => { + if let Err(e) = driver_block::FuturesExecutor.block_on(scheme.tick()) { + eprintln!("usbscsid: scheme tick error: {e}"); + } + } + }, + Err(e) => { + eprintln!("usbscsid: event queue error: {e}"); + } } } - std::process::exit(0); + Err(runtime_error("event queue terminated").into()) +} + +fn next_arg(args: &mut impl Iterator, name: &str) -> io::Result { + args.next() + .ok_or_else(|| usage_error(format!("missing {name} argument"))) +} + +fn usage_error(message: impl Into) -> io::Error { + let message = message.into(); + io::Error::new( + io::ErrorKind::InvalidInput, + format!("{message} (usage: {USAGE})"), + ) +} + +fn runtime_error(message: impl Into) -> io::Error { + io::Error::other(message.into()) } struct UsbDisk<'a> { diff --git a/drivers/storage/usbscsid/src/protocol/bot.rs b/drivers/storage/usbscsid/src/protocol/bot.rs index b751d51a..848ae0e9 100644 --- a/drivers/storage/usbscsid/src/protocol/bot.rs +++ b/drivers/storage/usbscsid/src/protocol/bot.rs @@ -8,6 +8,7 @@ use xhcid_interface::{ }; use super::{Protocol, ProtocolError, SendCommandStatus, SendCommandStatusKind}; +use crate::quirks::UsbStorageQuirkFlags; pub const CBW_SIGNATURE: u32 = 0x43425355; @@ -88,9 +89,12 @@ pub struct BulkOnlyTransport<'a> { bulk_out: XhciEndpHandle, bulk_in_num: u8, bulk_out_num: u8, + bulk_in_addr: u8, + bulk_out_addr: u8, max_lun: u8, current_tag: u32, interface_num: u8, + quirks: UsbStorageQuirkFlags, } pub const FEATURE_ENDPOINT_HALT: u16 = 0; @@ -98,23 +102,29 @@ pub const FEATURE_ENDPOINT_HALT: u16 = 0; impl<'a> BulkOnlyTransport<'a> { pub fn init( handle: &'a XhciClientHandle, - config_desc: &ConfDesc, + _config_desc: &ConfDesc, if_desc: &IfDesc, + quirks: UsbStorageQuirkFlags, ) -> Result { let endpoints = &if_desc.endpoints; - let bulk_in_num = (endpoints + let (bulk_in_idx, bulk_in_desc) = endpoints .iter() - .position(|endpoint| endpoint.direction() == EndpDirection::In) - .unwrap() - + 1) as u8; - let bulk_out_num = (endpoints + .enumerate() + .find(|(_, endpoint)| endpoint.direction() == EndpDirection::In) + .ok_or(ProtocolError::ProtocolError("no bulk IN endpoint found"))?; + let (bulk_out_idx, bulk_out_desc) = endpoints .iter() - .position(|endpoint| endpoint.direction() == EndpDirection::Out) - .unwrap() - + 1) as u8; + .enumerate() + .find(|(_, endpoint)| endpoint.direction() == EndpDirection::Out) + .ok_or(ProtocolError::ProtocolError("no bulk OUT endpoint found"))?; - let max_lun = get_max_lun(handle, 0)?; + let bulk_in_num = (bulk_in_idx + 1) as u8; + let bulk_out_num = (bulk_out_idx + 1) as u8; + let bulk_in_addr = bulk_in_desc.address; + let bulk_out_addr = bulk_out_desc.address; + + let max_lun = get_max_lun(handle, if_desc.number.into())?; println!("BOT_MAX_LUN {}", max_lun); Ok(Self { @@ -122,10 +132,13 @@ impl<'a> BulkOnlyTransport<'a> { bulk_out: handle.open_endpoint(bulk_out_num)?, bulk_in_num, bulk_out_num, + bulk_in_addr, + bulk_out_addr, handle, max_lun, current_tag: 0, interface_num: if_desc.number, + quirks, }) } fn clear_stall_in(&mut self) -> Result<(), XhciClientHandleError> { @@ -133,7 +146,7 @@ impl<'a> BulkOnlyTransport<'a> { self.bulk_in.reset(false)?; self.handle.clear_feature( PortReqRecipient::Endpoint, - u16::from(self.bulk_in_num), + u16::from(self.bulk_in_addr), FEATURE_ENDPOINT_HALT, )?; } @@ -144,7 +157,7 @@ impl<'a> BulkOnlyTransport<'a> { self.bulk_out.reset(false)?; self.handle.clear_feature( PortReqRecipient::Endpoint, - u16::from(self.bulk_out_num), + u16::from(self.bulk_out_addr), FEATURE_ENDPOINT_HALT, )?; } @@ -162,38 +175,59 @@ impl<'a> BulkOnlyTransport<'a> { } Ok(()) } - fn read_csw_raw( - &mut self, - csw_buffer: &mut [u8; 13], - already: bool, - ) -> Result<(), ProtocolError> { - match self.bulk_in.transfer_read(&mut csw_buffer[..])? { - PortTransferStatus { - kind: PortTransferStatusKind::Stalled, - .. - } => { - if already { + fn read_csw(&mut self, csw_buffer: &mut [u8; 13]) -> Result<(), ProtocolError> { + let mut attempts = 0u8; + loop { + let status = self.bulk_in.transfer_read(&mut csw_buffer[..])?; + match status { + PortTransferStatus { + kind: PortTransferStatusKind::Stalled, + .. + } => { + attempts += 1; + if attempts >= 2 { + self.reset_recovery()?; + return Err(ProtocolError::ProtocolError( + "bulk IN stalled repeatedly when reading CSW", + )); + } + eprintln!("usbscsid: bulk IN stalled when reading CSW, clearing stall"); + self.clear_stall_in()?; + continue; + } + PortTransferStatus { + kind: PortTransferStatusKind::ShortPacket, + bytes_transferred, + } if bytes_transferred != 13 => { + eprintln!( + "usbscsid: short packet when reading CSW ({} != 13)", + bytes_transferred + ); self.reset_recovery()?; + return Err(ProtocolError::ProtocolError( + "short packet when reading CSW", + )); + } + PortTransferStatus { + kind: PortTransferStatusKind::Success, + .. + } + | PortTransferStatus { + kind: PortTransferStatusKind::ShortPacket, + bytes_transferred: 13, + } => return Ok(()), + _ => { + eprintln!( + "usbscsid: unexpected transfer status when reading CSW: {:?}", + status + ); + self.reset_recovery()?; + return Err(ProtocolError::ProtocolError( + "unexpected transfer status when reading CSW", + )); } - println!("bulk in endpoint stalled when reading CSW"); - self.clear_stall_in()?; - self.read_csw_raw(csw_buffer, true)?; - } - PortTransferStatus { - kind: PortTransferStatusKind::ShortPacket, - bytes_transferred, - } if bytes_transferred != 13 => { - panic!( - "received a short packet when reading CSW ({} != 13)", - bytes_transferred - ) } - _ => (), } - Ok(()) - } - fn read_csw(&mut self, csw_buffer: &mut [u8; 13]) -> Result<(), ProtocolError> { - self.read_csw_raw(csw_buffer, false) } } @@ -207,8 +241,14 @@ impl<'a> Protocol for BulkOnlyTransport<'a> { let tag = self.current_tag; let mut cbw_bytes = [0u8; 31]; - let cbw = plain::from_mut_bytes::(&mut cbw_bytes).unwrap(); - *cbw = CommandBlockWrapper::new(tag, data.len() as u32, data.direction().into(), 0, cb)?; + let cbw = plain::from_mut_bytes::(&mut cbw_bytes) + .map_err(|_| ProtocolError::ProtocolError("CBW buffer size mismatch"))?; + let lun: u8 = if self.quirks.contains(UsbStorageQuirkFlags::SINGLE_LUN) { + 0 + } else { + 0 + }; + *cbw = CommandBlockWrapper::new(tag, data.len() as u32, data.direction().into(), lun, cb)?; let cbw = *cbw; match self.bulk_out.transfer_write(&cbw_bytes)? { @@ -216,22 +256,48 @@ impl<'a> Protocol for BulkOnlyTransport<'a> { kind: PortTransferStatusKind::Stalled, .. } => { - // TODO: Error handling - panic!("bulk out endpoint stalled when sending CBW {:?}", cbw); - //self.clear_stall_out()?; - //dbg!(self.bulk_in.status()?, self.bulk_out.status()?); + eprintln!( + "usbscsid: bulk OUT endpoint stalled when sending CBW {:?}", + cbw + ); + self.clear_stall_out()?; + return Err(ProtocolError::ProtocolError( + "bulk OUT endpoint stalled when sending CBW", + )); } PortTransferStatus { bytes_transferred, .. } if bytes_transferred != 31 => { - panic!( - "received short packet when sending CBW ({} != 31)", + eprintln!( + "usbscsid: short packet when sending CBW ({} != 31)", bytes_transferred ); + self.reset_recovery()?; + return Err(ProtocolError::ProtocolError( + "short packet when sending CBW", + )); + } + PortTransferStatus { + kind: PortTransferStatusKind::Success, + .. + } => (), + PortTransferStatus { + kind: PortTransferStatusKind::ShortPacket, + bytes_transferred: 31, + } => (), + status => { + eprintln!( + "usbscsid: unexpected transfer status {:?} when sending CBW", + status + ); + self.reset_recovery()?; + return Err(ProtocolError::ProtocolError( + "unexpected transfer status when sending CBW", + )); } - _ => (), } + let data_len = data.len() as u32; let early_residue: Option = match data { DeviceReqData::In(buffer) => match self.bulk_in.transfer_read(buffer)? { PortTransferStatus { @@ -240,15 +306,19 @@ impl<'a> Protocol for BulkOnlyTransport<'a> { } => match kind { PortTransferStatusKind::Success => None, PortTransferStatusKind::ShortPacket => { - println!( - "received short packet (len {}) when transferring data", - bytes_transferred + let residue = data_len.saturating_sub(bytes_transferred); + eprintln!( + "usbscsid: short packet ({} of {} bytes) during data read", + bytes_transferred, data_len ); - NonZeroU32::new(bytes_transferred) + NonZeroU32::new(residue) } PortTransferStatusKind::Stalled => { - panic!("bulk in endpoint stalled when reading data"); - //self.clear_stall_in()?; + eprintln!("usbscsid: bulk IN endpoint stalled when reading data"); + self.clear_stall_in()?; + return Err(ProtocolError::ProtocolError( + "bulk IN endpoint stalled during data read", + )); } PortTransferStatusKind::Unknown => { return Err(ProtocolError::XhciError( @@ -266,15 +336,19 @@ impl<'a> Protocol for BulkOnlyTransport<'a> { } => match kind { PortTransferStatusKind::Success => None, PortTransferStatusKind::ShortPacket => { - println!( - "received short packet (len {}) when transferring data", - bytes_transferred + let residue = data_len.saturating_sub(bytes_transferred); + eprintln!( + "usbscsid: short packet ({} of {} bytes) during data write", + bytes_transferred, data_len ); - NonZeroU32::new(bytes_transferred) + NonZeroU32::new(residue) } PortTransferStatusKind::Stalled => { - panic!("bulk out endpoint stalled when reading data"); - //self.clear_stall_out()?; + eprintln!("usbscsid: bulk OUT endpoint stalled when writing data"); + self.clear_stall_out()?; + return Err(ProtocolError::ProtocolError( + "bulk OUT endpoint stalled during data write", + )); } PortTransferStatusKind::Unknown => { return Err(ProtocolError::XhciError( @@ -290,9 +364,14 @@ impl<'a> Protocol for BulkOnlyTransport<'a> { let mut csw_buffer = [0u8; 13]; self.read_csw(&mut csw_buffer)?; - let csw = plain::from_bytes::(&csw_buffer).unwrap(); + let csw = plain::from_bytes::(&csw_buffer) + .map_err(|_| ProtocolError::ProtocolError("CSW buffer size mismatch"))?; - let residue = early_residue.or(NonZeroU32::new(csw.data_residue)); + let residue = if self.quirks.contains(UsbStorageQuirkFlags::IGNORE_RESIDUE) { + None + } else { + early_residue.or(NonZeroU32::new(csw.data_residue)) + }; if csw.status == CswStatus::Failed as u8 { println!("CSW indicated failure (CSW {:?}, CBW {:?})", csw, cbw); diff --git a/drivers/storage/usbscsid/src/protocol/mod.rs b/drivers/storage/usbscsid/src/protocol/mod.rs index a580765f..bde9affc 100644 --- a/drivers/storage/usbscsid/src/protocol/mod.rs +++ b/drivers/storage/usbscsid/src/protocol/mod.rs @@ -6,6 +6,8 @@ use xhcid_interface::{ ConfDesc, DevDesc, DeviceReqData, IfDesc, XhciClientHandle, XhciClientHandleError, }; +use crate::quirks::UsbStorageQuirkFlags; + #[derive(Debug, Error)] pub enum ProtocolError { #[error("Too large command block ({0} > 16)")] @@ -59,22 +61,19 @@ pub trait Protocol { /// Bulk-only transport pub mod bot; -mod uas { - // TODO -} - use bot::BulkOnlyTransport; pub fn setup<'a>( handle: &'a XhciClientHandle, protocol: u8, - dev_desc: &DevDesc, + _dev_desc: &DevDesc, conf_desc: &ConfDesc, if_desc: &IfDesc, + quirks: UsbStorageQuirkFlags, ) -> Option> { match protocol { 0x50 => Some(Box::new( - BulkOnlyTransport::init(handle, conf_desc, if_desc).unwrap(), + BulkOnlyTransport::init(handle, conf_desc, if_desc, quirks).ok()?, )), _ => None, } diff --git a/drivers/storage/usbscsid/src/scsi/cmds.rs b/drivers/storage/usbscsid/src/scsi/cmds.rs index ab02525e..ddc12336 100644 --- a/drivers/storage/usbscsid/src/scsi/cmds.rs +++ b/drivers/storage/usbscsid/src/scsi/cmds.rs @@ -179,9 +179,6 @@ unsafe impl plain::Plain for Read16 {} impl Read16 { pub const fn new(lba: u64, transfer_len: u32, control: u8) -> Self { - // TODO: RDPROTECT, DPO, FUA, RARC - // TODO: DLD - // TODO: Group number Self { opcode: Opcode::Read16 as u8, a: 0, @@ -193,6 +190,31 @@ impl Read16 { } } +#[repr(C, packed)] +#[derive(Clone, Copy, Debug)] +pub struct Read10 { + pub opcode: u8, + pub a: u8, + pub lba: u32, + pub group_num: u8, + pub transfer_len: u16, + pub control: u8, +} +unsafe impl plain::Plain for Read10 {} + +impl Read10 { + pub const fn new(lba: u64, transfer_len: u16, control: u8) -> Self { + Self { + opcode: Opcode::Read10 as u8, + a: 0, + lba: u32::to_be(lba as u32), + group_num: 0, + transfer_len: u16::to_be(transfer_len), + control, + } + } +} + #[repr(C, packed)] #[derive(Clone, Copy, Debug)] pub struct Write16 { @@ -219,6 +241,31 @@ impl Write16 { } } +#[repr(C, packed)] +#[derive(Clone, Copy, Debug)] +pub struct Write10 { + pub opcode: u8, + pub a: u8, + pub lba: u32, + pub group_num: u8, + pub transfer_len: u16, + pub control: u8, +} +unsafe impl plain::Plain for Write10 {} + +impl Write10 { + pub const fn new(lba: u64, transfer_len: u16, control: u8) -> Self { + Self { + opcode: Opcode::Write10 as u8, + a: 0, + lba: u32::to_be(lba as u32), + group_num: 0, + transfer_len: u16::to_be(transfer_len), + control, + } + } +} + #[repr(C, packed)] #[derive(Clone, Copy, Debug)] pub struct ModeSense6 { @@ -438,7 +485,35 @@ impl ReadCapacity10 { } } } -// TODO: ReadCapacity16 + +/// SERVICE ACTION IN(16) with service action 0x10 (READ CAPACITY(16)). +/// Required for devices larger than 2 TB where ReadCapacity10 cannot +/// represent the full block count. +#[repr(C, packed)] +#[derive(Clone, Copy, Debug)] +pub struct ReadCapacity16 { + pub opcode: u8, + pub service_action: u8, + pub lba: u64, + pub alloc_len: u32, + pub _rsvd: u8, + pub control: u8, +} + +impl ReadCapacity16 { + pub const fn new(control: u8) -> Self { + Self { + opcode: Opcode::ServiceAction9E as u8, + service_action: 0x10, + lba: 0, + alloc_len: u32::to_be(32), + _rsvd: 0, + control, + } + } +} + +unsafe impl plain::Plain for ReadCapacity16 {} #[repr(C, packed)] #[derive(Clone, Copy, Debug)] @@ -457,6 +532,27 @@ impl ReadCapacity10ParamData { } } +/// Response data for READ CAPACITY(16). The minimum useful response is +/// 12 bytes (max LBA + block length), but the device may return up to +/// 32 bytes with additional protection and mapping information. +#[repr(C, packed)] +#[derive(Clone, Copy, Debug)] +pub struct ReadCapacity16ParamData { + pub max_lba: u64, + pub block_len: u32, + pub _rest: [u8; 20], +} +unsafe impl plain::Plain for ReadCapacity16ParamData {} + +impl ReadCapacity16ParamData { + pub const fn block_count(&self) -> u64 { + u64::from_be(self.max_lba) + } + pub const fn logical_block_len(&self) -> u32 { + u32::from_be(self.block_len) + } +} + #[repr(C, packed)] #[derive(Clone, Copy, Debug)] pub struct RwErrorRecoveryPage { diff --git a/drivers/storage/usbscsid/src/scsi/mod.rs b/drivers/storage/usbscsid/src/scsi/mod.rs index 790abea6..b6d379d0 100644 --- a/drivers/storage/usbscsid/src/scsi/mod.rs +++ b/drivers/storage/usbscsid/src/scsi/mod.rs @@ -8,6 +8,7 @@ use thiserror::Error; use xhcid_interface::DeviceReqData; use crate::protocol::{Protocol, ProtocolError, SendCommandStatus, SendCommandStatusKind}; +use crate::quirks::UsbStorageQuirkFlags; use cmds::StandardInquiryData; pub struct Scsi { @@ -16,6 +17,7 @@ pub struct Scsi { data_buffer: Vec, pub block_size: u32, pub block_count: u64, + pub quirks: UsbStorageQuirkFlags, } const INQUIRY_CMD_LEN: u8 = 6; @@ -23,6 +25,7 @@ const REPORT_SUPP_OPCODES_CMD_LEN: u8 = 12; const REQUEST_SENSE_CMD_LEN: u8 = 6; const MIN_INQUIRY_ALLOC_LEN: u16 = 5; const MIN_REPORT_SUPP_OPCODES_ALLOC_LEN: u32 = 4; +const MAX_SECTORS_64_LIMIT: u64 = 64; type Result = std::result::Result; @@ -35,11 +38,74 @@ pub enum ScsiError { #[error("overflow")] Overflow(&'static str), + + #[error("invalid size for {context}: expected {expected}, got {actual}")] + InvalidSize { + context: &'static str, + expected: usize, + actual: usize, + }, + + #[error("insufficient bytes for {context}: need {needed}, have {actual}")] + InsufficientBytes { + context: &'static str, + needed: usize, + actual: usize, + }, + + #[error("plain parse error for {context}: {message}")] + PlainParse { + context: &'static str, + message: String, + }, + + #[error("invalid block size reported by device: {0}")] + InvalidBlockSize(u32), +} + +fn parse_bytes<'a, T: plain::Plain>(context: &'static str, bytes: &'a [u8]) -> Result<&'a T> { + let needed = mem::size_of::(); + if bytes.len() < needed { + return Err(ScsiError::InsufficientBytes { + context, + needed, + actual: bytes.len(), + }); + } + plain::from_bytes(bytes).map_err(|e| ScsiError::PlainParse { + context, + message: format!("{e:?}"), + }) +} + +fn parse_mut_bytes<'a, T: plain::Plain>( + context: &'static str, + bytes: &'a mut [u8], +) -> Result<&'a mut T> { + let needed = mem::size_of::(); + if bytes.len() < needed { + return Err(ScsiError::InsufficientBytes { + context, + needed, + actual: bytes.len(), + }); + } + plain::from_mut_bytes(bytes).map_err(|e| ScsiError::PlainParse { + context, + message: format!("{e:?}"), + }) } impl Scsi { - pub fn new(protocol: &mut dyn Protocol) -> Result { - assert_eq!(std::mem::size_of::(), 96); + pub fn new(protocol: &mut dyn Protocol, quirks: UsbStorageQuirkFlags) -> Result { + let inquiry_size = std::mem::size_of::(); + if inquiry_size != 96 { + return Err(ScsiError::InvalidSize { + context: "StandardInquiryData", + expected: 96, + actual: inquiry_size, + }); + } let mut this = Self { command_buffer: [0u8; 16], @@ -49,6 +115,7 @@ impl Scsi { data_buffer: Vec::new(), block_size: 0, block_count: 0, + quirks, }; // Get the max length that the device supports, of the Standard Inquiry Data. @@ -56,9 +123,11 @@ impl Scsi { // Get the Standard Inquiry Data. this.get_standard_inquiry_data(protocol, max_inquiry_len)?; - let version = this.res_standard_inquiry_data().version(); + let version = this.res_standard_inquiry_data()?.version(); println!("Inquiry version: {}", version); + let fix_capacity = this.quirks.contains(UsbStorageQuirkFlags::FIX_CAPACITY); + let (block_size, block_count) = { let (_, blkdescs, mode_page_iter) = this.get_mode_sense10(protocol)?; @@ -74,10 +143,25 @@ impl Scsi { println!("read_capacity10"); let r = this.read_capacity(protocol)?; println!("read_capacity10 result: {:?}", r); - (r.logical_block_len(), r.block_count().into()) + let mut count = r.block_count(); + if fix_capacity { + count = count.saturating_sub(1); + } + if count == u32::MAX { + println!("read_capacity10 returned max LBA, trying read_capacity16"); + let r16 = this.read_capacity16(protocol)?; + println!("read_capacity16 result: {:?}", r16); + (r16.logical_block_len(), r16.block_count()) + } else { + (r.logical_block_len(), u64::from(count)) + } } }; + if block_size == 0 { + return Err(ScsiError::InvalidBlockSize(block_size)); + } + this.block_size = block_size; this.block_count = block_count; @@ -85,7 +169,7 @@ impl Scsi { } pub fn get_inquiry_alloc_len(&mut self, protocol: &mut dyn Protocol) -> Result { self.get_standard_inquiry_data(protocol, MIN_INQUIRY_ALLOC_LEN)?; - let standard_inquiry_data = self.res_standard_inquiry_data(); + let standard_inquiry_data = self.res_standard_inquiry_data()?; Ok(4 + u16::from(standard_inquiry_data.additional_len)) } pub fn get_standard_inquiry_data( @@ -93,7 +177,7 @@ impl Scsi { protocol: &mut dyn Protocol, max_inquiry_len: u16, ) -> Result<()> { - let inquiry = self.cmd_inquiry(); + let inquiry = self.cmd_inquiry()?; *inquiry = cmds::Inquiry::new(false, 0, max_inquiry_len, 0); protocol.send_command( @@ -103,7 +187,7 @@ impl Scsi { Ok(()) } pub fn get_ff_sense(&mut self, protocol: &mut dyn Protocol, alloc_len: u8) -> Result<()> { - let request_sense = self.cmd_request_sense(); + let request_sense = self.cmd_request_sense()?; *request_sense = cmds::RequestSense::new(false, alloc_len, 0); self.data_buffer.resize(alloc_len.into(), 0); protocol.send_command( @@ -117,14 +201,14 @@ impl Scsi { protocol: &mut dyn Protocol, ) -> Result<&cmds::ReadCapacity10ParamData> { // The spec explicitly states that the allocation length is 8 bytes. - let read_capacity10 = self.cmd_read_capacity10(); + let read_capacity10 = self.cmd_read_capacity10()?; *read_capacity10 = cmds::ReadCapacity10::new(0); self.data_buffer.resize(10usize, 0u8); protocol.send_command( &self.command_buffer[..10], DeviceReqData::In(&mut self.data_buffer[..8]), )?; - Ok(self.res_read_capacity10()) + self.res_read_capacity10() } pub fn get_mode_sense10( &mut self, @@ -135,7 +219,7 @@ impl Scsi { impl Iterator>, )> { let initial_alloc_len = mem::size_of::() as u16; // covers both mode_data_len and blk_desc_len. - let mode_sense10 = self.cmd_mode_sense10(); + let mode_sense10 = self.cmd_mode_sense10()?; *mode_sense10 = cmds::ModeSense10::get_block_desc(initial_alloc_len, 0); self.data_buffer .resize(mem::size_of::(), 0); @@ -146,108 +230,166 @@ impl Scsi { &self.command_buffer[..10], DeviceReqData::In(&mut self.data_buffer[..initial_alloc_len as usize]), )? { - self.get_ff_sense(protocol, 252)?; - panic!("{:?}", self.res_ff_sense_data()); + if let Ok(()) = self.get_ff_sense(protocol, 252) { + match self.res_ff_sense_data() { + Ok(sense_data) => { + eprintln!("usbscsid: MODE SENSE(10) failed: {:?}", sense_data); + } + Err(err) => { + eprintln!( + "usbscsid: MODE SENSE(10) failed and sense parsing failed: {err}" + ); + } + } + } + return Err(ScsiError::ProtocolError(ProtocolError::ProtocolError( + "MODE SENSE(10) command failed", + ))); } - let optimal_alloc_len = self.res_mode_param_header10().mode_data_len() + 2; // the length of the mode data field itself + let optimal_alloc_len = self.res_mode_param_header10()?.mode_data_len() + 2; // the length of the mode data field itself - let mode_sense10 = self.cmd_mode_sense10(); + let mode_sense10 = self.cmd_mode_sense10()?; *mode_sense10 = cmds::ModeSense10::get_block_desc(optimal_alloc_len, 0); self.data_buffer.resize(optimal_alloc_len as usize, 0); protocol.send_command( &self.command_buffer[..10], DeviceReqData::In(&mut self.data_buffer[..optimal_alloc_len as usize]), )?; - Ok(( - self.res_mode_param_header10(), - self.res_blkdesc_mode10(), - self.res_mode_pages10(), - )) + let header = self.res_mode_param_header10()?; + let blkdescs = self.res_blkdesc_mode10()?; + let mode_pages = self.res_mode_pages10()?; + Ok((header, blkdescs, mode_pages)) } - pub fn cmd_inquiry(&mut self) -> &mut cmds::Inquiry { - plain::from_mut_bytes(&mut self.command_buffer).unwrap() + pub fn cmd_inquiry(&mut self) -> Result<&mut cmds::Inquiry> { + parse_mut_bytes("INQUIRY command", &mut self.command_buffer) + } + pub fn cmd_mode_sense6(&mut self) -> Result<&mut cmds::ModeSense6> { + parse_mut_bytes("MODE SENSE(6) command", &mut self.command_buffer) } - pub fn cmd_mode_sense6(&mut self) -> &mut cmds::ModeSense6 { - plain::from_mut_bytes(&mut self.command_buffer).unwrap() + pub fn cmd_mode_sense10(&mut self) -> Result<&mut cmds::ModeSense10> { + parse_mut_bytes("MODE SENSE(10) command", &mut self.command_buffer) } - pub fn cmd_mode_sense10(&mut self) -> &mut cmds::ModeSense10 { - plain::from_mut_bytes(&mut self.command_buffer).unwrap() + pub fn cmd_request_sense(&mut self) -> Result<&mut cmds::RequestSense> { + parse_mut_bytes("REQUEST SENSE command", &mut self.command_buffer) } - pub fn cmd_request_sense(&mut self) -> &mut cmds::RequestSense { - plain::from_mut_bytes(&mut self.command_buffer).unwrap() + pub fn cmd_read_capacity10(&mut self) -> Result<&mut cmds::ReadCapacity10> { + parse_mut_bytes("READ CAPACITY(10) command", &mut self.command_buffer) } - pub fn cmd_read_capacity10(&mut self) -> &mut cmds::ReadCapacity10 { - plain::from_mut_bytes(&mut self.command_buffer).unwrap() + pub fn cmd_read16(&mut self) -> Result<&mut cmds::Read16> { + parse_mut_bytes("READ(16) command", &mut self.command_buffer) } - pub fn cmd_read16(&mut self) -> &mut cmds::Read16 { - plain::from_mut_bytes(&mut self.command_buffer).unwrap() + pub fn cmd_read10(&mut self) -> Result<&mut cmds::Read10> { + parse_mut_bytes("READ(10) command", &mut self.command_buffer) } - pub fn cmd_write16(&mut self) -> &mut cmds::Write16 { - plain::from_mut_bytes(&mut self.command_buffer).unwrap() + pub fn cmd_write16(&mut self) -> Result<&mut cmds::Write16> { + parse_mut_bytes("WRITE(16) command", &mut self.command_buffer) } - pub fn res_standard_inquiry_data(&self) -> &StandardInquiryData { - plain::from_bytes(&self.inquiry_buffer).unwrap() + pub fn cmd_write10(&mut self) -> Result<&mut cmds::Write10> { + parse_mut_bytes("WRITE(10) command", &mut self.command_buffer) } - pub fn res_ff_sense_data(&self) -> &cmds::FixedFormatSenseData { - plain::from_bytes(&self.data_buffer).unwrap() + pub fn res_standard_inquiry_data(&self) -> Result<&StandardInquiryData> { + parse_bytes("standard inquiry data", &self.inquiry_buffer) } - pub fn res_mode_param_header6(&self) -> &cmds::ModeParamHeader6 { - plain::from_bytes(&self.data_buffer).unwrap() + pub fn res_ff_sense_data(&self) -> Result<&cmds::FixedFormatSenseData> { + parse_bytes("fixed format sense data", &self.data_buffer) } - pub fn res_mode_param_header10(&self) -> &cmds::ModeParamHeader10 { - plain::from_bytes(&self.data_buffer).unwrap() + pub fn res_mode_param_header6(&self) -> Result<&cmds::ModeParamHeader6> { + parse_bytes("MODE SENSE(6) parameter header", &self.data_buffer) } - pub fn res_blkdesc_mode6(&self) -> &[cmds::ShortLbaModeParamBlkDesc] { - let header = self.res_mode_param_header6(); + pub fn res_mode_param_header10(&self) -> Result<&cmds::ModeParamHeader10> { + parse_bytes("MODE SENSE(10) parameter header", &self.data_buffer) + } + pub fn res_blkdesc_mode6(&self) -> Result<&[cmds::ShortLbaModeParamBlkDesc]> { + let header = self.res_mode_param_header6()?; let descs_start = mem::size_of::(); - plain::slice_from_bytes( - &self.data_buffer[descs_start..descs_start + usize::from(header.block_desc_len)], + let desc_len = usize::from(header.block_desc_len); + let descs_end = descs_start + .checked_add(desc_len) + .ok_or(ScsiError::Overflow("block descriptor length overflowed"))?; + if descs_end > self.data_buffer.len() { + return Err(ScsiError::Overflow( + "block descriptor length exceeds data buffer", + )); + } + Ok( + plain::slice_from_bytes(&self.data_buffer[descs_start..descs_end]) + .map_err(|_| ScsiError::Overflow("block descriptor alignment mismatch"))?, ) - .unwrap() } - pub fn res_blkdesc_mode10(&self) -> BlkDescSlice<'_> { - let header = self.res_mode_param_header10(); + pub fn res_blkdesc_mode10(&self) -> Result> { + let header = self.res_mode_param_header10()?; let descs_start = mem::size_of::(); + let descs_end = descs_start + .checked_add(usize::from(header.block_desc_len())) + .ok_or(ScsiError::Overflow("block descriptor length overflowed"))?; + let desc_range = descs_start..descs_end; + if desc_range.end > self.data_buffer.len() { + return Err(ScsiError::Overflow( + "block descriptor length exceeds data buffer", + )); + } if header.longlba() { - BlkDescSlice::Long( - plain::slice_from_bytes( - &self.data_buffer - [descs_start..descs_start + usize::from(header.block_desc_len())], - ) - .unwrap(), - ) - } else if self.res_standard_inquiry_data().periph_dev_ty() - != cmds::PeriphDeviceType::DirectAccess as u8 - && self.res_standard_inquiry_data().version() == cmds::InquiryVersion::Spc3 as u8 - { - BlkDescSlice::General( - plain::slice_from_bytes( - &self.data_buffer - [descs_start..descs_start + usize::from(header.block_desc_len())], - ) - .unwrap(), - ) + Ok(BlkDescSlice::Long( + plain::slice_from_bytes(&self.data_buffer[desc_range]).map_err(|_| { + ScsiError::Overflow("long LBA block descriptor alignment mismatch") + })?, + )) } else { - BlkDescSlice::Short( - plain::slice_from_bytes( - &self.data_buffer - [descs_start..descs_start + usize::from(header.block_desc_len())], - ) - .unwrap(), - ) + let inquiry = self.res_standard_inquiry_data()?; + if inquiry.periph_dev_ty() != cmds::PeriphDeviceType::DirectAccess as u8 + && inquiry.version() == cmds::InquiryVersion::Spc3 as u8 + { + Ok(BlkDescSlice::General( + plain::slice_from_bytes(&self.data_buffer[desc_range]).map_err(|_| { + ScsiError::Overflow("general block descriptor alignment mismatch") + })?, + )) + } else { + Ok(BlkDescSlice::Short( + plain::slice_from_bytes(&self.data_buffer[desc_range]).map_err(|_| { + ScsiError::Overflow("short LBA block descriptor alignment mismatch") + })?, + )) + } } } - pub fn res_mode_pages10(&self) -> impl Iterator> { - let header = self.res_mode_param_header10(); + pub fn res_mode_pages10(&self) -> Result> + '_> { + let header = self.res_mode_param_header10()?; let descs_start = mem::size_of::(); - let buffer = &self.data_buffer[descs_start + header.block_desc_len() as usize..]; - cmds::mode_page_iter(buffer) + let pages_start = descs_start + .checked_add(header.block_desc_len() as usize) + .ok_or(ScsiError::Overflow("mode page offset overflowed"))?; + if pages_start > self.data_buffer.len() { + return Err(ScsiError::Overflow("mode page offset exceeds data buffer")); + } + let buffer = &self.data_buffer[pages_start..]; + Ok(cmds::mode_page_iter(buffer)) + } + pub fn res_read_capacity10(&self) -> Result<&cmds::ReadCapacity10ParamData> { + parse_bytes("READ CAPACITY(10) parameter data", &self.data_buffer) + } + pub fn read_capacity16( + &mut self, + protocol: &mut dyn Protocol, + ) -> Result<&cmds::ReadCapacity16ParamData> { + let cmd = self.cmd_read_capacity16()?; + *cmd = cmds::ReadCapacity16::new(0); + self.data_buffer + .resize(mem::size_of::(), 0); + protocol.send_command( + &self.command_buffer[..16], + DeviceReqData::In(&mut self.data_buffer[..32]), + )?; + self.res_read_capacity16() } - pub fn res_read_capacity10(&self) -> &cmds::ReadCapacity10ParamData { - plain::from_bytes(&self.data_buffer).unwrap() + pub fn cmd_read_capacity16(&mut self) -> Result<&mut cmds::ReadCapacity16> { + parse_mut_bytes("READ CAPACITY(16) command", &mut self.command_buffer) + } + pub fn res_read_capacity16(&self) -> Result<&cmds::ReadCapacity16ParamData> { + parse_bytes("READ CAPACITY(16) parameter data", &self.data_buffer) } pub fn get_disk_size(&self) -> u64 { self.block_count * u64::from(self.block_size) @@ -258,44 +400,90 @@ impl Scsi { lba: u64, buffer: &mut [u8], ) -> Result { - let blocks_to_read = buffer.len() as u64 / u64::from(self.block_size); - let bytes_to_read = blocks_to_read as usize * self.block_size as usize; - let transfer_len = u32::try_from(blocks_to_read).or(Err(ScsiError::Overflow( - "number of blocks to read couldn't fit inside a u32", - )))?; + let mut blocks_to_read = buffer.len() as u64 / u64::from(self.block_size); + + if self.quirks.contains(UsbStorageQuirkFlags::MAX_SECTORS_64) + && blocks_to_read > MAX_SECTORS_64_LIMIT { - let read = self.cmd_read16(); - *read = cmds::Read16::new(lba, transfer_len, 0); + blocks_to_read = MAX_SECTORS_64_LIMIT; + } + + let bytes_to_read = blocks_to_read as usize * self.block_size as usize; + + if self.quirks.contains(UsbStorageQuirkFlags::INITIAL_READ10) { + let transfer_len = u16::try_from(blocks_to_read).or(Err(ScsiError::Overflow( + "number of blocks to read couldn't fit inside a u16 for READ(10)", + )))?; + { + let read = self.cmd_read10()?; + *read = cmds::Read10::new(lba, transfer_len, 0); + } + self.data_buffer.resize(bytes_to_read, 0u8); + let status = protocol.send_command( + &self.command_buffer[..10], + DeviceReqData::In(&mut self.data_buffer[..bytes_to_read]), + )?; + buffer[..bytes_to_read].copy_from_slice(&self.data_buffer[..bytes_to_read]); + Ok(status.bytes_transferred(bytes_to_read as u32)) + } else { + let transfer_len = u32::try_from(blocks_to_read).or(Err(ScsiError::Overflow( + "number of blocks to read couldn't fit inside a u32", + )))?; + { + let read = self.cmd_read16()?; + *read = cmds::Read16::new(lba, transfer_len, 0); + } + self.data_buffer.resize(bytes_to_read, 0u8); + let status = protocol.send_command( + &self.command_buffer[..16], + DeviceReqData::In(&mut self.data_buffer[..bytes_to_read]), + )?; + buffer[..bytes_to_read].copy_from_slice(&self.data_buffer[..bytes_to_read]); + Ok(status.bytes_transferred(bytes_to_read as u32)) } - // TODO: Use the to-be-written TransferReadStream instead of relying on everything being - // able to fit within a single buffer. - self.data_buffer.resize(bytes_to_read, 0u8); - let status = protocol.send_command( - &self.command_buffer[..16], - DeviceReqData::In(&mut self.data_buffer[..bytes_to_read]), - )?; - buffer[..bytes_to_read].copy_from_slice(&self.data_buffer[..bytes_to_read]); - Ok(status.bytes_transferred(bytes_to_read as u32)) } pub fn write(&mut self, protocol: &mut dyn Protocol, lba: u64, buffer: &[u8]) -> Result { - let blocks_to_write = buffer.len() as u64 / u64::from(self.block_size); - let bytes_to_write = blocks_to_write as usize * self.block_size as usize; - let transfer_len = u32::try_from(blocks_to_write).or(Err(ScsiError::Overflow( - "number of blocks to write couldn't fit inside a u32", - )))?; + let mut blocks_to_write = buffer.len() as u64 / u64::from(self.block_size); + + if self.quirks.contains(UsbStorageQuirkFlags::MAX_SECTORS_64) + && blocks_to_write > MAX_SECTORS_64_LIMIT { - let read = self.cmd_write16(); - *read = cmds::Write16::new(lba, transfer_len, 0); + blocks_to_write = MAX_SECTORS_64_LIMIT; + } + + let bytes_to_write = blocks_to_write as usize * self.block_size as usize; + + if self.quirks.contains(UsbStorageQuirkFlags::INITIAL_READ10) { + let transfer_len = u16::try_from(blocks_to_write).or(Err(ScsiError::Overflow( + "number of blocks to write couldn't fit inside a u16 for WRITE(10)", + )))?; + { + let write = self.cmd_write10()?; + *write = cmds::Write10::new(lba, transfer_len, 0); + } + self.data_buffer.resize(bytes_to_write, 0u8); + self.data_buffer[..bytes_to_write].copy_from_slice(&buffer[..bytes_to_write]); + let status = protocol.send_command( + &self.command_buffer[..10], + DeviceReqData::Out(&buffer[..bytes_to_write]), + )?; + Ok(status.bytes_transferred(bytes_to_write as u32)) + } else { + let transfer_len = u32::try_from(blocks_to_write).or(Err(ScsiError::Overflow( + "number of blocks to write couldn't fit inside a u32", + )))?; + { + let write = self.cmd_write16()?; + *write = cmds::Write16::new(lba, transfer_len, 0); + } + self.data_buffer.resize(bytes_to_write, 0u8); + self.data_buffer[..bytes_to_write].copy_from_slice(&buffer[..bytes_to_write]); + let status = protocol.send_command( + &self.command_buffer[..16], + DeviceReqData::Out(&buffer[..bytes_to_write]), + )?; + Ok(status.bytes_transferred(bytes_to_write as u32)) } - // TODO: Use the to-be-written TransferReadStream instead of relying on everything being - // able to fit within a single buffer. - self.data_buffer.resize(bytes_to_write, 0u8); - self.data_buffer[..bytes_to_write].copy_from_slice(&buffer[..bytes_to_write]); - let status = protocol.send_command( - &self.command_buffer[..16], - DeviceReqData::Out(&buffer[..bytes_to_write]), - )?; - Ok(status.bytes_transferred(bytes_to_write as u32)) } } #[derive(Debug)] diff --git a/drivers/usb/usbhubd/src/main.rs b/drivers/usb/usbhubd/src/main.rs index 0e58542d..b13bb58a 100644 --- a/drivers/usb/usbhubd/src/main.rs +++ b/drivers/usb/usbhubd/src/main.rs @@ -84,7 +84,7 @@ fn main() -> Result<(), Box> { })?; // Read hub descriptor - let (ports, usb_3) = if desc.major_version() >= 3 { + let (ports, usb_3, hub_think_time) = if desc.major_version() >= 3 { // USB 3.0 hubs let mut hub_desc = usb::HubDescriptorV3::default(); handle @@ -101,7 +101,7 @@ fn main() -> Result<(), Box> { "Failed to read USB 3 hub descriptor for port {port_id}: {err}" )) })?; - (hub_desc.ports, true) + (hub_desc.ports, true, None) } else { // USB 2.0 and earlier hubs let mut hub_desc = usb::HubDescriptorV2::default(); @@ -119,7 +119,7 @@ fn main() -> Result<(), Box> { "Failed to read USB 2 hub descriptor for port {port_id}: {err}" )) })?; - (hub_desc.ports, false) + (hub_desc.ports, false, hub_desc.tt_think_time(desc.protocol)) }; @@ -128,6 +128,7 @@ fn main() -> Result<(), Box> { interface_desc: None, //TODO: stalls on USB 3 hub: Some(interface_num), alternate_setting: None, //TODO: stalls on USB 3 hub: Some(if_desc.alternate_setting), hub_ports: Some(ports), + hub_think_time, }) .map_err(|err| { other_error(format!( diff --git a/drivers/usb/xhcid/src/usb/hub.rs b/drivers/usb/xhcid/src/usb/hub.rs index 2d278320..fb02b17b 100644 --- a/drivers/usb/xhcid/src/usb/hub.rs +++ b/drivers/usb/xhcid/src/usb/hub.rs @@ -17,6 +17,23 @@ unsafe impl plain::Plain for HubDescriptorV2 {} impl HubDescriptorV2 { pub const DESCRIPTOR_KIND: u8 = 0x29; + + pub fn tt_think_time(self, device_protocol: u8) -> Option { + const HUB_CHAR_TTTT: u16 = 0x0060; + const HUB_TTTT_8_BITS: u16 = 0x0000; + const HUB_TTTT_16_BITS: u16 = 0x0020; + const HUB_TTTT_24_BITS: u16 = 0x0040; + const HUB_TTTT_32_BITS: u16 = 0x0060; + + match self.characteristics & HUB_CHAR_TTTT { + HUB_TTTT_8_BITS if device_protocol != 0 => Some(0), + HUB_TTTT_16_BITS => Some(1), + HUB_TTTT_24_BITS => Some(2), + HUB_TTTT_32_BITS => Some(3), + _ => None, + } + } } @@ -196,3 +213,23 @@ impl HubPortStatus { } } } + +#[cfg(test)] +mod tests { + use super::HubDescriptorV2; + + #[test] + fn usb2_hub_tt_think_time_decodes_linux_compatible_values() { + let mut hub = HubDescriptorV2::default(); + + hub.characteristics = 0x0000; + assert_eq!(hub.tt_think_time(0), None); + assert_eq!(hub.tt_think_time(1), Some(0)); + + hub.characteristics = 0x0020; + assert_eq!(hub.tt_think_time(0), Some(1)); + + hub.characteristics = 0x0040; + assert_eq!(hub.tt_think_time(0), Some(2)); + + hub.characteristics = 0x0060; + assert_eq!(hub.tt_think_time(0), Some(3)); + } +} diff --git a/drivers/usb/xhcid/src/xhci/scheme.rs b/drivers/usb/xhcid/src/xhci/scheme.rs index 627d33a7..7eb553ae 100644 --- a/drivers/usb/xhcid/src/xhci/scheme.rs +++ b/drivers/usb/xhcid/src/xhci/scheme.rs @@ -1196,11 +1196,8 @@ impl Xhci { // Set hub data current_slot_a &= !(1 << 26); current_slot_b &= !HUB_PORTS_MASK; - current_slot_c &= !TT_THINK_TIME_MASK; if let Some(hub_ports) = req.hub_ports { current_slot_a |= 1 << 26; current_slot_b |= (u32::from(hub_ports) << HUB_PORTS_SHIFT) & HUB_PORTS_MASK; - if let Some(hub_think_time) = req.hub_think_time { - current_slot_c |= (u32::from(hub_think_time) << TT_THINK_TIME_SHIFT) & TT_THINK_TIME_MASK; - } } current_slot_c = apply_hub_tt_info(current_slot_c, req); @@ -3250,6 +3247,21 @@ fn resolve_active_alternates( active } +fn apply_hub_tt_info(current_slot_c: u32, req: &ConfigureEndpointsReq) -> u32 { + const TT_THINK_TIME_MASK: u32 = 0x0003_0000; + const TT_THINK_TIME_SHIFT: u8 = 16; + + let mut slot_c = current_slot_c & !TT_THINK_TIME_MASK; + if req.hub_ports.is_some() { + if let Some(hub_think_time) = req.hub_think_time { + slot_c |= (u32::from(hub_think_time) << TT_THINK_TIME_SHIFT) & TT_THINK_TIME_MASK; + } + } + slot_c +} + use lazy_static::lazy_static; use std::ops::{Add, Div, Rem}; @@ -3283,4 +3295,18 @@ mod tests { assert_eq!(resolved.get(&0), Some(&1)); assert_eq!(resolved.get(&1), Some(&2)); } + + #[test] + fn apply_hub_tt_info_only_sets_bits_for_hub_requests() { + let req = ConfigureEndpointsReq { + config_desc: 1, + interface_desc: None, + alternate_setting: None, + hub_ports: Some(4), + hub_think_time: Some(3), + }; + assert_eq!(apply_hub_tt_info(0, &req), 0x0003_0000); + + let no_hub = ConfigureEndpointsReq { hub_ports: None, ..req.clone() }; + assert_eq!(apply_hub_tt_info(0x0003_0000, &no_hub), 0); + } } diff --git a/drivers/usb/usbhubd/src/main.rs b/drivers/usb/usbhubd/src/main.rs index 0e58542d..b13bb58a 100644 --- a/drivers/usb/usbhubd/src/main.rs +++ b/drivers/usb/usbhubd/src/main.rs @@ -84,7 +84,7 @@ fn main() -> Result<(), Box> { })?; // Read hub descriptor - let (ports, usb_3) = if desc.major_version() >= 3 { + let (ports, usb_3, hub_think_time) = if desc.major_version() >= 3 { // USB 3.0 hubs let mut hub_desc = usb::HubDescriptorV3::default(); handle @@ -101,7 +101,7 @@ fn main() -> Result<(), Box> { "Failed to read USB 3 hub descriptor for port {port_id}: {err}" )) })?; - (hub_desc.ports, true) + (hub_desc.ports, true, None) } else { // USB 2.0 and earlier hubs let mut hub_desc = usb::HubDescriptorV2::default(); @@ -119,7 +119,7 @@ fn main() -> Result<(), Box> { "Failed to read USB 2 hub descriptor for port {port_id}: {err}" )) })?; - (hub_desc.ports, false) + (hub_desc.ports, false, hub_desc.tt_think_time(desc.protocol)) }; @@ -128,6 +128,7 @@ fn main() -> Result<(), Box> { interface_desc: None, //TODO: stalls on USB 3 hub: Some(interface_num), alternate_setting: None, //TODO: stalls on USB 3 hub: Some(if_desc.alternate_setting), hub_ports: Some(ports), + hub_think_time, }) .map_err(|err| { other_error(format!( diff --git a/drivers/usb/xhcid/src/usb/hub.rs b/drivers/usb/xhcid/src/usb/hub.rs index b7dc4d54..2d278320 100644 --- a/drivers/usb/xhcid/src/usb/hub.rs +++ b/drivers/usb/xhcid/src/usb/hub.rs @@ -17,6 +17,23 @@ unsafe impl plain::Plain for HubDescriptorV2 {} impl HubDescriptorV2 { pub const DESCRIPTOR_KIND: u8 = 0x29; + + pub fn tt_think_time(self, device_protocol: u8) -> Option { + const HUB_CHAR_TTTT: u16 = 0x0060; + const HUB_TTTT_8_BITS: u16 = 0x0000; + const HUB_TTTT_16_BITS: u16 = 0x0020; + const HUB_TTTT_24_BITS: u16 = 0x0040; + const HUB_TTTT_32_BITS: u16 = 0x0060; + + match self.characteristics & HUB_CHAR_TTTT { + HUB_TTTT_8_BITS if device_protocol != 0 => Some(0), + HUB_TTTT_16_BITS => Some(1), + HUB_TTTT_24_BITS => Some(2), + HUB_TTTT_32_BITS => Some(3), + _ => None, + } + } } @@ -196,3 +213,23 @@ impl HubPortStatus { } } } + +#[cfg(test)] +mod tests { + use super::HubDescriptorV2; + + #[test] + fn usb2_hub_tt_think_time_decodes_linux_compatible_values() { + let mut hub = HubDescriptorV2::default(); + + hub.characteristics = 0x0000; + assert_eq!(hub.tt_think_time(0), None); + assert_eq!(hub.tt_think_time(1), Some(0)); + + hub.characteristics = 0x0020; + assert_eq!(hub.tt_think_time(0), Some(1)); + + hub.characteristics = 0x0040; + assert_eq!(hub.tt_think_time(0), Some(2)); + + hub.characteristics = 0x0060; + assert_eq!(hub.tt_think_time(0), Some(3)); + } +} diff --git a/drivers/usb/xhcid/src/xhci/scheme.rs b/drivers/usb/xhcid/src/xhci/scheme.rs index d5266ca0..627d33a7 100644 --- a/drivers/usb/xhcid/src/xhci/scheme.rs +++ b/drivers/usb/xhcid/src/xhci/scheme.rs @@ -1196,11 +1196,8 @@ impl Xhci { // Set hub data current_slot_a &= !(1 << 26); current_slot_b &= !HUB_PORTS_MASK; - current_slot_c &= !TT_THINK_TIME_MASK; if let Some(hub_ports) = req.hub_ports { current_slot_a |= 1 << 26; current_slot_b |= (u32::from(hub_ports) << HUB_PORTS_SHIFT) & HUB_PORTS_MASK; - if let Some(hub_think_time) = req.hub_think_time { - current_slot_c |= (u32::from(hub_think_time) << TT_THINK_TIME_SHIFT) & TT_THINK_TIME_MASK; - } } + current_slot_c = apply_hub_tt_info(current_slot_c, req); input_context.device.slot.a.write(current_slot_a); input_context.device.slot.b.write(current_slot_b); @@ -3250,6 +3247,21 @@ fn resolve_active_alternates( active } +fn apply_hub_tt_info(current_slot_c: u32, req: &ConfigureEndpointsReq) -> u32 { + const TT_THINK_TIME_MASK: u32 = 0x0003_0000; + const TT_THINK_TIME_SHIFT: u8 = 16; + + let mut slot_c = current_slot_c & !TT_THINK_TIME_MASK; + if req.hub_ports.is_some() { + if let Some(hub_think_time) = req.hub_think_time { + slot_c |= (u32::from(hub_think_time) << TT_THINK_TIME_SHIFT) & TT_THINK_TIME_MASK; + } + } + slot_c +} + use lazy_static::lazy_static; use std::ops::{Add, Div, Rem}; @@ -3283,4 +3295,18 @@ mod tests { assert_eq!(resolved.get(&0), Some(&1)); assert_eq!(resolved.get(&1), Some(&2)); } + + #[test] + fn apply_hub_tt_info_only_sets_bits_for_hub_requests() { + let req = ConfigureEndpointsReq { + config_desc: 1, + interface_desc: None, + alternate_setting: None, + hub_ports: Some(4), + hub_think_time: Some(3), + }; + assert_eq!(apply_hub_tt_info(0, &req), 0x0003_0000); + + let no_hub = ConfigureEndpointsReq { hub_ports: None, ..req.clone() }; + assert_eq!(apply_hub_tt_info(0x0003_0000, &no_hub), 0); + } } diff --git a/drivers/usb/usbhubd/src/main.rs b/drivers/usb/usbhubd/src/main.rs index 2c8b9876..68538b77 100644 --- a/drivers/usb/usbhubd/src/main.rs +++ b/drivers/usb/usbhubd/src/main.rs @@ -1,27 +1,41 @@ -use std::{env, thread, time}; +use std::{env, error::Error, io, thread, time}; use xhcid_interface::{ - plain, usb, ConfigureEndpointsReq, DevDesc, DeviceReqData, PortId, PortReqRecipient, PortReqTy, - XhciClientHandle, + plain, usb, ConfigureEndpointsReq, DevDesc, DeviceReqData, EndpDirection, PortId, + PortReqRecipient, PortReqTy, XhciClientHandle, XhciEndpHandle, }; -fn main() { +fn invalid_input_error(message: impl Into) -> Box { + let message = message.into(); + log::error!("{message}"); + Box::new(io::Error::new(io::ErrorKind::InvalidInput, message)) +} + +fn other_error(message: impl Into) -> Box { + let message = message.into(); + log::error!("{message}"); + Box::new(io::Error::other(message)) +} + +fn main() -> Result<(), Box> { common::init(); let mut args = env::args().skip(1); const USAGE: &'static str = "usbhubd "; - let scheme = args.next().expect(USAGE); + let scheme = args.next().ok_or_else(|| invalid_input_error(USAGE))?; let port_id = args .next() - .expect(USAGE) + .ok_or_else(|| invalid_input_error(USAGE))? .parse::() - .expect("Expected port ID"); + .map_err(|err| invalid_input_error(format!("Expected port ID: {err}")))?; let interface_num = args .next() - .expect(USAGE) + .ok_or_else(|| invalid_input_error(USAGE))? .parse::() - .expect("Expected integer as input of interface"); + .map_err(|err| { + invalid_input_error(format!("Expected integer as input of interface: {err}")) + })?; log::info!( "USB HUB driver spawned with scheme `{}`, port {}, interface {}", @@ -39,11 +53,16 @@ fn main() { common::file_level(), ); - let handle = - XhciClientHandle::new(scheme.clone(), port_id).expect("Failed to open XhciClientHandle"); - let desc: DevDesc = handle - .get_standard_descs() - .expect("Failed to get standard descriptors"); + let handle = XhciClientHandle::new(scheme.clone(), port_id).map_err(|err| { + other_error(format!( + "Failed to open XhciClientHandle for scheme `{scheme}` port {port_id}: {err}" + )) + })?; + let desc: DevDesc = handle.get_standard_descs().map_err(|err| { + other_error(format!( + "Failed to get standard descriptors for hub on port {port_id}: {err}" + )) + })?; let (conf_desc, if_desc) = desc .config_descs @@ -58,7 +77,11 @@ fn main() { })?; Some((conf_desc.clone(), if_desc)) }) - .expect("Failed to find suitable configuration"); + .ok_or_else(|| { + other_error(format!( + "Failed to find suitable configuration for hub interface {interface_num}" + )) + })?; // Read hub descriptor let (ports, usb_3) = if desc.major_version() >= 3 { @@ -73,7 +96,11 @@ fn main() { 0, DeviceReqData::In(unsafe { plain::as_mut_bytes(&mut hub_desc) }), ) - .expect("Failed to read hub descriptor"); + .map_err(|err| { + other_error(format!( + "Failed to read USB 3 hub descriptor for port {port_id}: {err}" + )) + })?; (hub_desc.ports, true) } else { // USB 2.0 and earlier hubs @@ -87,7 +114,11 @@ fn main() { 0, DeviceReqData::In(unsafe { plain::as_mut_bytes(&mut hub_desc) }), ) - .expect("Failed to read hub descriptor"); + .map_err(|err| { + other_error(format!( + "Failed to read USB 2 hub descriptor for port {port_id}: {err}" + )) + })?; (hub_desc.ports, false) }; @@ -99,7 +130,11 @@ fn main() { alternate_setting: None, //TODO: stalls on USB 3 hub: Some(if_desc.alternate_setting), hub_ports: Some(ports), }) - .expect("Failed to configure endpoints after reading hub descriptor"); + .map_err(|err| { + other_error(format!( + "Failed to configure endpoints after reading hub descriptor on port {port_id}: {err}" + )) + })?; if usb_3 { handle @@ -111,139 +146,353 @@ fn main() { 0, DeviceReqData::NoData, ) - .expect("Failed to set hub depth"); + .map_err(|err| { + other_error(format!("Failed to set hub depth for port {port_id}: {err}")) + })?; } + let interrupt_endpoint_desc = if_desc + .endpoints + .iter() + .find(|endp_desc| endp_desc.is_interrupt() && endp_desc.direction() == EndpDirection::In) + .copied(); + let status_change_bitmap_size = (usize::from(ports) + 8) / 8; + // Initialize states struct PortState { port_id: PortId, port_sts: usb::HubPortStatus, handle: XhciClientHandle, attached: bool, + suspended: bool, } impl PortState { - pub fn ensure_attached(&mut self, attached: bool) { + pub fn ensure_attached(&mut self, attached: bool) -> io::Result<()> { if attached == self.attached { - return; + return Ok(()); } if attached { - self.handle.attach().expect("Failed to attach"); + self.handle.attach().map_err(|err| { + io::Error::other(format!( + "Failed to attach child device on port {}: {err}", + self.port_id + )) + })?; } else { - self.handle.detach().expect("Failed to detach"); + self.handle.detach().map_err(|err| { + io::Error::other(format!( + "Failed to detach child device on port {}: {err}", + self.port_id + )) + })?; } self.attached = attached; + Ok(()) + } + + pub fn ensure_suspended(&mut self, suspended: bool) -> io::Result<()> { + if suspended == self.suspended { + return Ok(()); + } + + if suspended { + self.handle.suspend_device().map_err(|err| { + io::Error::other(format!( + "Failed to suspend child device on port {}: {err}", + self.port_id + )) + })?; + } else { + self.handle.resume_device().map_err(|err| { + io::Error::other(format!( + "Failed to resume child device on port {}: {err}", + self.port_id + )) + })?; + } + + self.suspended = suspended; + Ok(()) } } let mut states = Vec::new(); for port in 1..=ports { - let child_port_id = port_id.child(port).expect("Cannot get child port ID"); - states.push(PortState { + let child_port_id = match port_id.child(port) { + Ok(child_port_id) => child_port_id, + Err(err) => { + log::warn!( + "Skipping hub port {port}: cannot derive child port ID from parent port {port_id}: {err}" + ); + states.push(None); + continue; + } + }; + + let child_handle = match XhciClientHandle::new(scheme.clone(), child_port_id) { + Ok(child_handle) => child_handle, + Err(err) => { + log::warn!( + "Skipping hub port {port} (child port {child_port_id}): failed to open XhciClientHandle: {err}" + ); + states.push(None); + continue; + } + }; + + states.push(Some(PortState { port_id: child_port_id, port_sts: if usb_3 { usb::HubPortStatus::V3(usb::HubPortStatusV3::default()) } else { usb::HubPortStatus::V2(usb::HubPortStatusV2::default()) }, - handle: XhciClientHandle::new(scheme.clone(), child_port_id) - .expect("Failed to open XhciClientHandle"), + handle: child_handle, attached: false, - }); + suspended: false, + })); } - //TODO: use change flags? - loop { - for port in 1..=ports { - let port_idx: usize = port.checked_sub(1).unwrap().into(); - let state = states.get_mut(port_idx).unwrap(); - - let port_sts = if usb_3 { - let mut port_sts = usb::HubPortStatusV3::default(); - handle - .device_request( - PortReqTy::Class, - PortReqRecipient::Other, - usb::SetupReq::GetStatus as u8, - 0, - port as u16, - DeviceReqData::In(unsafe { plain::as_mut_bytes(&mut port_sts) }), - ) - .expect("Failed to retrieve port status"); - usb::HubPortStatus::V3(port_sts) - } else { - let mut port_sts = usb::HubPortStatusV2::default(); - handle - .device_request( - PortReqTy::Class, - PortReqRecipient::Other, - usb::SetupReq::GetStatus as u8, - 0, - port as u16, - DeviceReqData::In(unsafe { plain::as_mut_bytes(&mut port_sts) }), - ) - .expect("Failed to retrieve port status"); - usb::HubPortStatus::V2(port_sts) - }; - if state.port_sts != port_sts { - state.port_sts = port_sts; - log::info!("port {} status {:X?}", port, port_sts); - } - - // Ensure port is powered on - if !port_sts.is_powered() { - log::info!("power on port {port}"); - handle - .device_request( - PortReqTy::Class, - PortReqRecipient::Other, - usb::SetupReq::SetFeature as u8, - usb::HubPortFeature::PortPower as u16, - port as u16, - DeviceReqData::NoData, - ) - .expect("Failed to set port power"); - state.ensure_attached(false); - continue; + let mut process_port = |port: u8| -> Result<(), Box> { + let port_idx: usize = match port.checked_sub(1) { + Some(port_idx) => port_idx.into(), + None => { + return Err(other_error(format!( + "Failed to derive zero-based index for hub port {port}" + ))); } + }; + let Some(state_entry) = states.get_mut(port_idx) else { + return Err(other_error(format!( + "Missing state entry for hub port {port} at index {port_idx}" + ))); + }; + let Some(state) = state_entry.as_mut() else { + return Ok(()); + }; - // Ignore disconnected port - if !port_sts.is_connected() { - state.ensure_attached(false); - continue; + let port_sts = if usb_3 { + let mut port_sts = usb::HubPortStatusV3::default(); + if let Err(err) = handle.device_request( + PortReqTy::Class, + PortReqRecipient::Other, + usb::SetupReq::GetStatus as u8, + 0, + port as u16, + DeviceReqData::In(unsafe { plain::as_mut_bytes(&mut port_sts) }), + ) { + log::warn!("Failed to retrieve USB 3 status for hub port {port}: {err}"); + if let Err(err) = state.ensure_attached(false) { + log::warn!( + "Failed to detach child device after status error on hub port {port}: {err}" + ); + } + return Ok(()); + } + usb::HubPortStatus::V3(port_sts) + } else { + let mut port_sts = usb::HubPortStatusV2::default(); + if let Err(err) = handle.device_request( + PortReqTy::Class, + PortReqRecipient::Other, + usb::SetupReq::GetStatus as u8, + 0, + port as u16, + DeviceReqData::In(unsafe { plain::as_mut_bytes(&mut port_sts) }), + ) { + log::warn!("Failed to retrieve USB 2 status for hub port {port}: {err}"); + if let Err(err) = state.ensure_attached(false) { + log::warn!( + "Failed to detach child device after status error on hub port {port}: {err}" + ); + } + return Ok(()); } + usb::HubPortStatus::V2(port_sts) + }; + if state.port_sts != port_sts { + state.port_sts = port_sts; + log::info!("port {} status {:X?}", port, port_sts); + } - // Ignore port in reset - if port_sts.is_resetting() { - state.ensure_attached(false); - continue; + // Ensure port is powered on + if !port_sts.is_powered() { + if let Err(err) = state.ensure_suspended(false) { + log::warn!("Failed to resume child device for unpowered hub port {port}: {err}"); } + log::info!("power on port {port}"); + if let Err(err) = handle.device_request( + PortReqTy::Class, + PortReqRecipient::Other, + usb::SetupReq::SetFeature as u8, + usb::HubPortFeature::PortPower as u16, + port as u16, + DeviceReqData::NoData, + ) { + log::warn!("Failed to set port power for hub port {port}: {err}"); + if let Err(err) = state.ensure_attached(false) { + log::warn!( + "Failed to detach child device after power error on hub port {port}: {err}" + ); + } + return Ok(()); + } + if let Err(err) = state.ensure_attached(false) { + log::warn!("Failed to detach child device for unpowered hub port {port}: {err}"); + } + return Ok(()); + } - // Ensure port is enabled - if !port_sts.is_enabled() { - log::info!("reset port {port}"); - handle - .device_request( - PortReqTy::Class, - PortReqRecipient::Other, - usb::SetupReq::SetFeature as u8, - usb::HubPortFeature::PortReset as u16, - port as u16, - DeviceReqData::NoData, - ) - .expect("Failed to set port enable"); - state.ensure_attached(false); - continue; + // Ignore disconnected port + if !port_sts.is_connected() { + if let Err(err) = state.ensure_suspended(false) { + log::warn!("Failed to resume child device for disconnected hub port {port}: {err}"); + } + if let Err(err) = state.ensure_attached(false) { + log::warn!("Failed to detach child device for disconnected hub port {port}: {err}"); } + return Ok(()); + } - state.ensure_attached(true); + // Ignore port in reset + if port_sts.is_resetting() { + if let Err(err) = state.ensure_suspended(false) { + log::warn!("Failed to resume child device for resetting hub port {port}: {err}"); + } + if let Err(err) = state.ensure_attached(false) { + log::warn!("Failed to detach child device for resetting hub port {port}: {err}"); + } + return Ok(()); } - //TODO: use interrupts or poll faster? - thread::sleep(time::Duration::new(1, 0)); + // Ensure port is enabled + if !port_sts.is_enabled() { + log::info!("reset port {port}"); + if let Err(err) = handle.device_request( + PortReqTy::Class, + PortReqRecipient::Other, + usb::SetupReq::SetFeature as u8, + usb::HubPortFeature::PortReset as u16, + port as u16, + DeviceReqData::NoData, + ) { + log::warn!("Failed to reset hub port {port}: {err}"); + if let Err(err) = state.ensure_attached(false) { + log::warn!( + "Failed to detach child device after reset error on hub port {port}: {err}" + ); + } + return Ok(()); + } + if let Err(err) = state.ensure_attached(false) { + log::warn!("Failed to detach child device while resetting hub port {port}: {err}"); + } + return Ok(()); + } + + if let Err(err) = state.ensure_suspended(port_sts.is_suspended()) { + log::warn!("Failed to synchronize child suspend state for hub port {port}: {err}"); + } + + if let Err(err) = state.ensure_attached(true) { + log::warn!("Failed to attach child device for hub port {port}: {err}"); + } + + Ok(()) + }; + + let try_open_interrupt_endpoint = || -> Option { + let Some(interrupt_endpoint_desc) = interrupt_endpoint_desc else { + return None; + }; + + let interrupt_endpoint_num = interrupt_endpoint_desc.address & 0x0F; + match handle.open_endpoint(interrupt_endpoint_num) { + Ok(interrupt_endpoint) => { + log::info!( + "Using hub interrupt endpoint {} IN (max packet size {})", + interrupt_endpoint_num, + interrupt_endpoint_desc.max_packet_size + ); + Some(interrupt_endpoint) + } + Err(err) => { + log::warn!( + "Failed to open hub interrupt endpoint {} on port {}: {}; falling back to polling", + interrupt_endpoint_num, + port_id, + err + ); + None + } + } + }; + + for port in 1..=ports { + process_port(port)?; + } + + if interrupt_endpoint_desc.is_none() { + log::warn!( + "No interrupt IN endpoint found for hub on port {}; falling back to polling", + port_id + ); } - //TODO: read interrupt port for changes + let mut interrupt_endpoint = try_open_interrupt_endpoint(); + let mut poll_iterations: u32 = 0; + + loop { + if let Some(endp) = interrupt_endpoint.as_mut() { + let mut change_bitmap = vec![0_u8; status_change_bitmap_size]; + match endp.transfer_read(&mut change_bitmap) { + Ok(_) => { + for port in 1..=ports { + let bit = usize::from(port); + let byte_idx = bit / 8; + let bit_idx = bit % 8; + + if change_bitmap + .get(byte_idx) + .is_some_and(|byte| ((byte >> bit_idx) & 1) != 0) + { + process_port(port)?; + } + } + poll_iterations = 0; + continue; + } + Err(err) => { + log::warn!( + "Failed to read hub interrupt endpoint on port {}: {}; falling back to polling", + port_id, + err + ); + interrupt_endpoint = None; + poll_iterations = 0; + } + } + } + + for port in 1..=ports { + process_port(port)?; + } + + poll_iterations = poll_iterations.saturating_add(1); + if interrupt_endpoint.is_none() + && interrupt_endpoint_desc.is_some() + && poll_iterations % 10 == 0 + { + interrupt_endpoint = try_open_interrupt_endpoint(); + if interrupt_endpoint.is_some() { + poll_iterations = 0; + continue; + } + } + + thread::sleep(time::Duration::new(1, 0)); + } } diff --git a/drivers/usb/xhcid/src/driver_interface.rs b/drivers/usb/xhcid/src/driver_interface.rs index 727f8d7e..557e6bce 100644 --- a/drivers/usb/xhcid/src/driver_interface.rs +++ b/drivers/usb/xhcid/src/driver_interface.rs @@ -560,6 +560,16 @@ impl XhciClientHandle { let _bytes_written = file.write(&[])?; Ok(()) } + pub fn suspend_device(&self) -> result::Result<(), XhciClientHandleError> { + let file = self.fd.openat("suspend", libredox::flag::O_WRONLY, 0)?; + let _bytes_written = file.write(&[])?; + Ok(()) + } + pub fn resume_device(&self) -> result::Result<(), XhciClientHandleError> { + let file = self.fd.openat("resume", libredox::flag::O_WRONLY, 0)?; + let _bytes_written = file.write(&[])?; + Ok(()) + } pub fn get_standard_descs(&self) -> result::Result { let json = self.read("descriptors")?; Ok(serde_json::from_slice(&json)?) diff --git a/drivers/usb/xhcid/Cargo.toml b/drivers/usb/xhcid/Cargo.toml index 778376b0..1651bcf5 100644 --- a/drivers/usb/xhcid/Cargo.toml +++ b/drivers/usb/xhcid/Cargo.toml @@ -32,6 +32,7 @@ common = { path = "../../common" } daemon = { path = "../../../daemon" } pcid = { path = "../../pcid" } +redox-driver-sys = { path = "../../../../../../../local/recipes/drivers/redox-driver-sys/source" } libredox.workspace = true regex = "1.10.6" diff --git a/drivers/usb/xhcid/src/usb_quirks.rs b/drivers/usb/xhcid/src/usb_quirks.rs new file mode 100644 index 00000000..83ca324d --- /dev/null +++ b/drivers/usb/xhcid/src/usb_quirks.rs @@ -0,0 +1,10 @@ +pub use redox_driver_sys::quirks::UsbQuirkFlags; + +use crate::driver_interface::PortId; + +pub fn lookup_usb_quirks(vendor: u16, product: u16) -> UsbQuirkFlags { + redox_driver_sys::quirks::lookup_usb_quirks(vendor, product) +} + +pub fn lookup_usb_quirks_early(_port_id: PortId) -> UsbQuirkFlags { + UsbQuirkFlags::empty() +} diff --git a/drivers/usb/xhcid/src/main.rs b/drivers/usb/xhcid/src/main.rs index 25b2fdd6..d5dea9b2 100644 --- a/drivers/usb/xhcid/src/main.rs +++ b/drivers/usb/xhcid/src/main.rs @@ -49,6 +49,7 @@ use crate::xhci::{InterruptMethod, Xhci}; // mean anything. pub mod driver_interface; +mod usb_quirks; mod usb; mod xhci; @@ -141,8 +142,19 @@ fn daemon_with_context_size( let address = unsafe { pcid_handle.map_bar(0) }.ptr.as_ptr() as usize; - let (irq_file, interrupt_method) = (None, InterruptMethod::Polling); //get_int_method(&mut pcid_handle); - //TODO: Fix interrupts. + let (irq_file, interrupt_method) = get_int_method(&mut pcid_handle); + + match interrupt_method { + InterruptMethod::Msi => { + log::info!("xhcid: using MSI/MSI-X interrupt delivery"); + } + InterruptMethod::Intx => { + log::info!("xhcid: using legacy INTx interrupt delivery"); + } + InterruptMethod::Polling => { + log::warn!("xhcid: using polling event delivery"); + } + } log::info!("XHCI {}", pci_config.func.display()); diff --git a/drivers/usb/xhcid/src/usb/hub.rs b/drivers/usb/xhcid/src/usb/hub.rs index 9dab55e8..fe6efdd2 100644 --- a/drivers/usb/xhcid/src/usb/hub.rs +++ b/drivers/usb/xhcid/src/usb/hub.rs @@ -86,8 +86,10 @@ pub enum HubPortFeature { PortOverCurrent = 3, PortReset = 4, PortLinkState = 5, + PortSuspend = 7, PortPower = 8, CPortConnection = 16, + CPortSuspend = 18, CPortOverCurrent = 19, CPortReset = 20, } @@ -184,4 +186,11 @@ impl HubPortStatus { Self::V3(x) => x.contains(HubPortStatusV3::ENABLE), } } + + pub fn is_suspended(&self) -> bool { + match self { + Self::V2(x) => x.contains(HubPortStatusV2::SUSPEND), + Self::V3(_) => false, + } + } } diff --git a/drivers/usb/xhcid/src/xhci/device_enumerator.rs b/drivers/usb/xhcid/src/xhci/device_enumerator.rs index 74b9f732..1f144ac9 100644 --- a/drivers/usb/xhcid/src/xhci/device_enumerator.rs +++ b/drivers/usb/xhcid/src/xhci/device_enumerator.rs @@ -54,6 +54,7 @@ impl DeviceEnumerator { }; if flags.contains(PortFlags::CCS) { + let early_quirks = crate::usb_quirks::lookup_usb_quirks_early(port_id); debug!( "Received Device Connect Port Status Change Event with port flags {:?}", flags @@ -85,7 +86,17 @@ impl DeviceEnumerator { port.clear_prc(); - std::thread::sleep(Duration::from_millis(16)); //Some controllers need some extra time to make the transition. + let delay_ms = if early_quirks + .contains(crate::usb_quirks::UsbQuirkFlags::HUB_SLOW_RESET) + { + 200 + } else if early_quirks.contains(crate::usb_quirks::UsbQuirkFlags::RESET_DELAY) { + 100 + } else { + 16 + }; + + std::thread::sleep(Duration::from_millis(delay_ms)); // Some devices need extra time to settle after reset. let flags = port.flags(); diff --git a/drivers/usb/xhcid/src/xhci/mod.rs b/drivers/usb/xhcid/src/xhci/mod.rs index f2143676..d81648bf 100644 --- a/drivers/usb/xhcid/src/xhci/mod.rs +++ b/drivers/usb/xhcid/src/xhci/mod.rs @@ -11,12 +11,13 @@ //! documents are specified in the crate-level documentation. use std::collections::BTreeMap; use std::convert::TryFrom; -use std::fs::File; +use std::fs::{self, File}; +use std::time::Duration; use std::sync::atomic::AtomicUsize; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Condvar, Mutex}; use std::{mem, process, slice, thread}; -use syscall::error::{Error, Result, EBADF, EBADMSG, EIO, ENOENT}; +use syscall::error::{Error, Result, EBADF, EBADMSG, EBUSY, EIO, ENOENT}; use syscall::{EAGAIN, PAGE_SIZE}; use chashmap::CHashMap; @@ -77,7 +78,55 @@ pub enum InterruptMethod { Msi, } +const XHCID_TEST_HOOK_PATH: &str = "/tmp/xhcid-test-hook"; +const XHCID_TEST_HOOK_MAX_DELAY_MS: u64 = 5_000; + impl Xhci { + fn read_test_hook_command_from_path(path: &str) -> Option { + let contents = fs::read_to_string(path).ok()?; + contents + .lines() + .map(str::trim) + .find(|line| !line.is_empty() && !line.starts_with('#')) + .map(ToOwned::to_owned) + } + + fn clear_test_hook_command_path(path: &str) { + if let Err(err) = fs::remove_file(path) { + if err.kind() != std::io::ErrorKind::NotFound { + warn!( + "failed to remove xhcid test hook file {}: {}", + path, err + ); + } + } + } + + fn consume_test_hook_from_path(path: &str, expected: &str) -> bool { + match Self::read_test_hook_command_from_path(path) { + Some(command) if command == expected => { + Self::clear_test_hook_command_path(path); + true + } + _ => false, + } + } + + fn consume_test_hook_delay_ms_from_path(path: &str, prefix: &str) -> Option { + let command = Self::read_test_hook_command_from_path(path)?; + let delay_ms = command.strip_prefix(prefix)?.parse::().ok()?; + Self::clear_test_hook_command_path(path); + Some(delay_ms.min(XHCID_TEST_HOOK_MAX_DELAY_MS)) + } + + pub(crate) fn consume_test_hook(&self, expected: &str) -> bool { + Self::consume_test_hook_from_path(XHCID_TEST_HOOK_PATH, expected) + } + + pub(crate) fn consume_test_hook_delay_ms(&self, prefix: &str) -> Option { + Self::consume_test_hook_delay_ms_from_path(XHCID_TEST_HOOK_PATH, prefix) + } + /// Gets descriptors, before the port state is initiated. async fn get_desc_raw( &self, @@ -104,7 +153,18 @@ impl Xhci { ); let future = { - let mut port_state = self.port_states.get_mut(&port).ok_or(Error::new(ENOENT))?; + let mut published_port_state = self.port_states.get_mut(&port); + let mut staged_port_state = if published_port_state.is_none() { + self.staged_port_states.get_mut(&port) + } else { + None + }; + + let port_state = published_port_state + .as_deref_mut() + .or_else(|| staged_port_state.as_deref_mut()) + .ok_or(Error::new(ENOENT))?; + let ring = port_state .endpoint_states .get_mut(&0) @@ -150,7 +210,7 @@ impl Xhci { trace!("Handling the transfer event TRB!"); self::scheme::handle_transfer_event_trb("GET_DESC", &event_trb, &status_trb)?; - //self.event_handler_finished(); + self.event_handler_finished(); Ok(()) } @@ -283,6 +343,7 @@ pub struct Xhci { handles: CHashMap, next_handle: AtomicUsize, port_states: CHashMap>, + staged_port_states: CHashMap>, drivers: CHashMap>, scheme_name: String, @@ -311,6 +372,144 @@ struct PortState { input_context: Mutex>>, dev_desc: Option, endpoint_states: BTreeMap, + quirks: crate::usb_quirks::UsbQuirkFlags, + pm_state: PortPmState, + lifecycle: Arc, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) enum PortLifecycleState { + Attaching, + Attached, + Detaching, +} + +struct PortLifecycleInner { + state: PortLifecycleState, + active_operations: usize, +} + +pub(crate) struct PortLifecycle { + inner: Mutex, + idle: Condvar, +} + +impl PortLifecycle { + pub(crate) fn new_attaching() -> Self { + Self { + inner: Mutex::new(PortLifecycleInner { + state: PortLifecycleState::Attaching, + active_operations: 1, + }), + idle: Condvar::new(), + } + } + + fn lock_inner(&self) -> std::sync::MutexGuard<'_, PortLifecycleInner> { + self.inner.lock().unwrap_or_else(|err| err.into_inner()) + } + + pub(crate) fn state(&self) -> PortLifecycleState { + self.lock_inner().state + } + + pub(crate) fn begin_operation(&self, allow_attaching: bool) -> Result<()> { + let mut inner = self.lock_inner(); + + let allowed = match inner.state { + PortLifecycleState::Attached => true, + PortLifecycleState::Attaching => allow_attaching, + PortLifecycleState::Detaching => false, + }; + + if !allowed { + return Err(Error::new(EBUSY)); + } + + inner.active_operations += 1; + Ok(()) + } + + pub(crate) fn finish_operation(&self) { + let mut inner = self.lock_inner(); + + if inner.active_operations == 0 { + return; + } + + inner.active_operations -= 1; + if inner.active_operations == 0 { + self.idle.notify_all(); + } + } + + pub(crate) fn finish_attach_success(&self) -> PortLifecycleState { + let mut inner = self.lock_inner(); + + if inner.state == PortLifecycleState::Attaching { + inner.state = PortLifecycleState::Attached; + } + + if inner.active_operations != 0 { + inner.active_operations -= 1; + } + if inner.active_operations == 0 { + self.idle.notify_all(); + } + + inner.state + } + + pub(crate) fn finish_attach_failure(&self) { + let mut inner = self.lock_inner(); + inner.state = PortLifecycleState::Detaching; + + if inner.active_operations != 0 { + inner.active_operations -= 1; + } + if inner.active_operations == 0 { + self.idle.notify_all(); + } + } + + pub(crate) fn begin_detaching(&self) { + let mut inner = self.lock_inner(); + inner.state = PortLifecycleState::Detaching; + + while inner.active_operations != 0 { + inner = self.idle.wait(inner).unwrap_or_else(|err| err.into_inner()); + } + } +} + +pub(crate) struct PortOperationGuard { + lifecycle: Arc, +} + +impl PortOperationGuard { + pub(crate) fn new(lifecycle: Arc) -> Self { + Self { lifecycle } + } +} + +impl Drop for PortOperationGuard { + fn drop(&mut self) { + self.lifecycle.finish_operation(); + } +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) enum PortPmState { + Active, + Suspended, +} +impl PortPmState { + pub fn as_str(&self) -> &'static str { + match self { + Self::Active => "active", + Self::Suspended => "suspended", + } + } } impl PortState { @@ -463,6 +662,7 @@ impl Xhci { handles: CHashMap::new(), next_handle: AtomicUsize::new(0), port_states: CHashMap::new(), + staged_port_states: CHashMap::new(), drivers: CHashMap::new(), scheme_name, @@ -615,29 +815,24 @@ impl Xhci { route_string: 0, }; - //Get the CCS and CSC flags - let (ccs, csc, flags) = { + // Only queue ports that are actually connected at startup. A stale CSC bit on an + // otherwise disconnected port should not trigger a full attach attempt. + let (ccs, flags) = { let mut ports = self.ports.lock().unwrap(); let port = &mut ports[port_id.root_hub_port_index()]; let flags = port.flags(); let ccs = flags.contains(PortFlags::CCS); - let csc = flags.contains(PortFlags::CSC); - (ccs, csc, flags) + (ccs, flags) }; debug!("Port {} has flags {:?}", port_id, flags); - match (ccs, csc) { - (false, false) => { // Nothing is connected, and there was no port status change - //Do nothing - } - _ => { - //Either something is connected, or nothing is connected and a port status change was asserted. - self.device_enumerator_sender - .send(DeviceEnumerationRequest { port_id }) - .expect("Failed to generate the port enumeration request!"); - } + if ccs { + info!("xhcid: queueing initial enumeration for port {} with flags {:?}", port_id, flags); + self.device_enumerator_sender + .send(DeviceEnumerationRequest { port_id }) + .expect("Failed to generate the port enumeration request!"); } } } @@ -757,7 +952,7 @@ impl Xhci { trace!("Slot is enabled!"); self::scheme::handle_event_trb("ENABLE_SLOT", &event_trb, &command_trb)?; - //self.event_handler_finished(); + self.event_handler_finished(); Ok(event_trb.event_slot()) } @@ -768,7 +963,7 @@ impl Xhci { .await; self::scheme::handle_event_trb("DISABLE_SLOT", &event_trb, &command_trb)?; - //self.event_handler_finished(); + self.event_handler_finished(); Ok(()) } @@ -793,11 +988,13 @@ impl Xhci { } pub async fn attach_device(&self, port_id: PortId) -> syscall::Result<()> { - if self.port_states.contains_key(&port_id) { + if self.port_states.contains_key(&port_id) || self.staged_port_states.contains_key(&port_id) { debug!("Already contains port {}", port_id); return Err(syscall::Error::new(EAGAIN)); } + info!("xhcid: begin attach for port {}", port_id); + let (data, state, speed, flags) = { let port = &self.ports.lock().unwrap()[port_id.root_hub_port_index()]; (port.read(), port.state(), port.speed(), port.flags()) @@ -808,74 +1005,114 @@ impl Xhci { port_id, data, state, speed, flags ); - if flags.contains(port::PortFlags::CCS) { - let slot_ty = match self.supported_protocol(port_id) { - Some(protocol) => protocol.proto_slot_ty(), - None => { - warn!("Failed to find supported protocol information for port"); - 0 - } - }; - - debug!("Slot type: {}", slot_ty); - debug!("Enabling slot."); - let slot = match self.enable_port_slot(slot_ty).await { - Ok(ok) => ok, - Err(err) => { - error!("Failed to enable slot for port {}: {}", port_id, err); - return Err(err); - } - }; + if !flags.contains(port::PortFlags::CCS) { + warn!("Attempted to attach a device that didnt have CCS=1"); + return Ok(()); + } - debug!("Enabled port {}, which the xHC mapped to {}", port_id, slot); + let early_quirks = crate::usb_quirks::lookup_usb_quirks_early(port_id); + let slot_ty = match self.supported_protocol(port_id) { + Some(protocol) => protocol.proto_slot_ty(), + None => { + warn!("Failed to find supported protocol information for port {}", port_id); + 0 + } + }; - //TODO: get correct speed for child devices - let protocol_speed = self - .lookup_psiv(port_id, speed) - .expect("Failed to retrieve speed ID"); + debug!("Slot type: {}", slot_ty); + debug!("Enabling slot."); + let slot = match self.enable_port_slot(slot_ty).await { + Ok(ok) => ok, + Err(err) => { + error!("Failed to enable slot for port {}: {}", port_id, err); + return Err(err); + } + }; - let mut input = unsafe { self.alloc_dma_zeroed::>()? }; + debug!("Enabled port {}, which the xHC mapped to {}", port_id, slot); + info!("xhcid: enabled slot {} for port {}", slot, port_id); - debug!("Attempting to address the device"); - let mut ring = match self - .address_device(&mut input, port_id, slot_ty, slot, protocol_speed, speed) - .await - { - Ok(device_ring) => device_ring, - Err(err) => { - error!("Failed to address device for port {}: `{}`", port_id, err); - return Err(err); + let protocol_speed = match self.lookup_psiv(port_id, speed) { + Some(protocol_speed) => protocol_speed, + None => { + let err = Error::new(EIO); + error!("Failed to retrieve speed ID for port {}", port_id); + if let Err(disable_err) = self.disable_port_slot(slot).await { + warn!( + "Failed to disable slot {} after speed lookup failure on port {}: {}", + slot, port_id, disable_err + ); } - }; - - debug!("Addressed device"); + return Err(err); + } + }; - // TODO: Should the descriptors be cached in PortState, or refetched? + let mut input = unsafe { self.alloc_dma_zeroed::>()? }; - let mut port_state = PortState { + debug!("Attempting to address the device"); + let ring = match self + .address_device( + &mut input, + port_id, + slot_ty, slot, protocol_speed, - input_context: Mutex::new(input), - dev_desc: None, - cfg_idx: None, - endpoint_states: std::iter::once(( - 0, - EndpointState { - transfer: RingOrStreams::Ring(ring), - driver_if_state: EndpIfState::Init, - }, - )) - .collect::>(), - }; - self.port_states.insert(port_id, port_state); - debug!("Got port states!"); + speed, + early_quirks, + ) + .await + { + Ok(device_ring) => device_ring, + Err(err) => { + error!("Failed to address device for port {}: `{}`", port_id, err); + if let Err(disable_err) = self.disable_port_slot(slot).await { + warn!( + "Failed to disable slot {} after address failure on port {}: {}", + slot, port_id, disable_err + ); + } + return Err(err); + } + }; + + debug!("Addressed device"); + info!("xhcid: addressed device on port {} slot {}", port_id, slot); + + let lifecycle = Arc::new(PortLifecycle::new_attaching()); + let port_state = PortState { + slot, + protocol_speed, + input_context: Mutex::new(input), + dev_desc: None, + cfg_idx: None, + endpoint_states: std::iter::once(( + 0, + EndpointState { + transfer: RingOrStreams::Ring(ring), + driver_if_state: EndpIfState::Init, + }, + )) + .collect::>(), + quirks: early_quirks, + pm_state: PortPmState::Active, + lifecycle: Arc::clone(&lifecycle), + }; + self.staged_port_states.insert(port_id, port_state); + debug!("Got staged port state!"); - // Ensure correct packet size is used + let attach_result = async { let dev_desc_8_byte = self.fetch_dev_desc_8_byte(port_id, slot).await?; + info!("xhcid: fetched 8-byte device descriptor for port {}", port_id); { - let mut port_state = self.port_states.get_mut(&port_id).unwrap(); + let mut port_state = self + .staged_port_states + .get_mut(&port_id) + .ok_or(Error::new(ENOENT))?; - let mut input = port_state.input_context.lock().unwrap(); + let mut input = port_state + .input_context + .lock() + .unwrap_or_else(|err| err.into_inner()); self.update_max_packet_size(&mut *input, slot, dev_desc_8_byte) .await?; @@ -884,38 +1121,131 @@ impl Xhci { debug!("Got the 8 byte dev descriptor: {:X?}", dev_desc_8_byte); let dev_desc = self.get_desc(port_id, slot).await?; + info!( + "xhcid: got descriptors for port {} vendor {:04x} product {:04x}", + port_id, + dev_desc.vendor, + dev_desc.product + ); + let quirks = early_quirks + | crate::usb_quirks::lookup_usb_quirks(dev_desc.vendor, dev_desc.product); debug!("Got the full device descriptor!"); - self.port_states.get_mut(&port_id).unwrap().dev_desc = Some(dev_desc); + { + let mut port_state = self + .staged_port_states + .get_mut(&port_id) + .ok_or(Error::new(ENOENT))?; + port_state.quirks = quirks; + port_state.dev_desc = Some(dev_desc); + } debug!("Got the port states again!"); { - let mut port_state = self.port_states.get_mut(&port_id).unwrap(); - - let mut input = port_state.input_context.lock().unwrap(); + let mut port_state = self + .staged_port_states + .get_mut(&port_id) + .ok_or(Error::new(ENOENT))?; + + let mut input = port_state + .input_context + .lock() + .unwrap_or_else(|err| err.into_inner()); debug!("Got the input context!"); - let dev_desc = port_state.dev_desc.as_ref().unwrap(); + let dev_desc = port_state.dev_desc.as_ref().ok_or(Error::new(EIO))?; self.update_default_control_pipe(&mut *input, slot, dev_desc) .await?; } debug!("Updated the default control pipe"); + Ok(()) + } + .await; + + match attach_result { + Ok(()) => { + if let Some(delay_ms) = + self.consume_test_hook_delay_ms("delay_before_attach_commit_ms=") + { + info!( + "xhcid: test hook delaying attach commit for port {} by {} ms", + port_id, delay_ms + ); + thread::sleep(Duration::from_millis(delay_ms)); + } - match self.spawn_drivers(port_id) { - Ok(()) => (), - Err(err) => { - error!("Failed to spawn driver for port {}: `{}`", port_id, err) + if lifecycle.finish_attach_success() != PortLifecycleState::Attached { + warn!( + "attach for port {} completed after detach already started; skipping publication", + port_id + ); + return Err(Error::new(EBUSY)); } + + let staged_port_state = self + .staged_port_states + .remove(&port_id) + .ok_or(Error::new(ENOENT))?; + self.port_states.insert(port_id, staged_port_state); + + match self.spawn_drivers(port_id) { + Ok(()) => (), + Err(err) => { + error!("Failed to spawn driver for port {}: `{}`", port_id, err) + } + } + info!("xhcid: finished attach for port {}", port_id); + Ok(()) + } + Err(err) => { + lifecycle.finish_attach_failure(); + if let Err(detach_err) = self.detach_device(port_id).await { + warn!( + "failed to clean up attach failure on port {}: {}", + port_id, detach_err + ); + } + Err(err) } - } else { - warn!("Attempted to attach a device that didnt have CCS=1"); } - - Ok(()) } pub async fn detach_device(&self, port_id: PortId) -> Result { - if let Some(children) = self.drivers.remove(&port_id) { + let published_state = self.port_states.get(&port_id); + let staged_state = if published_state.is_none() { + self.staged_port_states.get(&port_id) + } else { + None + }; + + let (slot, lifecycle, was_published) = match published_state + .as_deref() + .or_else(|| staged_state.as_deref()) + { + Some(state) => (state.slot, Arc::clone(&state.lifecycle), published_state.is_some()), + None => { + debug!( + "Attempted to detach from port {}, which wasn't previously attached.", + port_id + ); + return Ok(false); + } + }; + + info!("xhcid: begin detach quiesce for port {}", port_id); + lifecycle.begin_detaching(); + info!("xhcid: detach quiesce complete for port {}", port_id); + + if let Some(delay_ms) = self.consume_test_hook_delay_ms("delay_before_detach_disable_ms=") { + info!( + "xhcid: test hook delaying detach disable for port {} by {} ms", + port_id, delay_ms + ); + thread::sleep(Duration::from_millis(delay_ms)); + } + + if was_published { + if let Some(children) = self.drivers.remove(&port_id) { for mut child in children { info!("killing driver process {} for port {}", child.id(), port_id); match child.kill() { @@ -961,21 +1291,26 @@ impl Xhci { } } } + } - if let Some(state) = self.port_states.remove(&port_id) { - debug!("disabling port slot {} for port {}", state.slot, port_id); - let result = self.disable_port_slot(state.slot).await.and(Ok(true)); - debug!( - "disabled port slot {} for port {} with result: {:?}", - state.slot, port_id, result - ); - result - } else { - debug!( - "Attempted to detach from port {}, which wasn't previously attached.", - port_id - ); - Ok(false) + debug!("disabling port slot {} for port {}", slot, port_id); + match self.disable_port_slot(slot).await { + Ok(()) => { + if was_published { + let _ = self.port_states.remove(&port_id); + } else { + let _ = self.staged_port_states.remove(&port_id); + } + debug!("disabled port slot {} for port {}", slot, port_id); + Ok(true) + } + Err(err) => { + warn!( + "failed to disable port slot {} for port {}: {}", + slot, port_id, err + ); + Err(err) + } } } @@ -1004,7 +1339,7 @@ impl Xhci { .await; self::scheme::handle_event_trb("EVALUATE_CONTEXT", &event_trb, &command_trb)?; - //self.event_handler_finished(); + self.event_handler_finished(); Ok(()) } @@ -1039,7 +1374,7 @@ impl Xhci { debug!("Completed the command to update the default control pipe"); self::scheme::handle_event_trb("EVALUATE_CONTEXT", &event_trb, &command_trb)?; - //self.event_handler_finished(); + self.event_handler_finished(); Ok(()) } @@ -1052,6 +1387,7 @@ impl Xhci { slot: u8, protocol_speed: &ProtocolSpeed, speed: u8, + quirks: crate::usb_quirks::UsbQuirkFlags, ) -> Result { // Collect MTT, parent port number, parent slot ID let mut mtt = false; @@ -1162,11 +1498,16 @@ impl Xhci { let input_context_physical = input_context.physical(); - let (event_trb, _) = self - .execute_command(|trb, cycle| { - trb.address_device(slot, input_context_physical, false, cycle) - }) - .await; + let address_timeout = if quirks.contains(crate::usb_quirks::UsbQuirkFlags::SHORT_SET_ADDR_TIMEOUT) + { + Timeout::from_millis(100) + } else { + Timeout::from_secs(1) + }; + + let (event_trb, _) = self.execute_command_with_timeout(address_timeout, |trb, cycle| { + trb.address_device(slot, input_context_physical, false, cycle) + })?; if event_trb.completion_code() != TrbCompletionCode::Success as u8 { error!( @@ -1175,10 +1516,10 @@ impl Xhci { port, event_trb.completion_code() ); - //self.event_handler_finished(); + self.event_handler_finished(); return Err(Error::new(EIO)); } - //self.event_handler_finished(); + self.event_handler_finished(); Ok(ring) } @@ -1281,6 +1622,12 @@ impl Xhci { ifdesc.sub_class, ifdesc.protocol, ); + match driver.name.as_str() { + "USB HID" => info!("USB HID driver spawned"), + "SCSI over USB" => info!("USB SCSI driver spawned"), + "USB HUB" => info!("USB HUB driver spawned"), + _ => {} + } let (command, args) = driver.command.split_first().ok_or(Error::new(EBADMSG))?; let command = if command.starts_with('/') { @@ -1487,3 +1834,52 @@ lazy_static! { toml::from_slice::(TOML).expect("Failed to parse internally embedded config file") }; } + +#[cfg(test)] +mod tests { + use super::{Xhci, XHCID_TEST_HOOK_MAX_DELAY_MS}; + use std::fs; + use std::path::Path; + use std::time::{SystemTime, UNIX_EPOCH}; + + fn unique_test_hook_path() -> String { + let unique = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_nanos(); + format!("/tmp/xhcid-test-hook-{}", unique) + } + + #[test] + fn consume_test_hook_only_clears_matching_command() { + let path = unique_test_hook_path(); + fs::write(&path, "fail_after_set_configuration\n").unwrap(); + + assert!(!Xhci::<16>::consume_test_hook_from_path( + &path, + "fail_after_configure_endpoint" + )); + assert!(Path::new(&path).exists()); + + assert!(Xhci::<16>::consume_test_hook_from_path( + &path, + "fail_after_set_configuration" + )); + assert!(!Path::new(&path).exists()); + } + + #[test] + fn consume_test_hook_delay_clamps_and_clears() { + let path = unique_test_hook_path(); + fs::write(&path, "delay_before_attach_commit_ms=999999\n").unwrap(); + + assert_eq!( + Xhci::<16>::consume_test_hook_delay_ms_from_path( + &path, + "delay_before_attach_commit_ms=" + ), + Some(XHCID_TEST_HOOK_MAX_DELAY_MS) + ); + assert!(!Path::new(&path).exists()); + } +} diff --git a/drivers/usb/xhcid/src/xhci/scheme.rs b/drivers/usb/xhcid/src/xhci/scheme.rs index f2d439a4..bc6d7fca 100644 --- a/drivers/usb/xhcid/src/xhci/scheme.rs +++ b/drivers/usb/xhcid/src/xhci/scheme.rs @@ -18,12 +18,15 @@ //! port/endpoints//data use std::convert::TryFrom; use std::io::prelude::*; +use std::io::Write; use std::ops::Deref; +use std::sync::Arc; use std::sync::atomic; use std::{cmp, fmt, io, mem, str}; use common::dma::Dma; use futures::executor::block_on; +use futures::FutureExt; use log::{debug, error, info, trace, warn}; use redox_scheme::scheme::SchemeSync; use smallvec::SmallVec; @@ -32,16 +35,16 @@ use common::io::Io; use redox_scheme::{CallerCtx, OpenResult}; use syscall::schemev2::NewFdFlags; use syscall::{ - Error, Result, Stat, EACCES, EBADF, EBADFD, EBADMSG, EINVAL, EIO, EISDIR, ENOENT, ENOSYS, - ENOTDIR, EOPNOTSUPP, EPROTO, ESPIPE, MODE_CHR, MODE_DIR, MODE_FILE, O_DIRECTORY, O_RDWR, - O_STAT, O_WRONLY, SEEK_CUR, SEEK_END, SEEK_SET, + Error, Result, Stat, EACCES, EBADF, EBADFD, EBADMSG, EBUSY, EINVAL, EIO, EISDIR, ENOENT, + ENOSYS, ENOTDIR, EOPNOTSUPP, EPROTO, ESPIPE, MODE_CHR, MODE_DIR, MODE_FILE, O_DIRECTORY, + O_RDWR, O_STAT, O_WRONLY, SEEK_CUR, SEEK_END, SEEK_SET, }; use super::{port, usb}; use super::{EndpointState, PortId, Xhci}; use super::context::{ - SlotState, StreamContextArray, StreamContextType, CONTEXT_32, CONTEXT_64, + EndpointContext, SlotState, StreamContextArray, StreamContextType, CONTEXT_32, CONTEXT_64, SLOT_CONTEXT_STATE_MASK, SLOT_CONTEXT_STATE_SHIFT, }; use super::extended::ProtocolSpeed; @@ -60,10 +63,16 @@ lazy_static! { .expect("Failed to create the regex for the port/attach scheme."); static ref REGEX_PORT_DETACH: Regex = Regex::new(r"^port([\d\.]+)/detach$") .expect("Failed to create the regex for the port/detach scheme."); + static ref REGEX_PORT_SUSPEND: Regex = Regex::new(r"^port([\d\.]+)/suspend$") + .expect("Failed to create the regex for the port/suspend scheme."); + static ref REGEX_PORT_RESUME: Regex = Regex::new(r"^port([\d\.]+)/resume$") + .expect("Failed to create the regex for the port/resume scheme."); static ref REGEX_PORT_DESCRIPTORS: Regex = Regex::new(r"^port([\d\.]+)/descriptors$") .expect("Failed to create the regex for the port/descriptors"); static ref REGEX_PORT_STATE: Regex = Regex::new(r"^port([\d\.]+)/state$") .expect("Failed to create the regex for the port/state scheme"); + static ref REGEX_PORT_PM_STATE: Regex = Regex::new(r"^port([\d\.]+)/pm_state$") + .expect("Failed to create the regex for the port/pm_state scheme"); static ref REGEX_PORT_REQUEST: Regex = Regex::new(r"^port([\d\.]+)/request$") .expect("Failed to create the regex for the port/request scheme"); static ref REGEX_PORT_ENDPOINTS: Regex = Regex::new(r"^port([\d\.]+)/endpoints$") @@ -137,12 +146,15 @@ pub enum Handle { Port(PortId, Vec), // port, contents PortDesc(PortId, Vec), // port, contents PortState(PortId), // port + PortPmState(PortId), // port PortReq(PortId, PortReqState), // port, state Endpoints(PortId, Vec), // port, contents Endpoint(PortId, u8, EndpointHandleTy), // port, endpoint, state ConfigureEndpoints(PortId), // port AttachDevice(PortId), // port DetachDevice(PortId), // port + SuspendDevice(PortId), // port + ResumeDevice(PortId), // port SchemeRoot, } @@ -172,6 +184,8 @@ enum SchemeParameters { PortDesc(PortId), // port number /// /port/state PortState(PortId), // port number + /// /port/pm_state + PortPmState(PortId), // port number /// /port/request PortReq(PortId), // port number /// /port/endpoints @@ -187,6 +201,10 @@ enum SchemeParameters { AttachDevice(PortId), // port number /// /port/detach DetachDevice(PortId), // port number + /// /port/suspend + SuspendDevice(PortId), // port number + /// /port/resume + ResumeDevice(PortId), // port number } impl Handle { @@ -209,6 +227,9 @@ impl Handle { Handle::PortState(port_num) => { format!("port{}/state", port_num) } + Handle::PortPmState(port_num) => { + format!("port{}/pm_state", port_num) + } Handle::PortReq(port_num, _) => { format!("port{}/request", port_num) } @@ -235,6 +256,12 @@ impl Handle { Handle::DetachDevice(port_num) => { format!("port{}/detach", port_num) } + Handle::SuspendDevice(port_num) => { + format!("port{}/suspend", port_num) + } + Handle::ResumeDevice(port_num) => { + format!("port{}/resume", port_num) + } Handle::SchemeRoot => String::from(""), } } @@ -258,10 +285,13 @@ impl Handle { &Handle::PortReq(_, PortReqState::Tmp) => unreachable!(), &Handle::PortReq(_, PortReqState::TmpSetup(_)) => unreachable!(), &Handle::PortState(_) => HandleType::Character, + &Handle::PortPmState(_) => HandleType::Character, &Handle::PortReq(_, _) => HandleType::Character, &Handle::ConfigureEndpoints(_) => HandleType::Character, &Handle::AttachDevice(_) => HandleType::Character, &Handle::DetachDevice(_) => HandleType::Character, + &Handle::SuspendDevice(_) => HandleType::Character, + &Handle::ResumeDevice(_) => HandleType::Character, &Handle::Endpoint(_, _, ref st) => match st { EndpointHandleTy::Data => HandleType::Character, EndpointHandleTy::Ctl => HandleType::Character, @@ -289,10 +319,13 @@ impl Handle { &Handle::PortReq(_, PortReqState::Tmp) => None, &Handle::PortReq(_, PortReqState::TmpSetup(_)) => None, &Handle::PortState(_) => None, + &Handle::PortPmState(_) => None, &Handle::PortReq(_, _) => None, &Handle::ConfigureEndpoints(_) => None, &Handle::AttachDevice(_) => None, &Handle::DetachDevice(_) => None, + &Handle::SuspendDevice(_) => None, + &Handle::ResumeDevice(_) => None, &Handle::Endpoint(_, _, ref st) => match st { EndpointHandleTy::Data => None, EndpointHandleTy::Ctl => None, @@ -383,6 +416,14 @@ impl SchemeParameters { let port_num = get_port_id_from_regex(®EX_PORT_DETACH, scheme, 0)?; Ok(Self::DetachDevice(port_num)) + } else if REGEX_PORT_SUSPEND.is_match(scheme) { + let port_num = get_port_id_from_regex(®EX_PORT_SUSPEND, scheme, 0)?; + + Ok(Self::SuspendDevice(port_num)) + } else if REGEX_PORT_RESUME.is_match(scheme) { + let port_num = get_port_id_from_regex(®EX_PORT_RESUME, scheme, 0)?; + + Ok(Self::ResumeDevice(port_num)) } else if REGEX_PORT_DESCRIPTORS.is_match(scheme) { let port_num = get_port_id_from_regex(®EX_PORT_DESCRIPTORS, scheme, 0)?; @@ -391,6 +432,10 @@ impl SchemeParameters { let port_num = get_port_id_from_regex(®EX_PORT_STATE, scheme, 0)?; Ok(Self::PortState(port_num)) + } else if REGEX_PORT_PM_STATE.is_match(scheme) { + let port_num = get_port_id_from_regex(®EX_PORT_PM_STATE, scheme, 0)?; + + Ok(Self::PortPmState(port_num)) } else if REGEX_PORT_REQUEST.is_match(scheme) { let port_num = get_port_id_from_regex(®EX_PORT_REQUEST, scheme, 0)?; @@ -556,6 +601,47 @@ impl AnyDescriptor { } impl Xhci { + fn begin_port_operation( + &self, + port: PortId, + allow_attaching: bool, + require_active_pm: bool, + ) -> Result { + let lifecycle = { + let port_state = self.port_states.get(&port).ok_or(Error::new(EBADFD))?; + Arc::clone(&port_state.lifecycle) + }; + + lifecycle.begin_operation(allow_attaching)?; + let guard = super::PortOperationGuard::new(lifecycle); + + if require_active_pm { + let pm_state = self + .port_states + .get(&port) + .ok_or(Error::new(EBADFD))? + .pm_state; + if pm_state != super::PortPmState::Active { + drop(guard); + return Err(Error::new(EBUSY)); + } + } + + Ok(guard) + } + + fn begin_transfer_operation(&self, port: PortId) -> Result { + self.begin_port_operation(port, true, true) + } + + fn begin_routable_operation(&self, port: PortId) -> Result { + self.begin_port_operation(port, false, true) + } + + fn begin_attached_operation(&self, port: PortId) -> Result { + self.begin_port_operation(port, false, false) + } + async fn new_if_desc( &self, port_id: PortId, @@ -564,15 +650,22 @@ impl Xhci { endps: impl IntoIterator, hid_descs: impl IntoIterator, lang_id: u16, + quirks: crate::usb_quirks::UsbQuirkFlags, ) -> Result { Ok(IfDesc { alternate_setting: desc.alternate_setting, class: desc.class, interface_str: if desc.interface_str > 0 { - Some( + if quirks.contains(crate::usb_quirks::UsbQuirkFlags::BAD_DESCRIPTOR) { self.fetch_string_desc(port_id, slot, desc.interface_str, lang_id) - .await?, - ) + .await + .ok() + } else { + Some( + self.fetch_string_desc(port_id, slot, desc.interface_str, lang_id) + .await?, + ) + } } else { None }, @@ -590,10 +683,9 @@ impl Xhci { /// # Locking /// This function will lock `Xhci::cmd` and `Xhci::dbs`. pub async fn execute_command(&self, f: F) -> (Trb, Trb) { - //TODO: find out why this bit is set earlier! if self.interrupt_is_pending(0) { debug!("The EHB bit is already set!"); - //self.force_clear_interrupt(0); + self.force_clear_interrupt(0); } let next_event = { @@ -628,6 +720,54 @@ impl Xhci { (event_trb, command_trb) } + pub fn execute_command_with_timeout( + &self, + timeout: common::timeout::Timeout, + f: F, + ) -> Result<(Trb, Trb)> { + if self.interrupt_is_pending(0) { + debug!("The EHB bit is already set!"); + self.force_clear_interrupt(0); + } + + let next_event = { + let mut command_ring = self.cmd.lock().unwrap(); + let (cmd_index, cycle) = (command_ring.next_index(), command_ring.cycle); + + debug!("Sending command with cycle bit {}", cycle as u8); + + { + let command_trb = &mut command_ring.trbs[cmd_index]; + f(command_trb, cycle); + } + + let command_trb = &command_ring.trbs[cmd_index]; + self.next_command_completion_event_trb( + &*command_ring, + command_trb, + EventDoorbell::new(self, 0, 0), + ) + }; + + let mut next_event = Box::pin(next_event); + + loop { + if let Some(trbs) = next_event.as_mut().now_or_never() { + let event_trb = trbs.event_trb; + let command_trb = trbs.src_trb.ok_or(Error::new(EIO))?; + + assert_eq!( + event_trb.trb_type(), + TrbType::CommandCompletion as u8, + "The IRQ reactor (or the xHC) gave an invalid event TRB" + ); + + return Ok((event_trb, command_trb)); + } + + timeout.run().map_err(|()| Error::new(EIO))?; + } + } pub async fn execute_control_transfer( &self, port_num: PortId, @@ -639,6 +779,9 @@ impl Xhci { where D: FnMut(&mut Trb, bool) -> ControlFlow, { + let _op = self.begin_transfer_operation(port_num)?; + self.ensure_port_active(port_num)?; + let future = { let mut port_state = self.port_state_mut(port_num)?; let slot = port_state.slot; @@ -690,7 +833,21 @@ impl Xhci { handle_transfer_event_trb("CONTROL_TRANSFER", &event_trb, &status_trb)?; - //self.event_handler_finished(); + let delay_ctrl_msg = self + .port_states + .get(&port_num) + .map(|port_state| { + port_state + .quirks + .contains(crate::usb_quirks::UsbQuirkFlags::DELAY_CTRL_MSG) + }) + .unwrap_or(false); + + if delay_ctrl_msg { + std::thread::sleep(std::time::Duration::from_millis(20)); + } + + self.event_handler_finished(); Ok(event_trb) } @@ -709,6 +866,9 @@ impl Xhci { where D: FnMut(&mut Trb, bool) -> ControlFlow, { + let _op = self.begin_transfer_operation(port_num)?; + self.ensure_port_active(port_num)?; + let endp_idx = endp_num.checked_sub(1).ok_or(Error::new(EIO))?; let mut port_state = self.port_state_mut(port_num)?; @@ -785,7 +945,31 @@ impl Xhci { let event_trb = trbs.event_trb; let transfer_trb = trbs.src_trb.ok_or(Error::new(EIO))?; - handle_transfer_event_trb("EXECUTE_TRANSFER", &event_trb, &transfer_trb)?; + if let Err(err) = handle_transfer_event_trb("EXECUTE_TRANSFER", &event_trb, &transfer_trb) + { + let need_reset = self + .port_states + .get(&port_num) + .map(|port_state| { + port_state + .quirks + .contains(crate::usb_quirks::UsbQuirkFlags::NEED_RESET) + }) + .unwrap_or(false); + + if need_reset { + if let Err(reset_err) = self.reset_device_slot(port_num).await { + error!( + "EXECUTE_TRANSFER reset recovery failed for port {}: {}", + port_num, reset_err + ); + } + } + + self.event_handler_finished(); + + return Err(err); + } // FIXME: EDTLA if event data was set if event_trb.completion_code() != TrbCompletionCode::ShortPacket as u8 @@ -798,6 +982,8 @@ impl Xhci { // TODO: Handle event data trace!("EVENT DATA: {:?}", event_trb.event_data()); + self.event_handler_finished(); + Ok(event_trb) } async fn device_req_no_data(&self, port: PortId, req: usb::Setup) -> Result<()> { @@ -857,10 +1043,27 @@ impl Xhci { trb.reset_endpoint(slot, endp_num_xhc, tsp, cycle); }) .await; - //self.event_handler_finished(); + self.event_handler_finished(); handle_event_trb("RESET_ENDPOINT", &event_trb, &command_trb) } + async fn reset_device_slot(&self, port_num: PortId) -> Result<()> { + let slot = self + .port_states + .get(&port_num) + .ok_or(Error::new(EBADF))? + .slot; + + let (event_trb, command_trb) = self + .execute_command(|trb, cycle| { + trb.reset_device(slot, cycle); + }) + .await; + + self.event_handler_finished(); + + handle_event_trb("RESET_DEVICE", &event_trb, &command_trb) + } fn endp_ctx_interval(speed_id: &ProtocolSpeed, endp_desc: &EndpDesc) -> u8 { /// Logarithmic (base 2) 125 µs periods per millisecond. @@ -949,35 +1152,106 @@ impl Xhci { self.port_states.get_mut(&port).ok_or(Error::new(EBADF)) } + fn restore_configure_input_context( + &self, + port: PortId, + snapshot: ConfigureContextSnapshot, + endpoint_snapshots: &[(usize, EndpointContextSnapshot)], + ) -> Result { + let port_state = self.port_states.get(&port).ok_or(Error::new(EBADFD))?; + let mut input_context = port_state + .input_context + .lock() + .unwrap_or_else(|err| err.into_inner()); + + input_context.add_context.write(snapshot.add_context); + input_context.drop_context.write(snapshot.drop_context); + input_context.control.write(snapshot.control); + input_context.device.slot.a.write(snapshot.slot_a); + input_context.device.slot.b.write(snapshot.slot_b); + input_context.device.slot.c.write(snapshot.slot_c); + + for (endp_i, endp_snapshot) in endpoint_snapshots { + input_context.device.endpoints[*endp_i].a.write(endp_snapshot.a); + input_context.device.endpoints[*endp_i].b.write(endp_snapshot.b); + input_context.device.endpoints[*endp_i].trl.write(endp_snapshot.trl); + input_context.device.endpoints[*endp_i].trh.write(endp_snapshot.trh); + input_context.device.endpoints[*endp_i].c.write(endp_snapshot.c); + } + + Ok(input_context.physical()) + } + + async fn rollback_configure_attempt( + &self, + port: PortId, + slot: u8, + configure_snapshot: ConfigureContextSnapshot, + endpoint_snapshots: &[(usize, EndpointContextSnapshot)], + stage: &str, + ) { + let rollback_input_context_physical = match self.restore_configure_input_context( + port, + configure_snapshot, + endpoint_snapshots, + ) { + Ok(physical) => physical, + Err(restore_err) => { + warn!( + "failed to restore configure input context after {}: {:?}", + stage, restore_err + ); + return; + } + }; + + let (rollback_event_trb, rollback_command_trb) = self + .execute_command(|trb, cycle| { + trb.configure_endpoint(slot, rollback_input_context_physical, cycle) + }) + .await; + + self.event_handler_finished(); + + if let Err(rollback_err) = + handle_event_trb("CONFIGURE_ENDPOINT_ROLLBACK", &rollback_event_trb, &rollback_command_trb) + { + warn!( + "failed to roll back CONFIGURE_ENDPOINT after {}: {:?}", + stage, rollback_err + ); + } + } + async fn configure_endpoints_once( &self, port: PortId, req: &ConfigureEndpointsReq, ) -> Result<()> { - let (endp_desc_count, new_context_entries, configuration_value) = { - let mut port_state = self.port_states.get_mut(&port).ok_or(Error::new(EBADFD))?; - - port_state.cfg_idx = Some(req.config_desc); + let (dev_desc, endpoint_descs, new_context_entries, configuration_value) = { + let port_state = self.port_states.get(&port).ok_or(Error::new(EBADFD))?; + let dev_desc = port_state.dev_desc.as_ref().ok_or(Error::new(EBADFD))?.clone(); - let config_desc = port_state - .dev_desc - .as_ref() - .unwrap() + let config_desc = dev_desc .config_descs .iter() .find(|desc| desc.configuration_value == req.config_desc) .ok_or(Error::new(EBADFD))?; + let configuration_value = config_desc.configuration_value; - //TODO: USE ENDPOINTS FROM ALL INTERFACES - let mut endp_desc_count = 0; - let mut new_context_entries = 1; - for if_desc in config_desc.interface_descs.iter() { - for endpoint in if_desc.endpoints.iter() { - endp_desc_count += 1; - let entry = Self::endp_num_to_dci(endp_desc_count, endpoint); - if entry > new_context_entries { - new_context_entries = entry; - } + let endpoint_descs = config_desc + .interface_descs + .iter() + .flat_map(|if_desc| if_desc.endpoints.iter().copied()) + .collect::>(); + + let endp_desc_count = endpoint_descs.len(); + let mut new_context_entries = 1u8; + for (endp_idx, endpoint) in endpoint_descs.iter().enumerate() { + let endp_num = endp_idx as u8 + 1; + let entry = Self::endp_num_to_dci(endp_num, endpoint); + if entry > new_context_entries { + new_context_entries = entry; } } new_context_entries += 1; @@ -988,11 +1262,13 @@ impl Xhci { } ( - endp_desc_count, + dev_desc, + endpoint_descs, new_context_entries, - config_desc.configuration_value, + configuration_value, ) }; + let endp_desc_count = endpoint_descs.len(); let lec = self.cap.lec(); let log_max_psa_size = self.cap.max_psa_size(); @@ -1002,9 +1278,160 @@ impl Xhci { Error::new(EIO) })?; + let mut endpoint_programs = Vec::with_capacity(endp_desc_count as usize); + let mut staged_endpoint_states = Vec::with_capacity(endp_desc_count as usize); + { + for (endp_idx, endp_desc) in endpoint_descs.iter().enumerate() { + let endp_num = endp_idx as u8 + 1; + + let endp_num_xhc = Self::endp_num_to_dci(endp_num, endp_desc); + let usb_log_max_streams = endp_desc.log_max_streams(); + + let primary_streams = if let Some(log_max_streams) = usb_log_max_streams { + if log_max_psa_size != 0 { + cmp::min(u8::from(log_max_streams), log_max_psa_size + 1) - 1 + } else { + 0 + } + } else { + 0 + }; + let linear_stream_array = primary_streams != 0; + + let mult = endp_desc.isoch_mult(lec); + + let max_packet_size = Self::endp_ctx_max_packet_size(endp_desc); + let max_burst_size = Self::endp_ctx_max_burst(speed_id, &dev_desc, endp_desc); + + let max_esit_payload = Self::endp_ctx_max_esit_payload( + speed_id, + &dev_desc, + endp_desc, + max_packet_size, + max_burst_size, + ); + let max_esit_payload_lo = max_esit_payload as u16; + let max_esit_payload_hi = ((max_esit_payload & 0x00FF_0000) >> 16) as u8; + + let interval = Self::endp_ctx_interval(speed_id, endp_desc); + + let max_error_count = 3; + let ep_ty = endp_desc.xhci_ep_type()?; + let host_initiate_disable = false; + + let avg_trb_len: u16 = match endp_desc.ty() { + EndpointTy::Ctrl => { + warn!("trying to use control endpoint"); + return Err(Error::new(EIO)); + } + EndpointTy::Bulk | EndpointTy::Isoch => 3072, + EndpointTy::Interrupt => 1024, + }; + + assert_eq!(ep_ty & 0x7, ep_ty); + assert_eq!(mult & 0x3, mult); + assert_eq!(max_error_count & 0x3, max_error_count); + assert_ne!(ep_ty, 0); + + let ring_ptr = if usb_log_max_streams.is_some() { + let mut array = + StreamContextArray::new::(self.cap.ac64(), 1 << (primary_streams + 1))?; + + array.add_ring::(self.cap.ac64(), 1, true)?; + let array_ptr = array.register(); + + assert_eq!( + array_ptr & 0xFFFF_FFFF_FFFF_FF81, + array_ptr, + "stream ctx ptr not aligned to 16 bytes" + ); + + staged_endpoint_states.push(( + endp_num, + EndpointState { + transfer: super::RingOrStreams::Streams(array), + driver_if_state: EndpIfState::Init, + }, + )); + + array_ptr + } else { + let ring = Ring::new::(self.cap.ac64(), 16, true)?; + let ring_ptr = ring.register(); + + assert_eq!( + ring_ptr & 0xFFFF_FFFF_FFFF_FF81, + ring_ptr, + "ring pointer not aligned to 16 bytes" + ); + + staged_endpoint_states.push(( + endp_num, + EndpointState { + transfer: super::RingOrStreams::Ring(ring), + driver_if_state: EndpIfState::Init, + }, + )); + + ring_ptr + }; + assert_eq!(primary_streams & 0x1F, primary_streams); + + endpoint_programs.push(EndpointProgram { + endp_num, + endp_num_xhc, + a: u32::from(mult) << 8 + | u32::from(primary_streams) << 10 + | u32::from(linear_stream_array) << 15 + | u32::from(interval) << 16 + | u32::from(max_esit_payload_hi) << 24, + b: max_error_count << 1 + | u32::from(ep_ty) << 3 + | u32::from(host_initiate_disable) << 7 + | u32::from(max_burst_size) << 8 + | u32::from(max_packet_size) << 16, + trl: ring_ptr as u32, + trh: (ring_ptr >> 32) as u32, + c: u32::from(avg_trb_len) | (u32::from(max_esit_payload_lo) << 16), + }); + + log::debug!("staged endpoint {}", endp_num); + } + } + + let (configure_snapshot, endpoint_snapshots, input_context_physical) = { let port_state = self.port_states.get(&port).ok_or(Error::new(EBADFD))?; - let mut input_context = port_state.input_context.lock().unwrap(); + let mut input_context = port_state + .input_context + .lock() + .unwrap_or_else(|err| err.into_inner()); + + let configure_snapshot = ConfigureContextSnapshot { + add_context: input_context.add_context.read(), + drop_context: input_context.drop_context.read(), + control: input_context.control.read(), + slot_a: input_context.device.slot.a.read(), + slot_b: input_context.device.slot.b.read(), + slot_c: input_context.device.slot.c.read(), + }; + + let endpoint_snapshots = endpoint_programs + .iter() + .map(|program| { + let endp_i = program.endp_num_xhc as usize - 1; + ( + endp_i, + EndpointContextSnapshot::capture_values( + input_context.device.endpoints[endp_i].a.read(), + input_context.device.endpoints[endp_i].b.read(), + input_context.device.endpoints[endp_i].trl.read(), + input_context.device.endpoints[endp_i].trh.read(), + input_context.device.endpoints[endp_i].c.read(), + ), + ) + }) + .collect::>(); // Configure the slot context as well, which holds the last index of the endp descs. input_context.add_context.write(1); @@ -1015,25 +1442,26 @@ impl Xhci { const HUB_PORTS_MASK: u32 = 0xFF00_0000; const HUB_PORTS_SHIFT: u8 = 24; + let mut current_slot_c = input_context.device.slot.c.read(); let mut current_slot_a = input_context.device.slot.a.read(); let mut current_slot_b = input_context.device.slot.b.read(); - // Set context entries current_slot_a &= !CONTEXT_ENTRIES_MASK; current_slot_a |= (u32::from(new_context_entries) << CONTEXT_ENTRIES_SHIFT) & CONTEXT_ENTRIES_MASK; - // Set hub data current_slot_a &= !(1 << 26); current_slot_b &= !HUB_PORTS_MASK; if let Some(hub_ports) = req.hub_ports { current_slot_a |= 1 << 26; current_slot_b |= (u32::from(hub_ports) << HUB_PORTS_SHIFT) & HUB_PORTS_MASK; } + current_slot_c = apply_hub_tt_info(current_slot_c, req); input_context.device.slot.a.write(current_slot_a); input_context.device.slot.b.write(current_slot_b); + input_context.device.slot.c.write(current_slot_c); let control = if self.op.lock().unwrap().cie() { (u32::from(req.alternate_setting.unwrap_or(0)) << 16) @@ -1043,174 +1471,138 @@ impl Xhci { 0 }; input_context.control.write(control); - } - for endp_idx in 0..endp_desc_count as u8 { - let endp_num = endp_idx + 1; - - let mut port_state = self.port_states.get_mut(&port).ok_or(Error::new(EBADFD))?; - let dev_desc = port_state.dev_desc.as_ref().unwrap(); - let endp_desc = port_state.get_endp_desc(endp_idx).ok_or_else(|| { - warn!("failed to find endpoint {}", endp_idx); - Error::new(EIO) - })?; - - let endp_num_xhc = Self::endp_num_to_dci(endp_num, endp_desc); - - let usb_log_max_streams = endp_desc.log_max_streams(); - - // TODO: Secondary streams. - let primary_streams = if let Some(log_max_streams) = usb_log_max_streams { - // TODO: Can streams-capable be configured to not use streams? - if log_max_psa_size != 0 { - cmp::min(u8::from(log_max_streams), log_max_psa_size + 1) - 1 - } else { - 0 - } - } else { - 0 - }; - let linear_stream_array = if primary_streams != 0 { true } else { false }; + for program in &endpoint_programs { + let endp_i = program.endp_num_xhc as usize - 1; + input_context.add_context.writef(1 << program.endp_num_xhc, true); + input_context.device.endpoints[endp_i].a.write(program.a); + input_context.device.endpoints[endp_i].b.write(program.b); + input_context.device.endpoints[endp_i].trl.write(program.trl); + input_context.device.endpoints[endp_i].trh.write(program.trh); + input_context.device.endpoints[endp_i].c.write(program.c); + } - // TODO: Interval related fields - // TODO: Max ESIT payload size. + (configure_snapshot, endpoint_snapshots, input_context.physical()) + }; - let mult = endp_desc.isoch_mult(lec); + let port_state = self.port_states.get(&port).ok_or(Error::new(EBADFD))?; + let slot = port_state.slot; - let max_packet_size = Self::endp_ctx_max_packet_size(endp_desc); - let max_burst_size = Self::endp_ctx_max_burst(speed_id, dev_desc, endp_desc); + let (event_trb, command_trb) = self + .execute_command(|trb, cycle| trb.configure_endpoint(slot, input_context_physical, cycle)) + .await; - let max_esit_payload = Self::endp_ctx_max_esit_payload( - speed_id, - dev_desc, - endp_desc, - max_packet_size, - max_burst_size, - ); - let max_esit_payload_lo = max_esit_payload as u16; - let max_esit_payload_hi = ((max_esit_payload & 0x00FF_0000) >> 16) as u8; - - let interval = Self::endp_ctx_interval(speed_id, endp_desc); - - let max_error_count = 3; - let ep_ty = endp_desc.xhci_ep_type()?; - let host_initiate_disable = false; - - // TODO: Maybe this value is out of scope for xhcid, because the actual usb device - // driver probably knows better. The spec says that the initial value should be 8 bytes - // for control, 1KiB for interrupt and 3KiB for bulk and isoch. - let avg_trb_len: u16 = match endp_desc.ty() { - EndpointTy::Ctrl => { - warn!("trying to use control endpoint"); - return Err(Error::new(EIO)); // only endpoint zero is of type control, and is configured separately with the address device command. + self.event_handler_finished(); + + if let Err(err) = handle_event_trb("CONFIGURE_ENDPOINT", &event_trb, &command_trb) { + let rollback_input_context_physical = match self.restore_configure_input_context( + port, + configure_snapshot, + &endpoint_snapshots, + ) { + Ok(physical) => physical, + Err(restore_err) => { + warn!( + "failed to restore configure input context after CONFIGURE_ENDPOINT failure: {:?}", + restore_err + ); + return Err(err); } - EndpointTy::Bulk | EndpointTy::Isoch => 3072, // 3 KiB - EndpointTy::Interrupt => 1024, // 1 KiB }; - assert_eq!(ep_ty & 0x7, ep_ty); - assert_eq!(mult & 0x3, mult); - assert_eq!(max_error_count & 0x3, max_error_count); - assert_ne!(ep_ty, 0); // 0 means invalid. - - let ring_ptr = if usb_log_max_streams.is_some() { - let mut array = - StreamContextArray::new::(self.cap.ac64(), 1 << (primary_streams + 1))?; + let (rollback_event_trb, rollback_command_trb) = self + .execute_command(|trb, cycle| { + trb.configure_endpoint(slot, rollback_input_context_physical, cycle) + }) + .await; - // TODO: Use as many stream rings as needed. - array.add_ring::(self.cap.ac64(), 1, true)?; - let array_ptr = array.register(); + self.event_handler_finished(); - assert_eq!( - array_ptr & 0xFFFF_FFFF_FFFF_FF81, - array_ptr, - "stream ctx ptr not aligned to 16 bytes" - ); - port_state.endpoint_states.insert( - endp_num, - EndpointState { - transfer: super::RingOrStreams::Streams(array), - driver_if_state: EndpIfState::Init, - }, + if let Err(rollback_err) = + handle_event_trb("CONFIGURE_ENDPOINT_ROLLBACK", &rollback_event_trb, &rollback_command_trb) + { + warn!( + "failed to roll back CONFIGURE_ENDPOINT after failure {:?}: {:?}", + err, + rollback_err ); + } - array_ptr - } else { - let ring = Ring::new::(self.cap.ac64(), 16, true)?; - let ring_ptr = ring.register(); + return Err(err); + } - assert_eq!( - ring_ptr & 0xFFFF_FFFF_FFFF_FF81, - ring_ptr, - "ring pointer not aligned to 16 bytes" - ); - port_state.endpoint_states.insert( - endp_num, - EndpointState { - transfer: super::RingOrStreams::Ring(ring), - driver_if_state: EndpIfState::Init, - }, - ); - ring_ptr - }; - assert_eq!(primary_streams & 0x1F, primary_streams); - - let mut input_context = port_state.input_context.lock().unwrap(); - input_context.add_context.writef(1 << endp_num_xhc, true); - - let endp_i = endp_num_xhc as usize - 1; - input_context.device.endpoints[endp_i].a.write( - u32::from(mult) << 8 - | u32::from(primary_streams) << 10 - | u32::from(linear_stream_array) << 15 - | u32::from(interval) << 16 - | u32::from(max_esit_payload_hi) << 24, + if self.consume_test_hook("fail_after_configure_endpoint") { + info!( + "xhcid: test hook injecting failure after CONFIGURE_ENDPOINT for port {}", + port ); - input_context.device.endpoints[endp_i].b.write( - max_error_count << 1 - | u32::from(ep_ty) << 3 - | u32::from(host_initiate_disable) << 7 - | u32::from(max_burst_size) << 8 - | u32::from(max_packet_size) << 16, - ); - - input_context.device.endpoints[endp_i] - .trl - .write(ring_ptr as u32); - input_context.device.endpoints[endp_i] - .trh - .write((ring_ptr >> 32) as u32); - - input_context.device.endpoints[endp_i] - .c - .write(u32::from(avg_trb_len) | (u32::from(max_esit_payload_lo) << 16)); - - log::debug!("initialized endpoint {}", endp_num); + self.rollback_configure_attempt( + port, + slot, + configure_snapshot, + &endpoint_snapshots, + "test hook fail_after_configure_endpoint", + ) + .await; + return Err(Error::new(EIO)); } - { - let port_state = self.port_states.get(&port).ok_or(Error::new(EBADFD))?; - let slot = port_state.slot; - let input_context_physical = port_state.input_context.lock().unwrap().physical(); + // Tell the device about this configuration. + let skip_set_configuration = self + .port_states + .get(&port) + .map(|port_state| { + port_state + .quirks + .contains(crate::usb_quirks::UsbQuirkFlags::NO_SET_CONFIG) + }) + .unwrap_or(false); - let (event_trb, command_trb) = self - .execute_command(|trb, cycle| { - trb.configure_endpoint(slot, input_context_physical, cycle) - }) + if !skip_set_configuration { + if let Err(err) = self.set_configuration(port, configuration_value).await { + self.rollback_configure_attempt( + port, + slot, + configure_snapshot, + &endpoint_snapshots, + "set_configuration failure", + ) .await; - //self.event_handler_finished(); + return Err(err); + } - handle_event_trb("CONFIGURE_ENDPOINT", &event_trb, &command_trb)?; + if self.consume_test_hook("fail_after_set_configuration") { + info!( + "xhcid: test hook injecting failure after SET_CONFIGURATION for port {}", + port + ); + self.rollback_configure_attempt( + port, + slot, + configure_snapshot, + &endpoint_snapshots, + "test hook fail_after_set_configuration", + ) + .await; + return Err(Error::new(EIO)); + } } - // Tell the device about this configuration. - self.set_configuration(port, configuration_value).await?; + { + let mut port_state = self.port_states.get_mut(&port).ok_or(Error::new(EBADFD))?; + port_state.cfg_idx = Some(configuration_value); + port_state.endpoint_states.retain(|endp_num, _| *endp_num == 0); + for (endp_num, endpoint_state) in staged_endpoint_states { + port_state.endpoint_states.insert(endp_num, endpoint_state); + } + } Ok(()) } async fn configure_endpoints(&self, port: PortId, json_buf: &[u8]) -> Result<()> { + let _op = self.begin_routable_operation(port)?; let mut req: ConfigureEndpointsReq = serde_json::from_slice(json_buf).or(Err(Error::new(EBADMSG)))?; @@ -1234,8 +1626,20 @@ impl Xhci { if let Some(interface_num) = req.interface_desc { if let Some(alternate_setting) = req.alternate_setting { - self.set_interface(port, interface_num, alternate_setting) - .await?; + let skip_set_interface = self + .port_states + .get(&port) + .map(|port_state| { + port_state + .quirks + .contains(crate::usb_quirks::UsbQuirkFlags::NO_SET_INTF) + }) + .unwrap_or(false); + + if !skip_set_interface { + self.set_interface(port, interface_num, alternate_setting) + .await?; + } } } @@ -1432,7 +1836,7 @@ impl Xhci { }, ) .await?; - //self.event_handler_finished(); + self.event_handler_finished(); let bytes_transferred = dma_buf .as_ref() @@ -1453,52 +1857,109 @@ impl Xhci { let raw_dd = self.fetch_dev_desc(port_id, slot).await?; log::debug!("port {} slot {} desc {:X?}", port_id, slot, raw_dd); + let vendor = raw_dd.vendor; + let product = raw_dd.product; + let quirks = crate::usb_quirks::lookup_usb_quirks(vendor, product); + if !quirks.is_empty() { + log::info!( + "port {}: USB quirks for {:04x}:{:04x}: {:?}", + port_id, vendor, product, quirks + ); + } + // Only fetch language IDs if we need to. Some devices will fail to return this descriptor //TODO: also check configurations and interfaces for defined strings? + let bad_descriptor = quirks.contains(crate::usb_quirks::UsbQuirkFlags::BAD_DESCRIPTOR); + let lang_id = - if raw_dd.manufacturer_str > 0 || raw_dd.product_str > 0 || raw_dd.serial_str > 0 { - let lang_ids = self.fetch_lang_ids_desc(port_id, slot).await?; - // Prefer US English, but fall back to first language ID, or zero - let en_us_id = 0x409; - if lang_ids.contains(&en_us_id) { - en_us_id - } else { - match lang_ids.first() { - Some(some) => *some, - None => 0, + if !quirks.contains(crate::usb_quirks::UsbQuirkFlags::NO_STRING_FETCH) + && (raw_dd.manufacturer_str > 0 + || raw_dd.product_str > 0 + || raw_dd.serial_str > 0) + { + match self.fetch_lang_ids_desc(port_id, slot).await { + Ok(lang_ids) => { + // Prefer US English, but fall back to first language ID, or zero + let en_us_id = 0x409; + if lang_ids.contains(&en_us_id) { + en_us_id + } else { + match lang_ids.first() { + Some(some) => *some, + None => 0, + } + } + } + Err(err) if bad_descriptor => { + log::warn!( + "port {} slot {}: failed to fetch language IDs with BAD_DESCRIPTOR set: {}", + port_id, + slot, + err + ); + 0 } + Err(err) => return Err(err), } } else { 0 }; log::debug!("port {} using language ID 0x{:04x}", port_id, lang_id); - let (manufacturer_str, product_str, serial_str) = ( - if raw_dd.manufacturer_str > 0 { - Some( - self.fetch_string_desc(port_id, slot, raw_dd.manufacturer_str, lang_id) - .await?, - ) - } else { - None - }, - if raw_dd.product_str > 0 { - Some( - self.fetch_string_desc(port_id, slot, raw_dd.product_str, lang_id) - .await?, - ) + let (manufacturer_str, product_str, serial_str) = + if quirks.contains(crate::usb_quirks::UsbQuirkFlags::NO_STRING_FETCH) { + (None, None, None) } else { - None - }, - if raw_dd.serial_str > 0 { - Some( - self.fetch_string_desc(port_id, slot, raw_dd.serial_str, lang_id) - .await?, + ( + if raw_dd.manufacturer_str > 0 { + if bad_descriptor { + self.fetch_string_desc(port_id, slot, raw_dd.manufacturer_str, lang_id) + .await + .ok() + } else { + Some( + self.fetch_string_desc( + port_id, + slot, + raw_dd.manufacturer_str, + lang_id, + ) + .await?, + ) + } + } else { + None + }, + if raw_dd.product_str > 0 { + if bad_descriptor { + self.fetch_string_desc(port_id, slot, raw_dd.product_str, lang_id) + .await + .ok() + } else { + Some( + self.fetch_string_desc(port_id, slot, raw_dd.product_str, lang_id) + .await?, + ) + } + } else { + None + }, + if raw_dd.serial_str > 0 { + if bad_descriptor { + self.fetch_string_desc(port_id, slot, raw_dd.serial_str, lang_id) + .await + .ok() + } else { + Some( + self.fetch_string_desc(port_id, slot, raw_dd.serial_str, lang_id) + .await?, + ) + } + } else { + None + }, ) - } else { - None - }, - ); + }; log::debug!( "manufacturer {:?} product {:?} serial {:?}", manufacturer_str, @@ -1508,14 +1969,39 @@ impl Xhci { //TODO let (bos_desc, bos_data) = self.fetch_bos_desc(port_id, slot).await?; - let supports_superspeed = false; - //TODO usb::bos_capability_descs(bos_desc, &bos_data).any(|desc| desc.is_superspeed()); - let supports_superspeedplus = false; - //TODO usb::bos_capability_descs(bos_desc, &bos_data).any(|desc| desc.is_superspeedplus()); + let (supports_superspeed, supports_superspeedplus) = + if quirks.contains(crate::usb_quirks::UsbQuirkFlags::NO_BOS) { + (false, false) + } else { + match self.fetch_bos_desc(port_id, slot).await { + Ok((bos_desc, bos_data)) => ( + usb::bos_capability_descs(bos_desc, &bos_data) + .any(|desc| desc.is_superspeed()), + usb::bos_capability_descs(bos_desc, &bos_data) + .any(|desc| desc.is_superspeedplus()), + ), + Err(err) => { + log::debug!( + "port {} slot {}: failed to fetch BOS descriptor: {}", + port_id, + slot, + err + ); + (false, false) + } + } + }; let mut config_descs = SmallVec::new(); - for index in 0..raw_dd.configurations { + let configuration_indices: Vec = + if quirks.contains(crate::usb_quirks::UsbQuirkFlags::FORCE_ONE_CONFIG) { + vec![0] + } else { + (0..raw_dd.configurations).collect() + }; + + for index in configuration_indices { debug!("Fetching the config descriptor at index {}", index); let (desc, data) = self.fetch_config_desc(port_id, slot, index).await?; log::debug!( @@ -1541,6 +2027,12 @@ impl Xhci { let mut iter = descriptors.into_iter().peekable(); while let Some(item) = iter.next() { + if quirks.contains(crate::usb_quirks::UsbQuirkFlags::HONOR_BNUMINTERFACES) + && interface_descs.len() >= desc.interfaces as usize + { + break; + } + if let AnyDescriptor::Interface(idesc) = item { let mut endpoints = SmallVec::<[EndpDesc; 4]>::new(); let mut hid_descs = SmallVec::<[HidDesc; 1]>::new(); @@ -1554,6 +2046,9 @@ impl Xhci { } Some(unexpected) => { log::warn!("expected endpoint, got {:X?}", unexpected); + if bad_descriptor { + continue; + } break; } None => break, @@ -1578,8 +2073,16 @@ impl Xhci { } interface_descs.push( - self.new_if_desc(port_id, slot, idesc, endpoints, hid_descs, lang_id) - .await?, + self.new_if_desc( + port_id, + slot, + idesc, + endpoints, + hid_descs, + lang_id, + quirks, + ) + .await?, ); } else { log::warn!("expected interface, got {:?}", item); @@ -1590,11 +2093,20 @@ impl Xhci { config_descs.push(ConfDesc { kind: desc.kind, - configuration: if desc.configuration_str > 0 { - Some( + configuration: if quirks.contains(crate::usb_quirks::UsbQuirkFlags::NO_STRING_FETCH) + { + None + } else if desc.configuration_str > 0 { + if bad_descriptor { self.fetch_string_desc(port_id, slot, desc.configuration_str, lang_id) - .await?, - ) + .await + .ok() + } else { + Some( + self.fetch_string_desc(port_id, slot, desc.configuration_str, lang_id) + .await?, + ) + } } else { None }, @@ -1856,7 +2368,7 @@ impl Xhci { if (flags & O_DIRECTORY != 0) || (flags & O_STAT != 0) { let mut contents = Vec::new(); - write!(contents, "descriptors\nendpoints\n").unwrap(); + write!(contents, "descriptors\nendpoints\npm_state\nsuspend\nresume\n").unwrap(); if self.slot_state( self.port_states @@ -1893,6 +2405,14 @@ impl Xhci { Ok(Handle::PortState(port_num)) } + fn open_handle_port_pm_state(&self, port_num: PortId, flags: usize) -> Result { + if flags & O_DIRECTORY != 0 && flags & O_STAT == 0 { + return Err(Error::new(ENOTDIR)); + } + + Ok(Handle::PortPmState(port_num)) + } + /// implements open() for /port/endpoints /// /// # Arguments @@ -2087,6 +2607,30 @@ impl Xhci { Ok(Handle::DetachDevice(port_num)) } + fn open_handle_suspend_device(&self, port_num: PortId, flags: usize) -> Result { + if flags & O_DIRECTORY != 0 && flags & O_STAT == 0 { + return Err(Error::new(ENOTDIR)); + } + + if flags & O_RDWR != O_WRONLY && flags & O_STAT == 0 { + return Err(Error::new(EACCES)); + } + + Ok(Handle::SuspendDevice(port_num)) + } + + fn open_handle_resume_device(&self, port_num: PortId, flags: usize) -> Result { + if flags & O_DIRECTORY != 0 && flags & O_STAT == 0 { + return Err(Error::new(ENOTDIR)); + } + + if flags & O_RDWR != O_WRONLY && flags & O_STAT == 0 { + return Err(Error::new(EACCES)); + } + + Ok(Handle::ResumeDevice(port_num)) + } + /// implements open() for /port/request /// /// # Arguments @@ -2155,6 +2699,9 @@ impl SchemeSync for &Xhci { SchemeParameters::PortState(port_number) => { self.open_handle_port_state(port_number, flags)? } + SchemeParameters::PortPmState(port_number) => { + self.open_handle_port_pm_state(port_number, flags)? + } SchemeParameters::PortReq(port_number) => { self.open_handle_port_request(port_number, flags)? } @@ -2173,6 +2720,12 @@ impl SchemeSync for &Xhci { SchemeParameters::DetachDevice(port_number) => { self.open_handle_detach_device(port_number, flags)? } + SchemeParameters::SuspendDevice(port_number) => { + self.open_handle_suspend_device(port_number, flags)? + } + SchemeParameters::ResumeDevice(port_number) => { + self.open_handle_resume_device(port_number, flags)? + } }; let fd = self.next_handle.fetch_add(1, atomic::Ordering::Relaxed); @@ -2203,7 +2756,11 @@ impl SchemeSync for &Xhci { //If we have a handle to the configure scheme, we need to mark it as write only. match &*guard { - Handle::ConfigureEndpoints(_) | Handle::AttachDevice(_) | Handle::DetachDevice(_) => { + Handle::ConfigureEndpoints(_) + | Handle::AttachDevice(_) + | Handle::DetachDevice(_) + | Handle::SuspendDevice(_) + | Handle::ResumeDevice(_) => { stat.st_mode = stat.st_mode | 0o200; } _ => {} @@ -2263,6 +2820,8 @@ impl SchemeSync for &Xhci { Handle::ConfigureEndpoints(_) => Err(Error::new(EBADF)), Handle::AttachDevice(_) => Err(Error::new(EBADF)), Handle::DetachDevice(_) => Err(Error::new(EBADF)), + Handle::SuspendDevice(_) => Err(Error::new(EBADF)), + Handle::ResumeDevice(_) => Err(Error::new(EBADF)), Handle::SchemeRoot => Err(Error::new(EBADF)), &mut Handle::Endpoint(port_num, endp_num, ref mut st) => match st { @@ -2294,6 +2853,10 @@ impl SchemeSync for &Xhci { Ok(Xhci::::write_dyn_string(string, buf, offset)) } + &mut Handle::PortPmState(port_num) => { + let ps = self.port_states.get(&port_num).ok_or(Error::new(EBADF))?; + Ok(Xhci::::write_dyn_string(ps.pm_state.as_str().as_bytes(), buf, offset)) + } &mut Handle::PortReq(port_num, ref mut st) => { let state = std::mem::replace(st, PortReqState::Tmp); drop(guard); // release the lock @@ -2333,6 +2896,14 @@ impl SchemeSync for &Xhci { block_on(self.detach_device(port_num))?; Ok(buf.len()) } + &mut Handle::SuspendDevice(port_num) => { + block_on(self.suspend_device(port_num))?; + Ok(buf.len()) + } + &mut Handle::ResumeDevice(port_num) => { + block_on(self.resume_device(port_num))?; + Ok(buf.len()) + } &mut Handle::Endpoint(port_num, endp_num, ref ep_file_ty) => match ep_file_ty { EndpointHandleTy::Ctl => block_on(self.on_write_endp_ctl(port_num, endp_num, buf)), EndpointHandleTy::Data => { @@ -2356,6 +2927,59 @@ impl Xhci { self.handles.remove(&fd); } + fn ensure_port_active(&self, port_num: PortId) -> Result<()> { + let port_state = self.port_states.get(&port_num).ok_or(Error::new(EBADFD))?; + if port_state.lifecycle.state() == super::PortLifecycleState::Detaching { + return Err(Error::new(EBUSY)); + } + + let pm_state = port_state.pm_state; + match pm_state { + super::PortPmState::Active => Ok(()), + super::PortPmState::Suspended => Err(Error::new(EBUSY)), + } + } + + pub async fn suspend_device(&self, port_num: PortId) -> Result<()> { + let _op = self.begin_attached_operation(port_num)?; + let mut port_state = self.port_states.get_mut(&port_num).ok_or(Error::new(EBADFD))?; + + if port_state + .quirks + .contains(crate::usb_quirks::UsbQuirkFlags::NO_SUSPEND) + { + return Err(Error::new(EOPNOTSUPP)); + } + + if port_state.pm_state != super::PortPmState::Active { + return Err(Error::new(EBUSY)); + } + + port_state.pm_state = super::PortPmState::Suspended; + Ok(()) + } + + pub async fn resume_device(&self, port_num: PortId) -> Result<()> { + let _op = self.begin_attached_operation(port_num)?; + let mut port_state = self.port_states.get_mut(&port_num).ok_or(Error::new(EBADFD))?; + + if port_state.pm_state == super::PortPmState::Active { + return Ok(()); + } + + let slot_state = self.slot_state(port_state.slot as usize); + if slot_state != SlotState::Addressed as u8 && slot_state != SlotState::Configured as u8 { + warn!( + "refusing to resume port {} while slot {} is in controller state {}", + port_num, port_state.slot, slot_state + ); + return Err(Error::new(EIO)); + } + + port_state.pm_state = super::PortPmState::Active; + Ok(()) + } + pub fn get_endp_status(&self, port_num: PortId, endp_num: u8) -> Result { let port_state = self.port_states.get(&port_num).ok_or(Error::new(EBADFD))?; @@ -2406,6 +3030,8 @@ impl Xhci { endp_num: u8, clear_feature: bool, ) -> Result<()> { + self.ensure_port_active(port_num)?; + if self.get_endp_status(port_num, endp_num)? != EndpointStatus::Halted { return Err(Error::new(EPROTO)); } @@ -2531,7 +3157,7 @@ impl Xhci { ) }) .await; - //self.event_handler_finished(); + self.event_handler_finished(); handle_event_trb("SET_TR_DEQUEUE_PTR", &event_trb, &command_trb) } @@ -2541,10 +3167,14 @@ impl Xhci { endp_num: u8, buf: &[u8], ) -> Result { + let _op = self.begin_routable_operation(port_num)?; let mut port_state = self .port_states .get_mut(&port_num) .ok_or(Error::new(EBADF))?; + if port_state.pm_state != super::PortPmState::Active { + return Err(Error::new(EBUSY)); + } let ep_if_state = &mut port_state .endpoint_states @@ -2562,6 +3192,7 @@ impl Xhci { }, XhciEndpCtlReq::Reset { no_clear_feature } => match ep_if_state { EndpIfState::Init => { + drop(port_state); self.on_req_reset_device(port_num, endp_num, !no_clear_feature) .await? } @@ -2631,6 +3262,9 @@ impl Xhci { endp_num: u8, buf: &[u8], ) -> Result { + let _op = self.begin_routable_operation(port_num)?; + self.ensure_port_active(port_num)?; + let mut port_state = self .port_states .get_mut(&port_num) @@ -2732,6 +3366,9 @@ impl Xhci { endp_num: u8, buf: &mut [u8], ) -> Result { + let _op = self.begin_routable_operation(port_num)?; + self.ensure_port_active(port_num)?; + let mut port_state = self .port_states .get_mut(&port_num) @@ -2832,6 +3469,64 @@ pub fn handle_transfer_event_trb(name: &str, event_trb: &Trb, transfer_trb: &Trb Err(Error::new(EIO)) } } + +fn apply_hub_tt_info(current_slot_c: u32, req: &ConfigureEndpointsReq) -> u32 { + const TT_THINK_TIME_MASK: u32 = 0x0003_0000; + const TT_THINK_TIME_SHIFT: u8 = 16; + + let mut slot_c = current_slot_c & !TT_THINK_TIME_MASK; + if req.hub_ports.is_some() { + if let Some(hub_think_time) = req.hub_think_time { + slot_c |= (u32::from(hub_think_time) << TT_THINK_TIME_SHIFT) & TT_THINK_TIME_MASK; + } + } + slot_c +} + +#[derive(Clone, Copy)] +struct ConfigureContextSnapshot { + add_context: u32, + drop_context: u32, + control: u32, + slot_a: u32, + slot_b: u32, + slot_c: u32, +} + +#[derive(Clone, Copy)] +struct EndpointContextSnapshot { + a: u32, + b: u32, + trl: u32, + trh: u32, + c: u32, +} + +impl EndpointContextSnapshot { + fn capture_values(a: u32, b: u32, trl: u32, trh: u32, c: u32) -> Self { + Self { a, b, trl, trh, c } + } + + fn restore(&self, ctx: &mut EndpointContext) { + ctx.a.write(self.a); + ctx.b.write(self.b); + ctx.trl.write(self.trl); + ctx.trh.write(self.trh); + ctx.c.write(self.c); + } +} + +#[derive(Clone, Copy)] +struct EndpointProgram { + endp_num: u8, + endp_num_xhc: u8, + a: u32, + b: u32, + trl: u32, + trh: u32, + c: u32, +} + use lazy_static::lazy_static; use std::ops::{Add, Div, Rem}; @@ -2845,3 +3540,26 @@ where a / b } } + +#[cfg(test)] +mod tests { + use super::{apply_hub_tt_info, ConfigureEndpointsReq}; + + #[test] + fn apply_hub_tt_info_only_sets_bits_for_hub_requests() { + let req = ConfigureEndpointsReq { + config_desc: 1, + interface_desc: None, + alternate_setting: None, + hub_ports: Some(4), + hub_think_time: Some(3), + }; + assert_eq!(apply_hub_tt_info(0, &req), 0x0003_0000); + + let no_hub = ConfigureEndpointsReq { + hub_ports: None, + ..req.clone() + }; + assert_eq!(apply_hub_tt_info(0x0003_0000, &no_hub), 0); + } +} diff --git a/init.initfs.d/40_ps2d.service b/init.initfs.d/40_ps2d.service index 881e75ea..bbee2699 100644 --- a/init.initfs.d/40_ps2d.service +++ b/init.initfs.d/40_ps2d.service @@ -5,4 +5,4 @@ condition_architecture = ["x86", "x86_64"] [service] cmd = "ps2d" -type = "notify" +type = "oneshot_async" diff --git a/init/src/scheduler.rs b/init/src/scheduler.rs index d42a4e57..f8ac5cd3 100644 --- a/init/src/scheduler.rs +++ b/init/src/scheduler.rs @@ -1,7 +1,8 @@ use std::collections::VecDeque; -use crate::InitConfig; +use crate::service::ServiceType; use crate::unit::{Unit, UnitId, UnitKind, UnitStore}; +use crate::InitConfig; pub struct Scheduler { pending: VecDeque, @@ -92,22 +93,31 @@ fn run(unit: &mut Unit, config: &mut InitConfig) { } UnitKind::Service { service } => { if config.skip_cmd.contains(&service.cmd) { - eprintln!("Skipping '{} {}'", service.cmd, service.args.join(" ")); + eprintln!("init: SKIP {} {}", service.cmd, service.args.join(" ")); return; } - if config.log_debug { - eprintln!( - "Starting {} ({})", - unit.info.description.as_ref().unwrap_or(&unit.id.0), - service.cmd, - ); + + // Always log blocking service types (notify, scheme, oneshot) + // since these can hang the boot if the child fails to signal. + // OneshotAsync services are fire-and-forget, only log at debug. + let is_blocking = !matches!(service.type_, ServiceType::OneshotAsync); + + if is_blocking || config.log_debug { + let desc = unit.info.description.as_ref().map(|s| s.as_str()).unwrap_or("-"); + eprintln!("init: START {desc} | {} {}", service.cmd, service.args.join(" ")); } + service.spawn(&config.envs); + + if is_blocking || config.log_debug { + let desc = unit.info.description.as_ref().map(|s| s.as_str()).unwrap_or("-"); + eprintln!("init: DONE {desc} | {} {}", service.cmd, service.args.join(" ")); + } } UnitKind::Target {} => { if config.log_debug { eprintln!( - "Reached target {}", + "init: TARGET {}", unit.info.description.as_ref().unwrap_or(&unit.id.0), ); } diff --git a/init/src/script.rs b/init/src/script.rs index d18e3a04..40bcf9a4 100644 --- a/init/src/script.rs +++ b/init/src/script.rs @@ -1,8 +1,8 @@ use std::collections::BTreeMap; use std::{env, io, iter}; -use crate::InitConfig; use crate::unit::UnitId; +use crate::InitConfig; pub fn subst_env<'a>(arg: &str) -> String { if arg.starts_with('$') {