diff --git a/kernel/src/greq/services.rs b/kernel/src/greq/services.rs index a1c213cc4..6e91b2948 100644 --- a/kernel/src/greq/services.rs +++ b/kernel/src/greq/services.rs @@ -121,7 +121,9 @@ mod tests { use alloc::vec; - let sp = svsm_test_io(IORequest::GetLaunchMeasurement); + let sp = svsm_test_io().unwrap(); + + sp.put_byte(IORequest::GetLaunchMeasurement as u8); let mut expected_measurement = [0u8; 48]; for byte in &mut expected_measurement { diff --git a/kernel/src/testing.rs b/kernel/src/testing.rs index 23554e185..8650e8d16 100644 --- a/kernel/src/testing.rs +++ b/kernel/src/testing.rs @@ -4,13 +4,11 @@ use test::ShouldPanic; use crate::{ cpu::percpu::current_ghcb, locking::{LockGuard, SpinLock}, - serial::{SerialPort, Terminal}, + platform::SVSM_PLATFORM, + serial::SerialPort, sev::ghcb::GHCBIOSize, - svsm_console::SVSMIOPort, }; -use core::sync::atomic::{AtomicBool, Ordering}; - #[macro_export] macro_rules! assert_eq_warn { ($left:expr, $right:expr) => { @@ -30,10 +28,7 @@ macro_rules! assert_eq_warn { } pub use assert_eq_warn; -static SERIAL_INITIALIZED: AtomicBool = AtomicBool::new(false); -static IOPORT: SVSMIOPort = SVSMIOPort::new(); -static SERIAL_PORT: SpinLock> = - SpinLock::new(SerialPort::new(&IOPORT, 0x2e8 /*COM4*/)); +static SERIAL_PORT: SpinLock>> = SpinLock::new(None); /// Byte used to tell the host the request we need for the test. /// These values must be aligned with `test_io()` in scripts/test-in-svsm.sh @@ -49,17 +44,15 @@ pub enum IORequest { /// used in a test. The request (first byte) is sent by this function, so the /// caller can start using the serial port according to the request implemented /// in `test_io()` in scripts/test-in-svsm.sh -pub fn svsm_test_io(req: IORequest) -> LockGuard<'static, SerialPort<'static>> { - let sp = SERIAL_PORT.lock(); - if SERIAL_INITIALIZED - .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) - .is_ok() - { - sp.init(); +pub fn svsm_test_io() -> LockGuard<'static, Option>> { + let mut sp = SERIAL_PORT.lock(); + if sp.is_none() { + let io_port = SVSM_PLATFORM.as_dyn_ref().get_io_port(); + let serial_port = SerialPort::new(io_port, 0x2e8 /*COM4*/); + *sp = Some(serial_port); + serial_port.init(); } - sp.put_byte(req as u8); - sp }