fix: add driver-core removal lifecycle

This commit is contained in:
2026-05-09 01:33:08 +01:00
parent c11c7a04df
commit ff99ad455b
2 changed files with 208 additions and 10 deletions
@@ -5,7 +5,7 @@ use alloc::vec::Vec;
use crate::bus::{Bus, BusError};
use crate::device::{BoundDevice, DeviceId, DeviceInfo};
use crate::driver::{Driver, ProbeResult};
use crate::driver::{Driver, DriverError, ProbeResult};
/// Event emitted by the device manager during discovery or deferred-probe processing.
#[derive(Clone, Debug, PartialEq, Eq)]
@@ -203,6 +203,27 @@ impl DeviceManager {
events
}
/// Detaches a previously bound device and invokes the owning driver's cleanup path.
pub fn remove_device(&mut self, device: &DeviceId) -> Result<Option<String>, DriverError> {
let Some(bound) = self.bound_devices.get(device).cloned() else {
return Ok(None);
};
let Some(driver) = self
.drivers
.iter()
.find(|driver| driver.name() == bound.driver_name)
else {
return Err(DriverError::Other("bound driver missing"));
};
driver.remove(&bound.info)?;
self.bound_devices.remove(device);
self.deferred_queue
.retain(|(info, _driver_name)| info.id != *device);
Ok(Some(bound.driver_name))
}
fn probe_device(&mut self, info: DeviceInfo, events: &mut Vec<ProbeEvent>) {
let mut matched = false;
@@ -276,7 +297,10 @@ impl DeviceManager {
}
fn enqueue_deferred(&mut self, info: DeviceInfo, driver_name: String) {
let already_queued = self.deferred_queue.iter().any(|(queued_info, queued_driver)| {
let already_queued = self
.deferred_queue
.iter()
.any(|(queued_info, queued_driver)| {
queued_info.id == info.id && queued_driver == &driver_name
});
@@ -319,6 +343,8 @@ mod tests {
description: &'static str,
priority: i32,
matches: Vec<DriverMatch>,
probe_result: ProbeResult,
remove_count: Option<&'static std::sync::atomic::AtomicUsize>,
}
impl Driver for MockDriver {
@@ -339,10 +365,13 @@ mod tests {
}
fn probe(&self, _info: &DeviceInfo) -> ProbeResult {
ProbeResult::NotSupported
self.probe_result.clone()
}
fn remove(&self, _info: &DeviceInfo) -> Result<(), DriverError> {
if let Some(remove_count) = self.remove_count {
remove_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
Ok(())
}
}
@@ -376,6 +405,8 @@ mod tests {
subsystem_vendor: None,
subsystem_device: None,
}],
probe_result: ProbeResult::NotSupported,
remove_count: None,
}));
manager.register_driver(Box::new(MockDriver {
name: "high",
@@ -390,6 +421,8 @@ mod tests {
subsystem_vendor: None,
subsystem_device: None,
}],
probe_result: ProbeResult::NotSupported,
remove_count: None,
}));
assert_eq!(manager.buses.len(), 1);
@@ -430,4 +463,70 @@ mod tests {
if bus == "pci" && *device_count == 1
)));
}
#[test]
fn remove_device_invokes_driver_and_allows_reprobe() {
static REMOVE_COUNT: std::sync::atomic::AtomicUsize =
std::sync::atomic::AtomicUsize::new(0);
REMOVE_COUNT.store(0, std::sync::atomic::Ordering::SeqCst);
let info = DeviceInfo {
id: DeviceId {
bus: String::from("pci"),
path: String::from("0000:00:14.0"),
},
vendor: Some(0x8086),
device: Some(0x7ec0),
class: Some(0x0c),
subclass: Some(0x03),
prog_if: Some(0x30),
revision: None,
subsystem_vendor: None,
subsystem_device: None,
raw_path: String::from("/scheme/pci/0000--00--14.0"),
description: Some(String::from("xHCI controller")),
};
let mut manager = DeviceManager::new(config());
manager.register_bus(Box::new(MockBus {
name: "pci",
devices: vec![info.clone()],
}));
manager.register_driver(Box::new(MockDriver {
name: "xhcid",
description: "USB host controller",
priority: 80,
matches: vec![DriverMatch {
vendor: Some(0x8086),
device: None,
class: Some(0x0c),
subclass: Some(0x03),
prog_if: Some(0x30),
subsystem_vendor: None,
subsystem_device: None,
}],
probe_result: ProbeResult::Bound,
remove_count: Some(&REMOVE_COUNT),
}));
let first_events = manager.enumerate();
assert!(first_events.iter().any(|event| matches!(
event,
super::ProbeEvent::ProbeCompleted { result, .. } if *result == ProbeResult::Bound
)));
let removed = manager.remove_device(&info.id).unwrap();
assert_eq!(removed.as_deref(), Some("xhcid"));
assert_eq!(REMOVE_COUNT.load(std::sync::atomic::Ordering::SeqCst), 1);
let second_events = manager.enumerate();
assert!(second_events.iter().any(|event| matches!(
event,
super::ProbeEvent::ProbeCompleted { result, .. } if *result == ProbeResult::Bound
)));
assert!(
!second_events
.iter()
.any(|event| matches!(event, super::ProbeEvent::AlreadyBound { .. }))
);
}
}
@@ -5,7 +5,7 @@ use alloc::vec::Vec;
use crate::bus::{Bus, BusError};
use crate::device::{BoundDevice, DeviceId, DeviceInfo};
use crate::driver::{Driver, ProbeResult};
use crate::driver::{Driver, DriverError, ProbeResult};
/// Event emitted by the device manager during discovery or deferred-probe processing.
#[derive(Clone, Debug, PartialEq, Eq)]
@@ -203,6 +203,27 @@ impl DeviceManager {
events
}
/// Detaches a previously bound device and invokes the owning driver's cleanup path.
pub fn remove_device(&mut self, device: &DeviceId) -> Result<Option<String>, DriverError> {
let Some(bound) = self.bound_devices.get(device).cloned() else {
return Ok(None);
};
let Some(driver) = self
.drivers
.iter()
.find(|driver| driver.name() == bound.driver_name)
else {
return Err(DriverError::Other("bound driver missing"));
};
driver.remove(&bound.info)?;
self.bound_devices.remove(device);
self.deferred_queue
.retain(|(info, _driver_name)| info.id != *device);
Ok(Some(bound.driver_name))
}
fn probe_device(&mut self, info: DeviceInfo, events: &mut Vec<ProbeEvent>) {
let mut matched = false;
@@ -276,7 +297,10 @@ impl DeviceManager {
}
fn enqueue_deferred(&mut self, info: DeviceInfo, driver_name: String) {
let already_queued = self.deferred_queue.iter().any(|(queued_info, queued_driver)| {
let already_queued = self
.deferred_queue
.iter()
.any(|(queued_info, queued_driver)| {
queued_info.id == info.id && queued_driver == &driver_name
});
@@ -319,6 +343,8 @@ mod tests {
description: &'static str,
priority: i32,
matches: Vec<DriverMatch>,
probe_result: ProbeResult,
remove_count: Option<&'static std::sync::atomic::AtomicUsize>,
}
impl Driver for MockDriver {
@@ -339,10 +365,13 @@ mod tests {
}
fn probe(&self, _info: &DeviceInfo) -> ProbeResult {
ProbeResult::NotSupported
self.probe_result.clone()
}
fn remove(&self, _info: &DeviceInfo) -> Result<(), DriverError> {
if let Some(remove_count) = self.remove_count {
remove_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
Ok(())
}
}
@@ -376,6 +405,8 @@ mod tests {
subsystem_vendor: None,
subsystem_device: None,
}],
probe_result: ProbeResult::NotSupported,
remove_count: None,
}));
manager.register_driver(Box::new(MockDriver {
name: "high",
@@ -390,6 +421,8 @@ mod tests {
subsystem_vendor: None,
subsystem_device: None,
}],
probe_result: ProbeResult::NotSupported,
remove_count: None,
}));
assert_eq!(manager.buses.len(), 1);
@@ -430,4 +463,70 @@ mod tests {
if bus == "pci" && *device_count == 1
)));
}
#[test]
fn remove_device_invokes_driver_and_allows_reprobe() {
static REMOVE_COUNT: std::sync::atomic::AtomicUsize =
std::sync::atomic::AtomicUsize::new(0);
REMOVE_COUNT.store(0, std::sync::atomic::Ordering::SeqCst);
let info = DeviceInfo {
id: DeviceId {
bus: String::from("pci"),
path: String::from("0000:00:14.0"),
},
vendor: Some(0x8086),
device: Some(0x7ec0),
class: Some(0x0c),
subclass: Some(0x03),
prog_if: Some(0x30),
revision: None,
subsystem_vendor: None,
subsystem_device: None,
raw_path: String::from("/scheme/pci/0000--00--14.0"),
description: Some(String::from("xHCI controller")),
};
let mut manager = DeviceManager::new(config());
manager.register_bus(Box::new(MockBus {
name: "pci",
devices: vec![info.clone()],
}));
manager.register_driver(Box::new(MockDriver {
name: "xhcid",
description: "USB host controller",
priority: 80,
matches: vec![DriverMatch {
vendor: Some(0x8086),
device: None,
class: Some(0x0c),
subclass: Some(0x03),
prog_if: Some(0x30),
subsystem_vendor: None,
subsystem_device: None,
}],
probe_result: ProbeResult::Bound,
remove_count: Some(&REMOVE_COUNT),
}));
let first_events = manager.enumerate();
assert!(first_events.iter().any(|event| matches!(
event,
super::ProbeEvent::ProbeCompleted { result, .. } if *result == ProbeResult::Bound
)));
let removed = manager.remove_device(&info.id).unwrap();
assert_eq!(removed.as_deref(), Some("xhcid"));
assert_eq!(REMOVE_COUNT.load(std::sync::atomic::Ordering::SeqCst), 1);
let second_events = manager.enumerate();
assert!(second_events.iter().any(|event| matches!(
event,
super::ProbeEvent::ProbeCompleted { result, .. } if *result == ProbeResult::Bound
)));
assert!(
!second_events
.iter()
.any(|event| matches!(event, super::ProbeEvent::AlreadyBound { .. }))
);
}
}