diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8526faac..ed814eed 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,7 +34,7 @@ jobs: sudo apt update sudo apt-get install -y build-essential octave portaudio19-dev python-dev-is-python3 export MAKEFLAGS="-j $(grep -c ^processor /proc/cpuinfo)" - pip install -e .[develop,test,documentation,quadriga,uhd,audio] + pip install -e .[develop,test,quadriga,uhd,audio,sionna,scapy] - name: Run unit tests run: | @@ -73,6 +73,7 @@ jobs: with: python-version: '3.11' + # Note: Sionna dependencies crash with hermespy[documentation] due to outdated ipywidgets requirement on Sionna's side - name: Install doc dependencies run: | sudo apt update diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 6653b937..7b8d605c 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -70,7 +70,7 @@ Unit Testing: - Build Python 3.11 before_script: - apt -qq update && apt-get -qq install -y octave portaudio19-dev python-dev-is-python3 unzip # remove for docker - - pip install -qq -e .\[develop,test,quadriga,audio,sionna\] + - pip install -qq -e .\[develop,test,quadriga,audio,sionna,scapy\] - unzip dist/$HERMES_WHEEL_11 "hermespy/fec/aff3ct/*.so" - pip install -qq pyzmq>=25.1.1 usrp-uhd-client memray>=1.11.0 script: @@ -93,7 +93,7 @@ Integration Testing: - Build Python 3.11 before_script: - apt -qq update && apt-get -qq install -y octave portaudio19-dev python-dev-is-python3 # remove for docker - - pip install -qq dist/$HERMES_WHEEL_11\[test,quadriga,audio,sionna\] + - pip install -qq dist/$HERMES_WHEEL_11\[test,quadriga,audio,sionna,scapy\] - pip install -qq memray script: - python ./tests/test_install.py ./tests/integration_tests/ diff --git a/_examples/library/rotation.py b/_examples/library/rotation.py new file mode 100644 index 00000000..291f8007 --- /dev/null +++ b/_examples/library/rotation.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- + +import matplotlib.pyplot as plt +import numpy as np + +from hermespy.core import Transformation +from hermespy.simulation import Simulation, LinearTrajectory, StaticTrajectory + + +# Setup flags +flag_traj_static_a = False # Is alpha device trajectory static? +flag_traj_static_b = False # Is beta device trajectory static? +flag_lookat = True # Should the devices look at each other? + +# Create a new simulation featuring two devices +simulation = Simulation() +device_alpha = simulation.new_device() +device_beta = simulation.new_device() +duration = 60 + +# Init positions and poses +init_pose_alpha = Transformation.From_Translation(np.array([10., 10., 0.])) +fina_pose_alpha = Transformation.From_Translation(np.array([50., 50., 0.])) +init_pose_beta = Transformation.From_Translation(np.array([30., 10., 20.])) +fina_pose_beta = Transformation.From_Translation(np.array([30., 20., 20.])) +if flag_lookat: + init_pose_alpha = init_pose_alpha.lookat(init_pose_beta.translation) + fina_pose_alpha = fina_pose_alpha.lookat(fina_pose_beta.translation) + init_pose_beta = init_pose_beta.lookat(init_pose_alpha.translation) + fina_pose_beta = fina_pose_beta.lookat(fina_pose_alpha.translation) + +# Assign each device a trajectory +# alpha +if flag_traj_static_a: + device_alpha.trajectory = StaticTrajectory(init_pose_alpha) +else: + device_alpha.trajectory = LinearTrajectory(init_pose_alpha, fina_pose_alpha, duration) +# beta +if flag_traj_static_b: + device_beta.trajectory = StaticTrajectory(init_pose_beta) +else: + device_beta.trajectory = LinearTrajectory(init_pose_beta, fina_pose_beta, duration) + +# Lock the devices onto each other +if flag_lookat: + device_alpha.trajectory.lookat(device_beta.trajectory) + device_beta.trajectory.lookat(device_alpha.trajectory) + +visualization = simulation.scenario.visualize() +with plt.ion(): + for timestamp in np.linspace(0, duration, 200): + simulation.scenario.visualize.update_visualization(visualization, time=timestamp) + plt.pause(0.1) + plt.show() diff --git a/hermespy/channel/radar/radar.py b/hermespy/channel/radar/radar.py index 4d9702ef..98982578 100644 --- a/hermespy/channel/radar/radar.py +++ b/hermespy/channel/radar/radar.py @@ -11,6 +11,7 @@ from sparse import GCXS # type: ignore from hermespy.core import ( + AntennaMode, ChannelStateInformation, ChannelStateFormat, Direction, @@ -609,10 +610,10 @@ def propagation_response( ) -> np.ndarray: # Query the sensor array responses rx_response = receiver.antennas.cartesian_array_response( - carrier_frequency, self.position, "global" + carrier_frequency, self.position, "global", AntennaMode.RX ) tx_response = transmitter.antennas.cartesian_array_response( - carrier_frequency, self.position, "global" + carrier_frequency, self.position, "global", AntennaMode.TX ).conj() if self.attenuate: diff --git a/hermespy/channel/sionna_rt_channel.py b/hermespy/channel/sionna_rt_channel.py index c4c5ef07..03943686 100644 --- a/hermespy/channel/sionna_rt_channel.py +++ b/hermespy/channel/sionna_rt_channel.py @@ -86,10 +86,12 @@ def __apply_doppler(self, num_samples: int) -> tuple: tau (np.ndarray): delays. Shape (num_rx_ants, num_tx_ants, num_paths) """ # Apply doppler - self.paths.apply_doppler(sampling_frequency=self.bandwidth, - num_time_steps=num_samples, - tx_velocities=self.transmitter_velocity, - rx_velocities=self.receiver_velocity) + self.paths.apply_doppler( + sampling_frequency=self.bandwidth, + num_time_steps=num_samples, + tx_velocities=self.transmitter_velocity, + rx_velocities=self.receiver_velocity, + ) # Get and cast CIR a, tau = self.paths.cir() @@ -106,7 +108,7 @@ def state( self, num_samples: int, max_num_taps: int, - interpolation_mode: InterpolationMode = InterpolationMode.NEAREST + interpolation_mode: InterpolationMode = InterpolationMode.NEAREST, ) -> ChannelStateInformation: # Apply Doppler effect and get the channel impulse response a, tau = self.__apply_doppler(num_samples) @@ -114,11 +116,15 @@ def state( # Init result max_delay = np.max(tau) if tau.size != 0 else 0 max_delay_in_samples = min(max_num_taps, ceil(max_delay * self.bandwidth)) - raw_state = np.zeros(( - self.num_receive_antennas, - self.num_transmit_antennas, - num_samples, - 1 + max_delay_in_samples), dtype=np.complex_) + raw_state = np.zeros( + ( + self.num_receive_antennas, + self.num_transmit_antennas, + num_samples, + 1 + max_delay_in_samples, + ), + dtype=np.complex_, + ) # If no paths hit the target, then return an empty state if a.size == 0 or tau.size == 0: return ChannelStateInformation(ChannelStateFormat.IMPULSE_RESPONSE, raw_state) @@ -136,9 +142,7 @@ def state( return ChannelStateInformation(ChannelStateFormat.IMPULSE_RESPONSE, raw_state) def _propagate( - self, - signal_block: SignalBlock, - interpolation: InterpolationMode, + self, signal_block: SignalBlock, interpolation: InterpolationMode ) -> SignalBlock: # Calculate the resulting signal block parameters sr_ratio = self.receiver_state.sampling_rate / self.transmitter_state.sampling_rate @@ -150,16 +154,16 @@ def _propagate( a, tau = self.__apply_doppler(signal_block.num_samples) # If no paths hit the target, then return a zeroed signal if a.size == 0 or tau.size == 0: - return SignalBlock(np.zeros((num_streams_new, num_samples_new), - signal_block.dtype), - offset_new) + return SignalBlock( + np.zeros((num_streams_new, num_samples_new), signal_block.dtype), offset_new + ) # Set other attributes max_delay = np.max(tau) max_delay_in_samples = ceil(max_delay * self.bandwidth) propagated_samples = np.zeros( (num_streams_new, signal_block.num_samples + max_delay_in_samples), - dtype=signal_block.dtype + dtype=signal_block.dtype, ) # Prepare the optimal einsum path ahead of time for faster execution @@ -173,9 +177,9 @@ def _propagate( if tau_p == -1.0: continue t = int(tau_p * self.bandwidth) - propagated_samples[ - :, t : t + signal_block.num_samples - ] += np.einsum(einsum_subscripts, a_p, signal_block, optimize=einsum_path) + propagated_samples[:, t : t + signal_block.num_samples] += np.einsum( + einsum_subscripts, a_p, signal_block, optimize=einsum_path + ) propagated_samples *= np.sqrt(self.__gain) return SignalBlock(propagated_samples, offset_new) @@ -189,10 +193,9 @@ class SionnaRTChannelRealization(ChannelRealization[SionnaRTChannelSample]): scene: rt.Scene - def __init__(self, - scene: rt.Scene, - sample_hooks: Set[ChannelSampleHook] | None = None, - gain: float = 1.) -> None: + def __init__( + self, scene: rt.Scene, sample_hooks: Set[ChannelSampleHook] | None = None, gain: float = 1.0 + ) -> None: super().__init__(sample_hooks, gain) self.scene = scene @@ -205,18 +208,12 @@ def _sample(self, state: LinkState) -> SionnaRTChannelSample: # init self.scene.tx_array tx_antenna = rt.Antenna("iso", "V") - tx_positions = [ - a.position - for a in state.transmitter.antennas.transmit_antennas - ] + tx_positions = [a.position for a in state.transmitter.antennas.transmit_antennas] self.scene.tx_array = rt.AntennaArray(tx_antenna, tx_positions) # init self.scene.rx_array rx_antenna = rt.Antenna("iso", "V") - rx_positions = [ - a.position - for a in state.receiver.antennas.receive_antennas - ] + rx_positions = [a.position for a in state.receiver.antennas.receive_antennas] self.scene.rx_array = rt.AntennaArray(rx_antenna, rx_positions) # init tx and rx @@ -244,9 +241,7 @@ def to_HDF(self, group: Group) -> None: @staticmethod def From_HDF( - scene: rt.Scene, - group: Group, - sample_hooks: Set[ChannelSampleHook[SionnaRTChannelSample]] + scene: rt.Scene, group: Group, sample_hooks: Set[ChannelSampleHook[SionnaRTChannelSample]] ) -> SionnaRTChannelRealization: return SionnaRTChannelRealization(scene, sample_hooks, group.attrs["gain"]) diff --git a/hermespy/core/__init__.py b/hermespy/core/__init__.py index 9791bcb6..35389a01 100644 --- a/hermespy/core/__init__.py +++ b/hermespy/core/__init__.py @@ -65,7 +65,7 @@ register, ) from .random_node import RandomRealization, RandomNode -from .drop import Drop, RecalledDrop +from .drop import Drop from .scenario import Scenario, ScenarioMode, ScenarioType, ReplayScenario from .signal_model import Signal, SignalBlock, DenseSignal, SparseSignal from .visualize import ( @@ -168,7 +168,6 @@ "RandomRealization", "RandomNode", "Drop", - "RecalledDrop", "Scenario", "ScenarioMode", "ScenarioType", diff --git a/hermespy/core/device.py b/hermespy/core/device.py index e69144f6..ddaa233b 100644 --- a/hermespy/core/device.py +++ b/hermespy/core/device.py @@ -406,6 +406,10 @@ def to_HDF(self, group: Group) -> None: self.mixed_signal.to_HDF(self._create_group(group, "mixed_signal")) +DTT = TypeVar("DTT", bound="DeviceTransmission") +"""Type of device transmission.""" + + class DeviceTransmission(DeviceOutput): """Information generated by transmitting over a device.""" @@ -465,7 +469,21 @@ def num_operator_transmissions(self) -> int: return len(self.__operator_transmissions) @classmethod - def from_HDF(cls: Type[DeviceTransmission], group: Group) -> DeviceTransmission: + def from_HDF( + cls: Type[DeviceTransmission], group: Group, operators: Sequence[Transmitter] | None = None + ) -> DeviceTransmission: + """Recall a device transmission from a serialization. + + Args: + + group (Group): + HDF5 group containing the serialized device transmission. + + operators (Sequence[Transmitter], optional): + List of device transmitters to recall the specific transmissions. + If not specified, the transmissions are recalled as their base class. + """ + # Recall base class device_output = DeviceOutput.from_HDF(group) @@ -473,9 +491,16 @@ def from_HDF(cls: Type[DeviceTransmission], group: Group) -> DeviceTransmission: num_transmissions = group.attrs.get("num_transmissions", 1) # Recall transmissions - transmissions = [ - Transmission.from_HDF(group[f"transmission_{t:02d}"]) for t in range(num_transmissions) - ] + if operators is None: + transmissions = [ + Transmission.from_HDF(group[f"transmission_{t:02d}"]) + for t in range(num_transmissions) + ] + else: + transmissions = [ + operator.recall_transmission(group[f"transmission_{t:02d}"]) + for t, operator in zip(range(num_transmissions), operators) + ] # Initialize object return cls.From_Output(device_output, transmissions) @@ -706,20 +731,26 @@ def num_operator_receptions(self) -> int: return len(self.__operator_receptions) @classmethod - def from_HDF(cls: Type[DRT], group: Group) -> DRT: + def from_HDF(cls: Type[DRT], group: Group, operators: Sequence[Receiver] | None = None) -> DRT: # Recall base class device_input = ProcessedDeviceInput.from_HDF(group) # Recall individual operator receptions num_receptions = group.attrs.get("num_operator_receptions", 0) - # Recall operator receptions - operator_receptions = [ - Reception.from_HDF(group[f"reception_{f:02d}"]) for f in range(num_receptions) - ] + # Recall receptions + if operators is None: + receptions = [ + Reception.from_HDF(group[f"reception_{r:02d}"]) for r in range(num_receptions) + ] + else: + receptions = [ + operator.recall_reception(group[f"reception_{r:02d}"]) + for r, operator in zip(range(num_receptions), operators) + ] # Initialize object - return cls.From_ProcessedDeviceInput(device_input, operator_receptions) + return cls.From_ProcessedDeviceInput(device_input, receptions) @classmethod def Recall(cls: Type[DRT], group: Group, device: Device) -> DRT: @@ -1809,6 +1840,21 @@ def transmit(self, clear_cache: bool = True) -> DeviceTransmission: return DeviceTransmission.From_Output(device_output, operator_transmissions) + def recall_transmission(self, group: Group) -> DeviceTransmission: + """Recall a specific transmission from a HDF5 serialization. + + Args: + + group (Group): + HDF group containing the transmission. + + Returns: The recalled transmission. + """ + + # Recall the specific operator transmissions + + return DeviceTransmission.from_HDF(group, list(self.transmitters)) + def cache_transmission(self, transmission: DeviceTransmission) -> None: for transmitter, operator_transmission in zip( self.transmitters, transmission.operator_transmissions @@ -1953,3 +1999,16 @@ def receive( # Generate device reception return DeviceReception.From_ProcessedDeviceInput(processed_input, receptions) + + def recall_reception(self, group: Group) -> DeviceReception: + """Recall a specific reception from a HDF5 serialization. + + Args: + + group (Group): + HDF group containing the reception. + + Returns: The recalled reception. + """ + + return DeviceReception.Recall(group, self) diff --git a/hermespy/core/drop.py b/hermespy/core/drop.py index b897edea..d078ca89 100644 --- a/hermespy/core/drop.py +++ b/hermespy/core/drop.py @@ -7,18 +7,15 @@ from __future__ import annotations from collections.abc import Sequence -from typing import Type, TYPE_CHECKING +from typing import Generic, Type, TypeVar from h5py import Group -from .device import DeviceReception, DeviceTransmission +from .device import Device, DeviceReception, DeviceTransmission, DRT, DTT from .factory import HDFSerializable from .signal_model import Signal from .monte_carlo import Artifact -if TYPE_CHECKING: - from .scenario import Scenario # pragma: no cover - __author__ = "Jan Adler" __copyright__ = "Copyright 2024, Barkhausen Institut gGmbH" __credits__ = ["Jan Adler"] @@ -29,19 +26,19 @@ __status__ = "Prototype" -class Drop(HDFSerializable): +class Drop(Generic[DTT, DRT], HDFSerializable): """Drop containing the information transmitted and received by all devices within a scenario.""" __timestamp: float # Time at which the drop was generated - __device_transmissions: Sequence[DeviceTransmission] # Transmitted device information - __device_receptions: Sequence[DeviceReception] # Received device information + __device_transmissions: Sequence[DTT] # Transmitted device information + __device_receptions: Sequence[DRT] # Received device information def __init__( self, timestamp: float, - device_transmissions: Sequence[DeviceTransmission], - device_receptions: Sequence[DeviceReception], + device_transmissions: Sequence[DTT], + device_receptions: Sequence[DRT], ) -> None: """ Args: @@ -49,10 +46,10 @@ def __init__( timestamp (float): Time at which the drop was generated. - device_transmissions (Sequence[DeviceTransmission]): + device_transmissions (Sequence[DTT]): Transmitted device information. - device_receptions (Sequence[DeviceReception]): + device_receptions (Sequence[DRT]): Received device information. """ @@ -67,13 +64,13 @@ def timestamp(self) -> float: return self.__timestamp @property - def device_transmissions(self) -> Sequence[DeviceTransmission]: + def device_transmissions(self) -> Sequence[DTT]: """Transmitted device information within this drop.""" return self.__device_transmissions @property - def device_receptions(self) -> Sequence[DeviceReception]: + def device_receptions(self) -> Sequence[DRT]: """Received device information within this drop.""" return self.__device_receptions @@ -100,19 +97,29 @@ def operator_inputs(self) -> Sequence[Sequence[Signal]]: return [reception.operator_inputs for reception in self.device_receptions] @classmethod - def from_HDF(cls: Type[Drop], group: Group) -> Drop: + def from_HDF(cls: Type[Drop], group: Group, devices: Sequence[Device] | None = None) -> Drop: # Recall attributes timestamp = group.attrs.get("timestamp", 0.0) - num_transmissions = group.attrs.get("num_transmissions", 0) - num_receptions = group.attrs.get("num_receptions", 0) - - transmissions = [ - DeviceTransmission.from_HDF(group[f"transmission_{t:02d}"]) - for t in range(num_transmissions) - ] - receptions = [ - DeviceReception.from_HDF(group[f"reception_{r:02d}"]) for r in range(num_receptions) - ] + num_transmissions: int = group.attrs.get("num_transmissions", 0) + num_receptions: int = group.attrs.get("num_receptions", 0) + + if devices is None: + transmissions = [ + DeviceTransmission.from_HDF(group[f"transmission_{t:02d}"]) + for t in range(num_transmissions) + ] + receptions = [ + DeviceReception.from_HDF(group[f"reception_{r:02d}"]) for r in range(num_receptions) + ] + else: + transmissions = [ + device.recall_transmission(group[f"transmission_{t:02d}"]) + for t, device in zip(range(num_transmissions), devices) + ] + receptions = [ + device.recall_reception(group[f"reception_{r:02d}"]) + for r, device in zip(range(num_receptions), devices) + ] drop = cls( timestamp=timestamp, device_transmissions=transmissions, device_receptions=receptions @@ -133,46 +140,8 @@ def to_HDF(self, group: Group) -> None: group.attrs["num_receptions"] = self.num_device_receptions -class RecalledDrop(Drop): - """Drop recalled from serialization containing the information transmitted and received by all devices - within a scenario.""" - - __group: Group - - def __init__(self, group: Group, scenario: Scenario) -> None: - # Recall attributes - timestamp = group.attrs.get("timestamp", 0.0) - num_transmissions = group.attrs.get("num_transmissions", 0) - num_receptions = group.attrs.get("num_receptions", 0) - - device_transmissions = [ - DeviceTransmission.Recall(group[f"transmission_{t:02d}"], device) - for t, device in zip(range(num_transmissions), scenario.devices) - ] - device_receptions = [ - DeviceReception.Recall(group[f"reception_{r:02d}"], device) - for r, device in zip(range(num_receptions), scenario.devices) - ] - - # Initialize base class - Drop.__init__( - self, - timestamp=timestamp, - device_transmissions=device_transmissions, - device_receptions=device_receptions, - ) - - # Initialize class attributes - self.__group = group - - @property - def group(self) -> Group: - """HDF group this drop was recalled from. - - Returns: Handle to an HDF group. - """ - - return self.__group +DropType = TypeVar("DropType", bound=Drop) +"""Type of a drop.""" class EvaluatedDrop(Drop): diff --git a/hermespy/core/scenario.py b/hermespy/core/scenario.py index e2901a74..643925a4 100644 --- a/hermespy/core/scenario.py +++ b/hermespy/core/scenario.py @@ -16,6 +16,7 @@ from h5py import File, Group from .device import ( + Device, DeviceInput, DeviceOutput, DeviceReception, @@ -28,7 +29,7 @@ Receiver, Operator, ) -from .drop import Drop, RecalledDrop +from .drop import Drop, DropType from .factory import Factory from .random_node import RandomNode from .signal_model import Signal @@ -66,7 +67,7 @@ class ScenarioMode(IntEnum): """ -class Scenario(ABC, RandomNode, TransformableBase, Generic[DeviceType]): +class Scenario(ABC, RandomNode, TransformableBase, Generic[DeviceType, DropType]): """A wireless scenario. Scenarios consist of several devices transmitting and receiving electromagnetic signals. @@ -857,7 +858,7 @@ def num_drops(self) -> int | None: return None @abstractmethod - def _drop(self) -> Drop: + def _drop(self) -> DropType: """Generate a single scenario drop. Wrapped by the scenario base class :meth:`.drop` method. @@ -867,14 +868,25 @@ def _drop(self) -> Drop: """ ... # pragma no cover - def drop(self) -> Drop: + @abstractmethod + def _recall_drop(self, group: Group) -> DropType: + """Recall a recorded drop from a HDF5 group. + + Args: + + group (Group): + HDF5 group containing the drop information. + + Returns: The recalled drop. + """ + ... # pragma no cover + + def drop(self) -> DropType: """Generate a single data drop from all scenario devices. Return: The generated drop information. """ - drop: Drop - if self.mode == ScenarioMode.REPLAY: # Recall the drop from the savefile for _ in range(self.__file.attrs["num_drops"]): @@ -882,7 +894,7 @@ def drop(self) -> Drop: self.__drop_counter = (self.__drop_counter + 1) % self.__file.attrs["num_drops"] if drop_path in self.__file: - drop = RecalledDrop(self.__file[drop_path], self) + drop = self._recall_drop(self.__file[drop_path]) break # Replay device operator transmissions @@ -912,8 +924,11 @@ def drop(self) -> Drop: """Type of scenario.""" -class ReplayScenario(Scenario): +class ReplayScenario(Scenario[Device, Drop]): """Scenario which is unable to generate drops.""" def _drop(self) -> Drop: raise RuntimeError("Replay scenario may not generate data drops.") + + def _recall_drop(self, group: Group) -> Drop: + return Drop.from_HDF(group) diff --git a/hermespy/core/signal_model.py b/hermespy/core/signal_model.py index 833c357d..047a0855 100644 --- a/hermespy/core/signal_model.py +++ b/hermespy/core/signal_model.py @@ -766,7 +766,9 @@ def __parse_slice(s: slice, dim_size: int) -> Tuple[int, int, int]: s1 = s1 if s1 >= 0 else s1 % dim_size return s0, s1, s2 - def _parse_validate_itemkey(self, key: Any) -> Tuple[int, int, int, int, int, int, bool]: + def _parse_validate_itemkey( + self, key: Any + ) -> Tuple[int, int, int, int, int, int, bool, bool, bool]: """Parse and validate key in __getitem__ and __setitem__. Raises: @@ -785,11 +787,13 @@ def _parse_validate_itemkey(self, key: Any) -> Tuple[int, int, int, int, int, in s11 (int): samples stop s12 (int): samples step isboolmask (bool): True if key is a boolean mask, False otherwise. + should_flatten_streams (bool): True if numpy's getitem would flatten the streams (1) dimension with this key + should_flatten_samples (bool): True if numpy's getitem would flatten the samples (2) dimension with this key Note that if isboolmask is True, then all s?? take the following values: (0, self.num_streams, 1, 0, self.num_samples, 1). - Note that if the key references any dimansion with an integer index, + Note that if the key references any dimension with an integer index, then the corresponding result start will be the index, and stop is start+1. For example, if key is 1, then only stream 1 is need. Then s00 is 1 and s01 is 2. Numpy getitem of [1] and [1:2] differ in dimensions. Flattening of the second variant should be considered. @@ -798,6 +802,8 @@ def _parse_validate_itemkey(self, key: Any) -> Tuple[int, int, int, int, int, in self_num_streams = self.num_streams self_num_samples = self.num_samples isboolmask = False + should_flatten_streams = False + should_flatten_samples = False # Key is a tuple of two # ====================================================================== @@ -819,6 +825,7 @@ def _parse_validate_itemkey(self, key: Any) -> Tuple[int, int, int, int, int, in s00 = key[0] % self_num_streams s01 = s00 + 1 s02 = 1 + should_flatten_streams = True else: raise TypeError( f"Expected to get streams index as an integer or a slice, but got {type(key[0])}" @@ -834,10 +841,21 @@ def _parse_validate_itemkey(self, key: Any) -> Tuple[int, int, int, int, int, in s10 = key[1] % self_num_samples s11 = s10 + 1 s12 = 1 + should_flatten_samples = True else: - raise TypeError(f"Samples key is of an unsupported type ({type(key[1])})") - - return s00, s01, s02, s10, s11, s12, False + raise TypeError(f"Samples key is ofan unsupported type ({type(key[1])})") + + return ( + s00, + s01, + s02, + s10, + s11, + s12, + False, + should_flatten_streams, + should_flatten_samples, + ) # ====================================================================== # done Key is a tuple of two @@ -861,6 +879,7 @@ def _parse_validate_itemkey(self, key: Any) -> Tuple[int, int, int, int, int, in s00 = key % self_num_streams s01 = s00 + 1 s02 = 1 + should_flatten_streams = True # Key is a boolean mask or something unsupported else: try: @@ -874,7 +893,17 @@ def _parse_validate_itemkey(self, key: Any) -> Tuple[int, int, int, int, int, in except ValueError: raise TypeError(f"Unsupported key type {type(key)}") - return s00, s01, s02, s10, s11, s12, isboolmask + return ( + s00, + s01, + s02, + s10, + s11, + s12, + isboolmask, + should_flatten_streams, + should_flatten_samples, + ) def _find_affected_blocks(self, s10: int, s11: int) -> Tuple[int, int]: """Find indices of blocks that are affected by the given samples slice. @@ -917,7 +946,7 @@ def _find_affected_blocks(self, s10: int, s11: int) -> Tuple[int, int]: return b_start, b_stop - def getitem(self, key: Any = slice(None, None)) -> np.ndarray: + def getitem(self, key: Any = slice(None, None), unflatten: bool = True) -> np.ndarray: """Get specified samples. Works like np.ndarray.__getitem__, but de-sparsifies the signal. @@ -927,6 +956,9 @@ def getitem(self, key: Any = slice(None, None)) -> np.ndarray: a tuple (int, int), (int, slice), (slice, int), (slice, slice) or a boolean mask. Defaults to slice(None, None) (same as [:, :]) + unflatten (bool): + Set to True to ensure the result ndim to be 2 even if only one stream is selected. + Set to False to allow the numpy-like degenerate dimensions reduction. Examples: getitem(slice(None, None)): @@ -934,13 +966,19 @@ def getitem(self, key: Any = slice(None, None)) -> np.ndarray: Warning: can cause memory overflow if used with a sparse signal. getitem(0): Select and de-sparsify the first stream. + Result shape is (1, num_samples) + getitem(0, False): + Same, but allow the numpy flattening. + Result shape is (num_samples,) getitem((slice(None, 2), slice(50, 100))): Select streams 0, 1 and samples 50-99. Same as samples_matrix[:2, 50:100] - Returns: np.ndarray with ndim 2 and dtype np.complex_""" + Returns: np.ndarray with ndim 2 or less and dtype dtype np.complex_""" - s00, s01, s02, s10, s11, s12, isboolmask = self._parse_validate_itemkey(key) + s00, s01, s02, s10, s11, s12, isboolmask, should_flatten_streams, should_flatten_samples = ( + self._parse_validate_itemkey(key) + ) num_streams = -((s01 - s00) // -s02) num_samples = -((s11 - s10) // -s12) if self.num_samples == 0 or self.num_streams == 0: # if this signal is empty @@ -968,6 +1006,9 @@ def getitem(self, key: Any = slice(None, None)) -> np.ndarray: # ^b previous^^gap^^b_start/b_stop^ elif s11 > b.offset: res[:, b.offset - s10 :] = b[s00:s01:s02, :] + # Apply numpy-like flattening + if not unflatten and (should_flatten_streams or should_flatten_samples): + res = res.flatten() return res # assemble the result @@ -1012,7 +1053,13 @@ def getitem(self, key: Any = slice(None, None)) -> np.ndarray: ) if is_streams_step_reversing: return res[::s02, ::s12] - return res[s00:s01:s02, ::s12] + res = res[s00:s01:s02, ::s12] + + # Apply numpy-like flattening + if not unflatten and (should_flatten_streams or should_flatten_samples): + res = res.flatten() + + return res def getstreams(self, streams_key: int | slice | Sequence[int]) -> Signal: """Create a new signal like this, but with only the selected streams. @@ -1693,15 +1740,21 @@ def Create( ) -> DenseSignal: return DenseSignal(samples, sampling_rate, carrier_frequency, noise_power, delay, offsets) - def getitem(self, key: Any = slice(None, None)) -> np.ndarray: + def getitem(self, key: Any = slice(None, None), unflatten: bool = True) -> np.ndarray: """Reroutes the argument to the single block of this model. Refer the numpy.ndarray.__getitem__ documentation. The result is always a 2D ndarray.""" res = self._blocks[0].view(np.ndarray)[key] # de-flatten - if res.ndim == 1: - return res.reshape((1, res.size)) + if unflatten and res.ndim == 1: + streams_flattened, samples_flattened = self._parse_validate_itemkey(key)[-2:] + if streams_flattened and samples_flattened: + return res.reshape(()) + elif streams_flattened: + return res.reshape((1, res.size)) + elif samples_flattened: + return res.reshape((res.size, 1)) return res def __setitem__(self, key: Any, value: Any) -> None: @@ -1932,7 +1985,7 @@ def __from_dense(block: np.ndarray) -> List[SignalBlock]: def __setitem__(self, key: Any, value: Any) -> None: # parse and validate key - s00, s01, s02, s10, s11, s12, isboolmask = self._parse_validate_itemkey(key) + s00, s01, s02, s10, s11, s12, isboolmask, _, _ = self._parse_validate_itemkey(key) if s02 <= 0 or s12 <= 0: raise NotImplementedError("Only positive steps are implemented") if s12 != 1: diff --git a/hermespy/core/transformation.py b/hermespy/core/transformation.py index e65ac4a6..87405392 100644 --- a/hermespy/core/transformation.py +++ b/hermespy/core/transformation.py @@ -472,6 +472,57 @@ def invert(self) -> Transformation: return np.linalg.inv(self).view(Transformation) + def lookat( + self, + target: np.ndarray = np.array([0.0, 0.0, 0.0], float), + up: np.ndarray = np.array([0.0, 1.0, 0.0], float), + ) -> Transformation: + """Rotate and loook at the given coordinates. Modifies `orientation` property. + + Args: + target (np.ndarray): + Cartesean coordinates to look at. + Defaults to np.array([0., 0., 0.], float) + up (np.ndarray): + Global catesean sky vector. + Defines the upward direction of the local viewport. + Defaults to np.array([0., 1., 0.], float) + + Returns: + self (Transformation): This modified Transformation. + """ + + # Validate arguments + target_ = np.asarray(target) + if target_.shape != (3,): + raise ValueError( + f"Got target of an unexpected shape (expected (3,), got {target_.shape})" + ) + up_ = np.asarray(up) + if up_.shape != (3,): + raise ValueError(f"Got up of an unexpected shape (expected (3,), got {up_.shape})") + up_ /= np.linalg.norm(up_) + + # Calculate new orientation + # forward vector + pos = self.translation + f = target_ - pos + f_norm = np.linalg.norm(f) + f = f / f_norm if f_norm != 0.0 else pos # normalize + # side/right vector + s = np.cross(up_, f) + s_norm = np.linalg.norm(s) + s = s / s_norm if s_norm != 0.0 else up_ # normalize + # up vector + u = np.cross(f, s) + # Calcualte the new transformation matrix + self[:3, 0] = s + self[:3, 1] = u + self[:3, 2] = f + self[3, :] = [0.0, 0.0, 0.0, 1.0] + + return self + @classmethod def to_yaml( cls: Type[Transformation], representer: SafeRepresenter, node: Transformation @@ -763,6 +814,25 @@ def to_local_coordinates(self, arg_0: Transformable | Transformation | np.ndarra local_transformation = self.backwards_transformation @ arg_0 return local_transformation.view(Transformation) + def lookat( + self, + target: np.ndarray = np.array([0.0, 0.0, 0.0], float), + up: np.ndarray = np.array([0.0, 1.0, 0.0], float), + ) -> None: + """Rotate and loook at the given coordinates. Modifies `orientation` property. + + Args: + target (np.ndarray): + Cartesean coordinates to look at. + Defaults to np.ndarray([0., 0., 0.], float). + up (array of 3 numbers): + Global catesean sky vector. + Defines the upward direction of the local viewport. + Defaults to np.ndarray([0., 1., 0.], float). + """ + + self.pose.lookat(target, up) + def _kinematics_updated(self) -> None: # Clear the cached forwards transformation if the base has been updated if "forwards_transformation" in self.__dict__: diff --git a/hermespy/hardware_loop/scenario.py b/hermespy/hardware_loop/scenario.py index a105bab7..1f6f72b5 100644 --- a/hermespy/hardware_loop/scenario.py +++ b/hermespy/hardware_loop/scenario.py @@ -11,6 +11,8 @@ from time import time from typing import Generic, Optional, TypeVar +from h5py import Group + from hermespy.core import DeviceInput, DeviceReception, Scenario, Drop, Signal from hermespy.simulation import SimulatedDeviceReception, SimulationScenario, TriggerRealization from .physical_device import PDT @@ -25,7 +27,7 @@ __status__ = "Prototype" -class PhysicalScenario(Generic[PDT], Scenario[PDT]): +class PhysicalScenario(Generic[PDT], Scenario[PDT, Drop]): """Scenario of physical device bindings. Managing physical devices by a scenario enables synchronized triggering @@ -97,6 +99,9 @@ def _drop(self) -> Drop: return Drop(timestamp, device_transmissions, device_receptions) + def _recall_drop(self, group: Group) -> Drop: + return Drop.from_HDF(group, self.devices) + def add_device(self, device: PDT) -> None: Scenario.add_device(self, device) diff --git a/hermespy/jcas/matched_filtering.py b/hermespy/jcas/matched_filtering.py index 7abcb5d0..e9a8cf56 100644 --- a/hermespy/jcas/matched_filtering.py +++ b/hermespy/jcas/matched_filtering.py @@ -84,7 +84,8 @@ def _receive(self, signal: Signal) -> JCASReception: signal.append_samples( Signal.Create( np.zeros( - (1, required_num_received_samples - signal.num_samples), dtype=complex + (signal.num_streams, required_num_received_samples - signal.num_samples), + dtype=complex, ), self.sampling_rate, signal.carrier_frequency, @@ -93,20 +94,18 @@ def _receive(self, signal: Signal) -> JCASReception: ) ) - # Remove possible overhead samples if signal is too long - # resampled_signal.set_samples(re) - # sampled_signal[:, :num_samples] + # Digital receive beamformer + angle_bins, beamformed_samples = self._receive_beamform(signal) + # Transmit-receive correlation for range estimation + transmitted_samples = self.transmission.signal.getitem(0) correlation = ( - abs( - correlate( - signal.getitem(), self.transmission.signal.getitem(), mode="valid", method="fft" - ).flatten() - ) + abs(correlate(beamformed_samples, transmitted_samples, mode="valid", method="fft")) / self.transmission.signal.num_samples ) + lags = correlation_lags( - signal.num_samples, self.transmission.signal.num_samples, mode="valid" + beamformed_samples.shape[1], transmitted_samples.shape[1], mode="valid" ) # Append zeros for correct depth estimation @@ -114,10 +113,9 @@ def _receive(self, signal: Signal) -> JCASReception: # correlation = np.append(correlation, np.zeros(num_appended_zeros)) # Create the cube object - angle_bins = np.array([[0.0, 0.0]]) velocity_bins = np.array([0.0]) range_bins = 0.5 * lags[:num_propagated_samples] * resolution - cube_data = np.array([[correlation[:num_propagated_samples]]], dtype=float) + cube_data = correlation[:, None, :num_propagated_samples] cube = RadarCube(cube_data, angle_bins, velocity_bins, range_bins, self.carrier_frequency) # Infer the point cloud, if a detector has been configured diff --git a/hermespy/modem/__init__.py b/hermespy/modem/__init__.py index 18120c82..e9532d9b 100644 --- a/hermespy/modem/__init__.py +++ b/hermespy/modem/__init__.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from .bits_source import RandomBitsSource, StreamBitsSource +from .frame_generator import FrameGenerator, FrameGeneratorStub from .symbols import Symbol, Symbols, StatedSymbols from .modem import ( CommunicationReception, @@ -110,6 +111,8 @@ __all__ = [ "RandomBitsSource", "StreamBitsSource", + "FrameGenerator", + "FrameGeneratorStub", "Symbol", "Symbols", "StatedSymbols", diff --git a/hermespy/modem/frame_generator/__init__.py b/hermespy/modem/frame_generator/__init__.py new file mode 100644 index 00000000..e67391cc --- /dev/null +++ b/hermespy/modem/frame_generator/__init__.py @@ -0,0 +1,3 @@ +from .frame_generator import FrameGenerator, FrameGeneratorStub + +__all__ = ["FrameGenerator", "FrameGeneratorStub"] diff --git a/hermespy/modem/frame_generator/frame_generator.py b/hermespy/modem/frame_generator/frame_generator.py new file mode 100644 index 00000000..200cc703 --- /dev/null +++ b/hermespy/modem/frame_generator/frame_generator.py @@ -0,0 +1,46 @@ +from abc import ABC, abstractmethod + +import numpy as np + +from hermespy.core import Serializable +from ..bits_source import BitsSource + + +class FrameGenerator(ABC, Serializable): + """Base class for frame generators.""" + + @abstractmethod + def pack_frame(self, source: BitsSource, num_bits: int) -> np.ndarray: + """Generate a frame of num_bits bits from the given bitsource. + + Args: + source (BitsSource): payload source. + num_bits (int): number of bits in the whole resulting frame. + + Return: + frame (np.ndarray): array of ints with each element beeing an individual bit. + """ + ... + + @abstractmethod + def unpack_frame(self, frame: np.ndarray) -> np.ndarray: + """Extract the original payload from the frame generated with pack_frame. + + Args: + frame (np.ndarray): array of bits of a frame, generated with pack_frame. + + Return: + payload (np.ndarray): array of payload bits.""" + ... + + +class FrameGeneratorStub(FrameGenerator): + """A dummy placeholder frame generator, packing and unpacking payload without any overhead.""" + + yaml_tag = "GeneratorStub" + + def pack_frame(self, source: BitsSource, num_bits: int) -> np.ndarray: + return source.generate_bits(num_bits) + + def unpack_frame(self, frame: np.ndarray) -> np.ndarray: + return frame diff --git a/hermespy/modem/frame_generator/scapy.py b/hermespy/modem/frame_generator/scapy.py new file mode 100644 index 00000000..80a8fcd0 --- /dev/null +++ b/hermespy/modem/frame_generator/scapy.py @@ -0,0 +1,57 @@ +from .frame_generator import FrameGenerator +from ..bits_source import BitsSource + +import numpy as np +from typing import Type + +from scapy.packet import Packet, raw # type: ignore + + +class FrameGeneratorScapy(FrameGenerator): + """Scapy wrapper frame generator. + + Attrs: + packet(Packet): Scapy packet header to which a payload would be attached. + packet_type(Type[Packet]): Type of the first layer of the packet header. + """ + + packet: Packet + packet_type: Type[Packet] + + def __init__(self, packet: Packet) -> None: + """ + Args: + packet(Packet): Packet to which a payload will be attached. + """ + self.packet = packet + self.packet_num_bits = len(packet) * 8 + self.packet_type = packet.layers()[0] + + def pack_frame(self, source: BitsSource, num_bits: int) -> np.ndarray: + """Generate a frame of num_bits bits from the given bitsource. + Note that the payload size is num_bits minus number of bits in the packet header. + Note that payload can be of size 0, in which case no data would be sent (except for the packet header). + + Args: + source (BitsSource): payload source. + num_bits (int): number of bits in the whole resulting frame. + + Raises: + ValueError if num_bits is not enough to fit the packet. + """ + + payload_num_bits = num_bits - self.packet_num_bits + if payload_num_bits < 0: + raise ValueError( + f"Packet header is bigger then the requested amount of bits ({len(self.packet)*8} > {num_bits})." + ) + packet_new = self.packet_type() + packet_new.add_payload(np.packbits(source.generate_bits(payload_num_bits)).tobytes()) + return np.unpackbits(np.frombuffer(raw(packet_new), np.uint8)) + + def unpack_frame(self, frame: np.ndarray) -> np.ndarray: + if frame.size < self.packet_num_bits: + raise ValueError( + f"The frame contains less bits then the header ({frame.size} < {self.packet_num_bits})." + ) + return frame[self.packet_num_bits :] diff --git a/hermespy/modem/modem.py b/hermespy/modem/modem.py index 6c8d4537..68d5ab0c 100644 --- a/hermespy/modem/modem.py +++ b/hermespy/modem/modem.py @@ -25,6 +25,7 @@ from .bits_source import BitsSource, RandomBitsSource from .symbols import StatedSymbols, Symbols from .waveform import CommunicationWaveform, CWT +from .frame_generator import FrameGenerator, FrameGeneratorStub __author__ = "Jan Adler" __copyright__ = "Copyright 2024, Barkhausen Institut gGmbH" @@ -444,6 +445,7 @@ class BaseModem(ABC, Generic[CWT], RandomNode): __encoder_manager: EncoderManager __precoding: SymbolPrecoding __waveform: CWT | None + __frame_generator: FrameGenerator @staticmethod def _arg_signature() -> Set[str]: @@ -454,6 +456,7 @@ def __init__( encoding: EncoderManager | None = None, precoding: SymbolPrecoding | None = None, waveform: CWT | None = None, + frame_generator: FrameGenerator | None = None, seed: int | None = None, ) -> None: """ @@ -479,6 +482,7 @@ def __init__( self.encoder_manager = EncoderManager() if encoding is None else encoding self.precoding = SymbolPrecoding(modem=self) if precoding is None else precoding self.waveform = waveform + self.frame_generator = FrameGeneratorStub() if frame_generator is None else frame_generator @property @abstractmethod @@ -531,6 +535,14 @@ def waveform(self, value: CWT | None) -> None: value.modem = self value.random_mother = self + @property + def frame_generator(self) -> FrameGenerator: + return self.__frame_generator + + @frame_generator.setter + def frame_generator(self, value: FrameGenerator) -> None: + self.__frame_generator = value + @property def precoding(self) -> SymbolPrecoding: """Description of the modem's precoding on a symbol level.""" @@ -790,10 +802,10 @@ def _transmit(self, duration: float = -1.0) -> CommunicationTransmission: frames: List[CommunicationTransmissionFrame] = [] for n in range(num_mimo_frames): # Generate plain data bits - data_bits = self.bits_source.generate_bits(required_num_data_bits) + frame_bits = self.frame_generator.pack_frame(self.bits_source, required_num_data_bits) # Apply forward error correction - encoded_bits = self.encoder_manager.encode(data_bits, required_num_code_bits) + encoded_bits = self.encoder_manager.encode(frame_bits, required_num_code_bits) # Map bits to communication symbols mapped_symbols = self.__map(encoded_bits, self.precoding.num_input_streams) @@ -825,7 +837,7 @@ def _transmit(self, duration: float = -1.0) -> CommunicationTransmission: frames.append( CommunicationTransmissionFrame( signal=frame_signal, - bits=data_bits, + bits=frame_bits, encoded_bits=encoded_bits, symbols=mapped_symbols, encoded_symbols=encoded_symbols, @@ -1046,6 +1058,9 @@ def _receive(self, signal: Signal) -> CommunicationReception: # Apply inverse FEC configuration to correct errors and remove redundancies decoded_bits = self.encoder_manager.decode(encoded_bits, required_num_data_bits) + # Decode the frame + payload_bits = self.frame_generator.unpack_frame(decoded_bits) + # Store the received information frames.append( CommunicationReceptionFrame( @@ -1056,7 +1071,7 @@ def _receive(self, signal: Signal) -> CommunicationReception: timestamp=frame_index * signal.sampling_rate, equalized_symbols=equalized_symbols, encoded_bits=encoded_bits, - decoded_bits=decoded_bits, + decoded_bits=payload_bits, ) ) diff --git a/hermespy/modem/waveform_chirp_fsk.py b/hermespy/modem/waveform_chirp_fsk.py index cf62a3d9..2d1f5533 100644 --- a/hermespy/modem/waveform_chirp_fsk.py +++ b/hermespy/modem/waveform_chirp_fsk.py @@ -30,7 +30,7 @@ __status__ = "Prototype" -scipy_minor_version = int(version("scipy").split('.')[1]) +scipy_minor_version = int(version("scipy").split(".")[1]) class ChirpFSKWaveform(PilotCommunicationWaveform, Serializable): @@ -473,7 +473,9 @@ def _prototypes(self) -> Tuple[np.ndarray, float]: if scipy_minor_version < 14: phase = integrate.cumtrapz(frequency, dx=1 / self.sampling_rate, initial=0) else: - phase = integrate.cumulative_trapezoid(frequency, dx=1 / self.sampling_rate, initial=0) + phase = integrate.cumulative_trapezoid( + frequency, dx=1 / self.sampling_rate, initial=0 + ) phase *= 2 * np.pi prototypes[idx, :] = np.exp(1j * phase) diff --git a/hermespy/radar/radar.py b/hermespy/radar/radar.py index 345fd522..21f4a2ab 100644 --- a/hermespy/radar/radar.py +++ b/hermespy/radar/radar.py @@ -389,10 +389,10 @@ def _receive_beamform(self, signal: Signal) -> Tuple[np.ndarray, np.ndarray]: RuntimeError: If the beamforming configuration does not result in a single output stream. """ - if self.device.antennas.num_antennas > 1: + if self.device.antennas.num_receive_ports > 1: if self.receive_beamformer is None: raise RuntimeError( - "Receiving over a device with more than one antenna requires a beamforming configuration" + "Receiving over a device with more than one RF port requires a beamforming configuration" ) if self.receive_beamformer.num_receive_output_streams != 1: @@ -402,7 +402,7 @@ def _receive_beamform(self, signal: Signal) -> Tuple[np.ndarray, np.ndarray]: if ( self.receive_beamformer.num_receive_input_streams - != self.device.antennas.num_antennas + != self.device.antennas.num_receive_ports ): raise RuntimeError( "Radar operator receive beamformers are required to consider the full number of antenna streams" diff --git a/hermespy/simulation/animation.py b/hermespy/simulation/animation.py index d093db18..f22b9201 100644 --- a/hermespy/simulation/animation.py +++ b/hermespy/simulation/animation.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod import numpy as np +from scipy.spatial.transform import Slerp, Rotation from hermespy.core import Serializable, Transformation @@ -56,6 +57,32 @@ def velocity(self) -> np.ndarray: class Trajectory(ABC): """Base class for motion trajectories of moveable objects within simulation scenarios.""" + # lookat attributes + _lookat_flag: bool = False + _lookat_target: Trajectory = None + _lookat_up: np.ndarray = np.array([0.0, 1.0, 0.0], float) # (3,), float + + def lookat(self, target: Trajectory, up: np.ndarray = np.array([0.0, 1.0, 0.0], float)) -> None: + """Set a target to look at and track. + + Args: + target (Trajectory): Target trajectory. + up (np.ndarray): Up/sky/head/ceiling global unit vector. Defaults to [0., 1., 0.]. + """ + + self._lookat_flag = True + self._lookat_target = target + self._lookat_up = up + + def lookat_disable(self) -> None: + self._lookat_flag = False + + def lookat_enable(self) -> None: + if self._lookat_target is None: + raise RuntimeError('Cannot enable lookat whithout a target. Use the "lookat" method.') + + self._lookat_flag = True + @property @abstractmethod def max_timestamp(self) -> float: @@ -67,16 +94,59 @@ def max_timestamp(self) -> float: ... # pragma: no cover @abstractmethod + def sample_velocity(self, timestamp: float) -> np.ndarray: + """Sample the trajectory's velocity. + + Args: + timestamp (float): Time at which to sample the trajectory in seconds. + + Returns: A sample of the trajectory's velocity (vector (3,) of floats). + """ + ... # pragma: no cover + + @abstractmethod + def sample_translation(self, timestamp: float) -> np.ndarray: + """Sample the trajectory's translation. + + Args: + timestamp (float): Time at which to sample the trajectory in seconds. + + Returns: A sample of the trajectory's translation (vector (3,) of floats). + """ + ... # pragma: no cover + + @abstractmethod + def sample_orientation(self, timestamp: float) -> np.ndarray: + """Sample the trajectory's orientation. Does not consider lookat. + + Args: + timestamp (float): Time at which to sample the trajectory in seconds. + + Returns: A sample of the trajectory's orientation matrix (matrix (3, 3) of float). + """ + ... # pragma: no cover + def sample(self, timestamp: float) -> TrajectorySample: """Sample the trajectory at a given point in time. Args: - timestamp (float): Time at which to sample the trajectory in seconds. Returns: A sample of the trajectory. """ - ... # pragma: no cover + + # Init transformation and sample position + transformation = np.eye(4, 4, dtype=float).view(Transformation) + transformation[:3, 3] = self.sample_translation(timestamp) + + # Sample orientation + if self._lookat_flag: + target_translation = self._lookat_target.sample_translation(timestamp) + transformation = transformation.lookat(target_translation, self._lookat_up) + else: + transformation[:3, :3] = self.sample_orientation(timestamp) + + return TrajectorySample(timestamp, transformation, self.sample_velocity(timestamp)) class LinearTrajectory(Trajectory): @@ -104,37 +174,31 @@ def __init__( # Initialize class attributes self.__initial_pose = initial_pose - self.__final_pose = final_pose self.__duration = duration self.__start = start # Infer velocity from start and end poses self.__velocity = (final_pose.translation - initial_pose.translation) / duration - self.__initial_quaternion = initial_pose.rotation_quaternion - self.__quaternion_velocity = ( - final_pose.rotation_quaternion - initial_pose.rotation_quaternion - ) / duration + rotations = Rotation.from_matrix([initial_pose[:3, :3], final_pose[:3, :3]]) + self.__slerp = Slerp([start, start + duration], rotations) @property def max_timestamp(self) -> float: return self.__start + self.__duration - def sample(self, timestamp: float) -> TrajectorySample: - - # If the timestamp is outside the trajectory, return the initial or final pose - if timestamp < self.__start: - return TrajectorySample(timestamp, self.__initial_pose, np.zeros(3, dtype=np.float_)) - - if timestamp >= self.__start + self.__duration: - return TrajectorySample(timestamp, self.__final_pose, np.zeros(3, dtype=np.float_)) + def sample_velocity(self, timestamp: float) -> np.ndarray: + if timestamp >= self.__start and timestamp < self.__start + self.__duration: + return self.__velocity + else: + return np.zeros(3, np.float_) - # Interpolate orientation and position - t = timestamp - self.__start - orientation = self.__initial_quaternion + t * self.__quaternion_velocity - translation = self.__initial_pose.translation + t * self.__velocity - transformation = Transformation.From_Quaternion(orientation, translation) + def sample_translation(self, timestamp: float) -> np.ndarray: + t = np.clip(timestamp, self.__start, self.__start + self.__duration) - self.__start + return self.__initial_pose.translation + t * self.__velocity - return TrajectorySample(timestamp, transformation, self.__velocity) + def sample_orientation(self, timestamp: float) -> np.ndarray: + t = np.clip(timestamp, self.__start, self.__start + self.__duration) + return self.__slerp(t).as_matrix() class StaticTrajectory(Serializable, Trajectory): @@ -169,6 +233,15 @@ def velocity(self) -> np.ndarray: def max_timestamp(self) -> float: return 0.0 + def sample_velocity(self, timestamp: float) -> np.ndarray: + return self.__velocity + + def sample_translation(self, timestamp: float) -> np.ndarray: + return self.__pose.translation + + def sample_orientation(self, timestamp: float) -> np.ndarray: + return self.__pose[:3, :3] + def sample(self, timestamp: float) -> TrajectorySample: return TrajectorySample(timestamp, self.__pose, self.__velocity) @@ -229,17 +302,8 @@ def __init__(self, height: float, duration: float) -> None: def max_timestamp(self) -> float: return self.__duration - def sample(self, timestamp: float) -> TrajectorySample: - - if timestamp >= self.__duration: - return TrajectorySample( - timestamp, - Transformation.From_Translation( - np.array([0, 0.5 * self.__height, 0], dtype=np.float_) - ), - np.zeros(3, dtype=np.float_), - ) - + def __start_end_point_time(self, timestamp: float) -> tuple: + """Returns start start_point, end_point, start_time and end_time of the straight path section.""" if timestamp > self.__duration * 4 / 5: start_point = self.__height * np.array([0.5, 0.5, 0], dtype=np.float_) end_point = self.__height * np.array([0, 0.5, 0], dtype=np.float_) @@ -264,10 +328,19 @@ def sample(self, timestamp: float) -> TrajectorySample: start_time = 0 end_time = self.__duration * 2 / 5 - leg_duration = end_time - start_time - velocity = (end_point - start_point) / leg_duration - interpolated_position = start_point + velocity * (timestamp - start_time) + return start_point, end_point, start_time, end_time + + def sample_velocity(self, timestamp: float) -> np.ndarray: + start_point, end_point, start_time, end_time = self.__start_end_point_time(timestamp) + if timestamp <= start_time or timestamp >= end_time: + return np.zeros(3, np.float_) + return (end_point - start_point) / end_time - start_time + + def sample_translation(self, timestamp: float) -> np.ndarray: + start_point, _, start_time, end_time = self.__start_end_point_time(timestamp) + if timestamp <= start_time or timestamp >= end_time: + return np.array([0, 0.5 * self.__height, 0], np.float_) + return start_point + self.sample_velocity(timestamp) * (timestamp - start_time) - return TrajectorySample( - timestamp, Transformation.From_Translation(interpolated_position), velocity - ) + def sample_orientation(self, timestamp: float) -> np.ndarray: + return np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], float) diff --git a/hermespy/simulation/drop.py b/hermespy/simulation/drop.py index 69827c73..5f8d1385 100644 --- a/hermespy/simulation/drop.py +++ b/hermespy/simulation/drop.py @@ -1,17 +1,14 @@ # -*- coding: utf-8 -*- from __future__ import annotations -from typing import List, Sequence, Type, TYPE_CHECKING +from typing import List, Sequence, Type from h5py import Group -from hermespy.channel import ChannelRealization -from hermespy.core import Drop +from hermespy.channel import Channel, ChannelRealization +from hermespy.core import Device, Drop from .simulated_device import SimulatedDeviceReception, SimulatedDeviceTransmission -if TYPE_CHECKING: - from hermespy.simulation import SimulationScenario # pragma: no cover - __author__ = "Jan Adler" __copyright__ = "Copyright 2024, Barkhausen Institut gGmbH" __credits__ = ["Jan Adler"] @@ -22,7 +19,7 @@ __status__ = "Prototype" -class SimulatedDrop(Drop): +class SimulatedDrop(Drop[SimulatedDeviceTransmission, SimulatedDeviceReception]): """Drop containing all information generated during a simulated wireless scenario transmission, channel propagation and reception.""" @@ -51,7 +48,10 @@ def __init__( Received device information. """ + # Initialize attributes self.__channel_realizations = channel_realizations + + # Initialize base class Drop.__init__(self, timestamp, device_transmissions, device_receptions) @property @@ -80,37 +80,43 @@ def to_HDF(self, group: Group) -> None: @classmethod def from_HDF( - cls: Type[SimulatedDrop], group: Group, scenario: SimulationScenario | None = None + cls: Type[SimulatedDrop], + group: Group, + devices: Sequence[Device] | None = None, + channels: Sequence[Channel] | None = None, ) -> SimulatedDrop: - # Require a scenario to be specified - # Maybe there is a workaround possible since this is validates object-oriented principles - if scenario is None: - raise ValueError("Simulation drops must be deserialized with a scenario instance") + """Recall a simulated drop from a HDF5 group. + + Args: + + group (Group): The HDF5 group containing the serialized drop. + devices (Sequence[Device], optional): The devices participating in the scenario. + channels (Sequence[Channel], optional): The channels used in the scenario. + """ # Recall attributes timestamp = group.attrs.get("timestamp", 0.0) num_transmissions = group.attrs.get("num_transmissions", 0) num_receptions = group.attrs.get("num_receptions", 0) num_devices = group.attrs.get("num_devices", 1) - - # Assert that the scenario parameters match the serialization - if scenario.num_devices != num_devices: - raise ValueError( - f"Number of scenario devices does not match the serialization ({scenario.num_devices} != {num_devices})" - ) + _devices = [None] * num_devices if devices is None else devices # Recall groups transmissions = [ - SimulatedDeviceTransmission.from_HDF(group[f"transmission_{t:02d}"]) - for t in range(num_transmissions) + SimulatedDeviceTransmission.from_HDF( + group[f"transmission_{t:02d}"], None if d is None else list(d.transmitters) + ) + for t, d in zip(range(num_transmissions), _devices) ] receptions = [ - SimulatedDeviceReception.from_HDF(group[f"reception_{r:02d}"]) - for r in range(num_receptions) + SimulatedDeviceReception.from_HDF( + group[f"reception_{r:02d}"], None if d is None else list(d.receivers) + ) + for r, d in zip(range(num_receptions), _devices) ] channel_realizations: List[ChannelRealization] = [] - for c, channel in enumerate(scenario.channels): + for c, channel in enumerate(channels): realization = channel.recall_realization(group[f"channel_realization_{c:02d}"]) channel_realizations.append(realization) diff --git a/hermespy/simulation/scenario.py b/hermespy/simulation/scenario.py index 3df8385a..330798c9 100644 --- a/hermespy/simulation/scenario.py +++ b/hermespy/simulation/scenario.py @@ -6,6 +6,7 @@ import matplotlib.pyplot as plt import numpy as np +from h5py import Group from mpl_toolkits.mplot3d.axes3d import Axes3D # type: ignore from mpl_toolkits.mplot3d.art3d import Line3DCollection # type: ignore @@ -180,7 +181,7 @@ def _update_visualization( ) -class SimulationScenario(Scenario[SimulatedDevice]): +class SimulationScenario(Scenario[SimulatedDevice, SimulatedDrop]): """Description of a physical layer wireless communication scenario.""" yaml_tag = "SimulationScenario" @@ -748,6 +749,9 @@ def _drop(self) -> SimulatedDrop: timestamp, device_transmissions, channel_realizations, device_receptions ) + def _recall_drop(self, group: Group) -> SimulatedDrop: + return SimulatedDrop.from_HDF(group, self.devices, self.channels) + @property def visualize(self) -> _ScenarioVisualizer: return self.__visualizer diff --git a/hermespy/simulation/simulated_device.py b/hermespy/simulation/simulated_device.py index 0a49219d..7c5e9296 100644 --- a/hermespy/simulation/simulated_device.py +++ b/hermespy/simulation/simulated_device.py @@ -21,6 +21,7 @@ RandomNode, Transformation, Transmission, + Transmitter, Reception, Scenario, Serializable, @@ -624,10 +625,12 @@ def From_SimulatedDeviceOutput( @classmethod def from_HDF( - cls: Type[SimulatedDeviceTransmission], group: Group + cls: Type[SimulatedDeviceTransmission], + group: Group, + operators: Sequence[Transmitter] | None = None, ) -> SimulatedDeviceTransmission: # Recover base classes - device_transmission = DeviceTransmission.from_HDF(group) + device_transmission = DeviceTransmission.from_HDF(group, operators) devic_output = SimulatedDeviceOutput.from_HDF(group) # Initialize class from device output and operator transmissions @@ -891,9 +894,13 @@ def to_HDF(self, group: Group) -> None: DeviceReception.to_HDF(self, group) @classmethod - def from_HDF(cls: Type[SimulatedDeviceReception], group: Group) -> SimulatedDeviceReception: + def from_HDF( + cls: Type[SimulatedDeviceReception], + group: Group, + operators: Sequence[Receiver] | None = None, + ) -> SimulatedDeviceReception: device_input = ProcessedSimulatedDeviceInput.from_HDF(group) - device_reception = DeviceReception.from_HDF(group) + device_reception = DeviceReception.from_HDF(group, operators) return cls.From_ProcessedSimulatedDeviceInput( device_input, device_reception.operator_receptions diff --git a/pyproject.toml b/pyproject.toml index 6f49b4eb..584e129e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,6 +100,9 @@ develop = [ sionna = [ "sionna>=0.17.0", ] +scapy = [ + "scapy>=2.5.0", +] [tool.scikit-build] cmake.verbose = true diff --git a/submodules/affect b/submodules/affect index 61509eb7..8fa65a3c 160000 --- a/submodules/affect +++ b/submodules/affect @@ -1 +1 @@ -Subproject commit 61509eb756ae3725b8a67c2d26a5af5ba95186fb +Subproject commit 8fa65a3ca9b0dcdd3d544363bc692d4f85f6f718 diff --git a/tests/integration_tests/test_scenario.py b/tests/integration_tests/test_scenario.py index 1d6893f2..bfec7cd7 100644 --- a/tests/integration_tests/test_scenario.py +++ b/tests/integration_tests/test_scenario.py @@ -7,8 +7,7 @@ from numpy.testing import assert_array_almost_equal -from hermespy.core import Drop -from hermespy.simulation import SimulationScenario +from hermespy.simulation import SimulatedDrop, SimulationScenario from hermespy.modem import TransmittingModem, ReceivingModem, RaisedCosineWaveform __author__ = "Jan Adler" @@ -49,7 +48,7 @@ def tearDown(self) -> None: self.scenario.stop() self.tempdir.cleanup() - def _record(self) -> List[Drop]: + def _record(self) -> List[SimulatedDrop]: """Record some drops for testing. Returns: List of recorded drops. diff --git a/tests/unit_tests/core/test_drop.py b/tests/unit_tests/core/test_drop.py index 0e9550ed..95b43e5d 100644 --- a/tests/unit_tests/core/test_drop.py +++ b/tests/unit_tests/core/test_drop.py @@ -8,7 +8,7 @@ import numpy as np from h5py import File, Group -from hermespy.core import Drop, RecalledDrop, Signal, DeviceTransmission, DeviceReception +from hermespy.core import Drop, Signal, DeviceTransmission, DeviceReception from hermespy.core.drop import EvaluatedDrop __author__ = "Jan Adler" @@ -60,7 +60,7 @@ def test_hdf_serialization(self) -> None: mock_device.receivers = [] mock_device.transmitters = [] mock_scenario.devices = [mock_device] - recalled_drop = RecalledDrop(file["g1"], mock_scenario) + recalled_drop = Drop.from_HDF(file["g1"], mock_scenario.devices) file.close() self.assertEqual(self.drop.timestamp, deserialization.timestamp) @@ -69,7 +69,6 @@ def test_hdf_serialization(self) -> None: self.assertEqual(self.drop.timestamp, recalled_drop.timestamp) self.assertEqual(self.drop.num_device_transmissions, recalled_drop.num_device_transmissions) self.assertEqual(self.drop.num_device_receptions, recalled_drop.num_device_receptions) - self.assertIsInstance(recalled_drop.group, Group) class TestEvaluatedDrop(TestCase): diff --git a/tests/unit_tests/core/test_scenario.py b/tests/unit_tests/core/test_scenario.py index f465ce8b..a59e835a 100644 --- a/tests/unit_tests/core/test_scenario.py +++ b/tests/unit_tests/core/test_scenario.py @@ -7,9 +7,9 @@ from unittest.mock import PropertyMock, MagicMock, Mock, patch import numpy.random as rnd -from h5py import File +from h5py import File, Group -from hermespy.core import Drop, Scenario, ScenarioMode, Signal, SignalReceiver, SilentTransmitter, ReplayScenario +from hermespy.core import Device, Drop, Scenario, ScenarioMode, Signal, SignalReceiver, SilentTransmitter, ReplayScenario from hermespy.simulation import SimulatedDevice __author__ = "Tobias Kronauer" @@ -22,7 +22,7 @@ __status__ = "Prototype" -class MockScenario(Scenario): +class MockScenario(Scenario[Device, Drop]): """Implementation of abstract scenario base for testing purpuses""" def _drop(self) -> Drop: @@ -31,6 +31,9 @@ def _drop(self) -> Drop: receptions = self.receive_devices([t.mixed_signal for t in transmissions]) return Drop(0.0, transmissions, receptions) + def _recall_drop(self, group: Group) -> Drop: + return Drop.from_HDF(group) + class TestScenario(TestCase): """Test scenario base class""" diff --git a/tests/unit_tests/core/test_signal.py b/tests/unit_tests/core/test_signal.py index 2a2fcc0b..8ea859b4 100644 --- a/tests/unit_tests/core/test_signal.py +++ b/tests/unit_tests/core/test_signal.py @@ -234,6 +234,12 @@ def test_setgetitem(self) -> None: for key in keys: assert_array_equal(self.samples_dense[key].flatten(), self.signal.getitem(key).flatten()) + # getitem with unflatten=False + assert_array_equal(self.signal.getitem(0, False).shape, + (self.signal.num_samples,)) + assert_array_equal(self.signal.getitem((slice_full, 0), False).shape, + (self.signal.num_streams,)) + # __setitem__ dummy_value = 13.37 + 73.31j dummy_samples_full = np.full((self.num_streams, self.num_samples), diff --git a/tests/unit_tests/modem/test_frame_generators.py b/tests/unit_tests/modem/test_frame_generators.py new file mode 100644 index 00000000..76b0a476 --- /dev/null +++ b/tests/unit_tests/modem/test_frame_generators.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +"""HermesPy FrameGenerators testing.""" + +from hermespy.modem.bits_source import RandomBitsSource +from hermespy.modem.frame_generator import FrameGeneratorStub +from hermespy.modem.frame_generator.scapy import FrameGeneratorScapy + +import unittest +import numpy as np + +from numpy.testing import assert_array_equal + +from scapy.layers.dot11 import Dot11 + + +class TestFrameGeneratorStub(unittest.TestCase): + """Test the placeholder stub frame generator""" + + def setUp(self) -> None: + self.fg = FrameGeneratorStub() + self.bs = RandomBitsSource(42) + + def test_pack_unpack(self): + for num_frame_bits in [0, 1, 8, 2**10, 11]: + frame = self.fg.pack_frame(self.bs, num_frame_bits) + payload = self.fg.unpack_frame(frame) + # TODO how can this be asserted? + + +class TestFrameGeneratorScapy(unittest.TestCase): + """Test the Scapy wrapper frame generator""" + + def setUp(self) -> None: + self.packet_base = Dot11(proto=1, ID=1337, addr1='01:23:45:67:89:ab', addr2='ff:ee:dd:cc:bb:aa') + self.packet_base_num_bits = len(self.packet_base)*8 + self.fg = FrameGeneratorScapy(self.packet_base) + self.bs = RandomBitsSource(42) + + def test_pack_unpack(self): + """Test pack_bits and unpack_bits with the 802.11 Scapy implementation""" + + # Try packing frames with valid number of bits + for num_bits in np.array([8, 16, 2**10]) + self.packet_base_num_bits: + packet = self.fg.pack_frame(self.bs, num_bits) + payload = self.fg.unpack_frame(packet) + # Strip the Dot11 head off the packet and + # assert that what is left is the same payload + payload_expected = packet[-num_bits+self.packet_base_num_bits:] + assert_array_equal(payload_expected, payload) + + # Test a zero-sized payload + num_bits = self.packet_base_num_bits + packet = self.fg.pack_frame(self.bs, num_bits) + payload = self.fg.unpack_frame(packet) + self.assertEqual(packet.size, self.packet_base_num_bits) + assert_array_equal(payload, []) + + # Test invalid num_bits + for num_bits in [-1, self.packet_base_num_bits-1]: + with self.assertRaises(ValueError): + self.fg.pack_frame(self.bs, num_bits) + + # Test unpacking of an invalid frame + with self.assertRaises(ValueError): + # Assuming this is a packet from the zero-sized payload test + self.fg.unpack_frame(packet[1:]) diff --git a/tests/unit_tests/simulation/test_animation.py b/tests/unit_tests/simulation/test_animation.py index 507271a9..24812dce 100644 --- a/tests/unit_tests/simulation/test_animation.py +++ b/tests/unit_tests/simulation/test_animation.py @@ -3,10 +3,16 @@ from unittest import TestCase import numpy as np -from numpy.testing import assert_array_equal +from numpy.testing import assert_array_equal, assert_array_almost_equal from hermespy.core import Transformation -from hermespy.simulation.animation import BITrajectoryB, LinearTrajectory, Moveable, StaticTrajectory, TrajectorySample +from hermespy.simulation.animation import ( + BITrajectoryB, + LinearTrajectory, + Moveable, + StaticTrajectory, + TrajectorySample, +) __author__ = "Jan Adler" __copyright__ = "Copyright 2024, Barkhausen Institut gGmbH" @@ -20,20 +26,20 @@ class TestTrajectorySample(TestCase): """Test trajectory sample data class.""" - + def setUp(self) -> None: - + self.timestamp = 12345 self.pose = Transformation.From_RPY(np.array([1, 2, 3]), np.array([4, 5, 6])) self.velocity = np.array([7, 8, 9]) - + self.sample = TrajectorySample(self.timestamp, self.pose, self.velocity) - + def test_init(self) -> None: """Initialization parameters are stored correctly.""" - + trajectory_sample = TrajectorySample(self.timestamp, self.pose, self.velocity) - + self.assertEqual(trajectory_sample.timestamp, self.timestamp) assert_array_equal(trajectory_sample.velocity, self.velocity) assert_array_equal(trajectory_sample.pose, self.pose) @@ -41,81 +47,122 @@ def test_init(self) -> None: class TestLinearTrajectory(TestCase): """Test the linear trajectory class.""" - + def setUp(self) -> None: - + self.initial_pose = Transformation.From_RPY(np.array([1, 2, 3]), np.array([4, 5, 6])) self.final_pose = Transformation.From_RPY(np.array([7, 8, 9]), np.array([10, 11, 12])) self.duration = 9.876 self.start = 1.234 - - self.linear_trajectory = LinearTrajectory(self.initial_pose, self.final_pose, self.duration, self.start) + + self.linear_trajectory = LinearTrajectory( + self.initial_pose, self.final_pose, self.duration, self.start + ) def test_init_validation(self) -> None: """Initialization should raise an error if the duration is negative.""" - + with self.assertRaises(ValueError): LinearTrajectory(self.initial_pose, self.final_pose, -1, self.start) - + with self.assertRaises(ValueError): LinearTrajectory(self.initial_pose, self.final_pose, 1, -1) def test_max_timestamp(self) -> None: """Max timestamp should be the sum of start and duration.""" - + self.assertAlmostEqual(self.linear_trajectory.max_timestamp, self.start + self.duration) def test_sample_before_start(self) -> None: """Sampling before start should return the initial pose.""" - + sample = self.linear_trajectory.sample(self.start - 1) - assert_array_equal(self.initial_pose, sample.pose) - + assert_array_almost_equal(self.initial_pose, sample.pose) + def test_sample_after_end(self) -> None: """Sampling after end should return the final pose.""" - + sample = self.linear_trajectory.sample(self.start + self.duration + 1) - assert_array_equal(self.final_pose, sample.pose) - + assert_array_almost_equal(self.final_pose, sample.pose) + def test_sample(self) -> None: """Sampling within the trajectory should return the correct pose.""" - + sample = self.linear_trajectory.sample(self.start + self.duration / 2) - assert_array_equal(sample.pose.translation, (self.initial_pose.translation + self.final_pose.translation) / 2) + assert_array_equal( + sample.pose.translation, + (self.initial_pose.translation + self.final_pose.translation) / 2, + ) + + def test_lookat(self) -> None: + """Setting lookat target, flag and up should correctly override samples' rotation.""" + + # Try enabling lookup without a target + with self.assertRaises(RuntimeError): + self.linear_trajectory.lookat_enable() + + # Enable lookup + target_position = np.array([2., 5., 0.]) + target_trajectory = StaticTrajectory(Transformation.From_Translation(target_position)) + up = np.array([0., 1., 0.]) + self.linear_trajectory.lookat(target_trajectory, up) + + for t in np.linspace(self.start, self.start+self.duration, 10, True): + pose = self.linear_trajectory.sample(t).pose + # Assert local Z pointing towards the target + f = target_position - pose[:3, 3] + assert_array_almost_equal(pose[:3, 2], f / np.linalg.norm(f)) + # Assert correct up alignment + s = np.cross(up, pose[:3, 2]) + assert_array_almost_equal(pose[:3, 0], s / np.linalg.norm(s)) + # Try disabling lookat + self.linear_trajectory.lookat_disable() + self.assertTrue(np.any(pose != self.linear_trajectory.sample(t).pose)) + # Try enabling it back + self.linear_trajectory.lookat_enable() + assert_array_equal(pose, self.linear_trajectory.sample(t).pose) class TestStaticTrajectory(TestCase): """Test static trajectory class.""" - + def setUp(self) -> None: self.pose = Transformation.From_RPY(np.array([1, 2, 3]), np.array([4, 5, 6])) self.velocity = np.array([7, 8, 9]) - + self.trajectory = StaticTrajectory(self.pose, self.velocity) def test_init(self) -> None: """Initialization parameters are stored correctly.""" - + assert_array_equal(self.trajectory.pose, self.pose) assert_array_equal(self.trajectory.velocity, self.velocity) - + def test_sample(self) -> None: """Sampling should return the correct pose.""" - - sample = self.trajectory.sample(12345) + + t = 12345 + + sample = self.trajectory.sample(t) assert_array_equal(sample.pose, self.pose) assert_array_equal(sample.velocity, self.velocity) + assert_array_equal(self.trajectory.sample_velocity(t), self.velocity) + assert_array_equal(self.trajectory.sample_translation(t), self.pose.translation) + assert_array_equal(self.trajectory.sample_orientation(t), self.pose[:3, :3]) + class TestMoveable(TestCase): """Test moveable base class""" def setUp(self) -> None: - + self.trajectory = LinearTrajectory( Transformation.From_Translation(np.array([1, 2, 3])), Transformation.From_Translation(np.array([8, 2, 3])), - 10, 2) + 10, + 2, + ) self.moveable = Moveable(self.trajectory) def test_init(self) -> None: @@ -125,40 +172,65 @@ def test_init(self) -> None: def test_trajectory_setget(self) -> None: """Trajectory property getter should return setter argument""" - + expected_trajectory = LinearTrajectory( Transformation.From_Translation(np.array([1, 2, 5])), Transformation.From_Translation(np.array([8, 2, 5])), - 10, 2) - + 10, + 2, + ) + self.moveable.trajectory = expected_trajectory self.assertIs(self.moveable.trajectory, expected_trajectory) - + def test_max_trajectory_timestamp(self) -> None: """Max timestamp should be the trajectory's max timestamp""" - - self.assertAlmostEqual(self.moveable.max_trajectory_timestamp, self.trajectory.max_timestamp) + + self.assertAlmostEqual( + self.moveable.max_trajectory_timestamp, self.trajectory.max_timestamp + ) class TestBITrajectoryB(TestCase): """Test the BI trajectory class.""" - + def setUp(self) -> None: - + self.height = 10 self.duration = 11.2345 - + self.trajectory = BITrajectoryB(self.height, self.duration) def test_max_timestamp(self) -> None: """Max timestamp should be the duration""" self.assertEqual(self.duration, self.trajectory.max_timestamp) - + def test_sample(self) -> None: """Sample should return valid poses for each leg""" - - leg_timestamps = [1.5 * self.duration, 4.5 * self.duration / 5, 3.5 * self.duration / 5, 2.5 * self.duration / 5, 1.0 * self.duration / 5] - for leg_timetamp in leg_timestamps: + + lookat_target = StaticTrajectory(Transformation.From_Translation(np.array([1., 2., 3.]))) + self.trajectory.lookat(lookat_target) + leg_timetamps = [ + 1.5 * self.duration, + 4.5 * self.duration / 5, + 3.5 * self.duration / 5, + 2.5 * self.duration / 5, + 1.0 * self.duration / 5, + ] + default_orientation = np.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]) + for leg_timetamp in leg_timetamps: sample = self.trajectory.sample(leg_timetamp) self.assertEqual(sample.timestamp, leg_timetamp) + # sample_velocity + assert_array_equal(sample.velocity, + self.trajectory.sample_velocity(leg_timetamp)) + # sample_translation + assert_array_equal(sample.pose.translation, + self.trajectory.sample_translation(leg_timetamp)) + # sample_orientation + assert_array_equal(default_orientation, + self.trajectory.sample_orientation(leg_timetamp)) + # assert lookat orientation + f = lookat_target.pose[:3, 3] - sample.pose[:3, 3] + assert_array_almost_equal(sample.pose[:3, 2], f / np.linalg.norm(f)) diff --git a/tests/unit_tests/simulation/test_drop.py b/tests/unit_tests/simulation/test_drop.py index f11be797..ca88b544 100644 --- a/tests/unit_tests/simulation/test_drop.py +++ b/tests/unit_tests/simulation/test_drop.py @@ -32,24 +32,6 @@ def test_channel_realizations(self) -> None: self.assertEqual(1, len(self.drop.channel_realizations)) - def test_hdf_serialization_validation(self) -> None: - """HDF serialization should raise ValueError on invalid scenario arguments""" - - file = File("test.h5", "w", driver="core", backing_store=False) - group = file.create_group("group") - - self.drop.to_HDF(group) - - with self.assertRaises(ValueError): - _ = self.drop.from_HDF(group) - - self.scenario.new_device() - - with self.assertRaises(ValueError): - _ = SimulatedDrop.from_HDF(group, scenario=self.scenario) - - file.close() - def test_hdf_serialization(self) -> None: """Test HDF roundtrip serialization""" @@ -57,7 +39,7 @@ def test_hdf_serialization(self) -> None: group = file.create_group("group") self.drop.to_HDF(group) - deserialization = SimulatedDrop.from_HDF(group, scenario=self.scenario) + deserialization = SimulatedDrop.from_HDF(group, self.scenario.devices, self.scenario.channels) file.close()