diff --git a/stubs/amaranth/sim/core.pyi b/stubs/amaranth/sim/core.pyi index 23fef3472..58241c6bb 100644 --- a/stubs/amaranth/sim/core.pyi +++ b/stubs/amaranth/sim/core.pyi @@ -5,6 +5,7 @@ This type stub file was generated by pyright. from .._utils import deprecated from ..hdl.cd import * from ..hdl.ir import * +from .pysim import * __all__ = ["Settle", "Delay", "Tick", "Passive", "Active", "Simulator"] class Command: @@ -48,6 +49,9 @@ class Active(Command): class Simulator: + _fragment : Fragment + _engine : PySimEngine + _clocked : set def __init__(self, fragment, *, engine=...) -> None: ... diff --git a/test/common/_test/test_infrastructure.py b/test/common/_test/test_infrastructure.py index ecf1c84d9..5dea7ba67 100644 --- a/test/common/_test/test_infrastructure.py +++ b/test/common/_test/test_infrastructure.py @@ -29,3 +29,30 @@ def process(self): def test_random(self): with self.run_simulation(self.m, 50) as sim: sim.add_sync_process(self.process) + + +class TestTrueSettle(TestCaseWithSimulator): + def setUp(self): + self.m = SimpleTestCircuit(EmptyCircuit()) + self.test_cycles = 10 + self.flag = False + random.seed(14) + + def true_settle_process(self): + for k in range(self.test_cycles): + yield TrueSettle() + self.assertTrue(self.flag) + self.flag = False + yield + + def flag_process(self): + for k in range(self.test_cycles): + for i in range(random.randrange(0, 5)): + yield Settle() + self.flag = True + yield + + def test_flag(self): + with self.run_simulation(self.m, 50) as sim: + sim.add_sync_process(self.true_settle_process) + sim.add_sync_process(self.flag_process) diff --git a/test/common/infrastructure.py b/test/common/infrastructure.py index 058d5b9ed..4c6106c92 100644 --- a/test/common/infrastructure.py +++ b/test/common/infrastructure.py @@ -2,6 +2,7 @@ import random import unittest import functools +from enum import Enum from contextlib import contextmanager, nullcontext from typing import TypeVar, Generic, Type, TypeGuard, Any, Union, Callable, cast, TypeAlias from abc import ABC @@ -109,14 +110,32 @@ class Now(CoreblocksCommand): pass +class TrueSettle(CoreblocksCommand): + """Wait till all process are waiting for the next cycle or for the TrueSettle""" + + pass + + +class SyncProcessState(Enum): + """State of SyncProcessWrapper.""" + + sleeping = 0 # Wait for the next cycle + running = 1 + ended = 2 + true_settle = 3 # Wait for the TrueSettle + + class SyncProcessWrapper: def __init__(self, f): self.org_process = f self.current_cycle = 0 + self.state = None + self.blocked = None def _wrapping_function(self): response = None org_coroutine = self.org_process() + self.state = SyncProcessState.running try: while True: # call orginal test process and catch data yielded by it in `command` variable @@ -124,14 +143,26 @@ def _wrapping_function(self): # If process wait for new cycle if command is None: self.current_cycle += 1 + self.state = SyncProcessState.sleeping # forward to amaranth yield - elif isinstance(command, Now): - response = self.current_cycle + self.state = SyncProcessState.running + elif isinstance(command, CoreblocksCommand): + if isinstance(command, Now): + response = self.current_cycle + elif isinstance(command, TrueSettle): + self.state = SyncProcessState.true_settle + self.blocked = True + while self.blocked: + yield Settle() + self.state = SyncProcessState.running + else: + raise RuntimeError(f"Not known CoreblocksCommand: {command}") # Pass everything else to amaranth simulator without modifications else: response = yield command except StopIteration: + self.state = SyncProcessState.ended pass @@ -168,15 +199,30 @@ def __init__(self, module: HasElaborate, max_cycles: float = 10e4, add_transacti self.ctx = nullcontext() self.deadline = clk_period * max_cycles + self.sync_proc_list = [] def add_sync_process(self, f: Callable[[], TestGen]): f_wrapped = SyncProcessWrapper(f) + self.sync_proc_list.append(f_wrapped) super().add_sync_process(f_wrapped._wrapping_function) - def run(self) -> bool: - with self.ctx: - self.run_until(self.deadline) + def _check_true_settle_ready(self): + return all(p.state != SyncProcessState.running for p in self.sync_proc_list) + def _unblock_sync_processes(self): + for p in self.sync_proc_list: + if p.blocked: + p.blocked = False + + def run(self) -> bool: + deadline = self.deadline * 1e12 + assert self._engine.now <= deadline + last_now = self._engine.now + while self.advance() and self._engine.now < deadline: + if last_now == self._engine.now: + if self._check_true_settle_ready(): + self._unblock_sync_processes() + last_now = self._engine.now return not self.advance()