diff --git a/test/common.py b/test/common.py index 76ede88f4..51c09874d 100644 --- a/test/common.py +++ b/test/common.py @@ -12,7 +12,7 @@ from transactron.core import SignalBundle, Method, TransactionModule from transactron.lib import AdapterBase, AdapterTrans -from transactron._utils import mock_def_helper +from transactron._utils import def_helper from coreblocks.utils import ValueLike, HasElaborate, HasDebugSignals, auto_debug_signals, LayoutLike, ModuleConnector from .gtkw_extension import write_vcd_ext @@ -260,6 +260,10 @@ def random_wait(self, max_cycle_cnt): yield from self.tick(random.randrange(max_cycle_cnt)) +def mock_def_helper(tb, func: Callable[..., T], arg: Mapping[str, Any]) -> T: + return def_helper(f"mock definition for {tb}", func, Mapping[str, Any], arg, **arg) + + class TestbenchIO(Elaboratable): def __init__(self, adapter: AdapterBase): self.adapter = adapter diff --git a/test/transactions/test_transaction_lib.py b/test/transactions/test_transaction_lib.py index a23ee077e..aa395ecb9 100644 --- a/test/transactions/test_transaction_lib.py +++ b/test/transactions/test_transaction_lib.py @@ -13,7 +13,7 @@ from transactron.core import RecordDict from transactron.lib import * from coreblocks.utils import * -from coreblocks.utils._typing import LayoutLike, ModuleLike +from coreblocks.utils._typing import LayoutLike from coreblocks.utils import ModuleConnector from ..common import ( SimpleTestCircuit, @@ -362,21 +362,21 @@ def elaborate(self, platform): layout = data_layout(self.iosize) - def itransform_rec(m: ModuleLike, v: Record) -> Record: - s = Record.like(v) - m.d.comb += s.data.eq(v.data + 1) + def itransform_rec(m: TModule, arg: Record) -> Record: + s = Record.like(arg) + m.d.comb += s.data.eq(arg.data + 1) return s - def otransform_rec(m: ModuleLike, v: Record) -> Record: - s = Record.like(v) - m.d.comb += s.data.eq(v.data - 1) + def otransform_rec(m: TModule, arg: Record) -> Record: + s = Record.like(arg) + m.d.comb += s.data.eq(arg.data - 1) return s - def itransform_dict(_, v: Record) -> RecordDict: - return {"data": v.data + 1} + def itransform_dict(data: Value) -> RecordDict: + return {"data": data + 1} - def otransform_dict(_, v: Record) -> RecordDict: - return {"data": v.data - 1} + def otransform_dict(data: Value) -> RecordDict: + return {"data": data - 1} if self.use_dicts: itransform = itransform_dict @@ -388,16 +388,13 @@ def otransform_dict(_, v: Record) -> RecordDict: m.submodules.target = self.target = TestbenchIO(Adapter(i=layout, o=layout)) if self.use_methods: + assert self.use_dicts + imeth = Method(i=layout, o=layout) ometh = Method(i=layout, o=layout) - @def_method(m, imeth) - def _(arg: Record): - return itransform(m, arg) - - @def_method(m, ometh) - def _(arg: Record): - return otransform(m, arg) + def_method(m, imeth)(itransform) + def_method(m, ometh)(otransform) trans = MethodTransformer( self.target.adapter.iface, i_transform=(layout, imeth), o_transform=(layout, ometh) @@ -483,8 +480,8 @@ def test_method_filter_with_methods(self): def test_method_filter(self): self.initialize() - def condition(_, v): - return v[0] + def condition(data: Value): + return data[0] self.tc = SimpleTestCircuit(MethodFilter(self.target.adapter.iface, condition)) m = ModuleConnector(test_circuit=self.tc, target=self.target) @@ -515,7 +512,7 @@ def elaborate(self, platform): combiner = None if self.add_combiner: - combiner = (layout, lambda _, vs: {"data": sum(vs)}) + combiner = (layout, lambda vs: {"data": sum(vs)}) m.submodules.product = product = MethodProduct(methods, combiner) @@ -702,7 +699,7 @@ def elaborate(self, platform): combiner = None if self.add_combiner: - combiner = (layout, lambda _, vs: {"data": sum(Mux(s, r, 0) for (s, r) in vs)}) + combiner = (layout, lambda vs: {"data": sum(Mux(s, r, 0) for (s, r) in vs)}) m.submodules.product = product = MethodTryProduct(methods, combiner) diff --git a/transactron/_utils.py b/transactron/_utils.py index 138c7222b..84746be3b 100644 --- a/transactron/_utils.py +++ b/transactron/_utils.py @@ -1,7 +1,8 @@ import itertools import sys +import functools from inspect import Parameter, signature -from typing import Any, Concatenate, Optional, TypeAlias, TypeGuard, TypeVar +from typing import Concatenate, Optional, ParamSpec, TypeAlias, TypeGuard, TypeVar from collections.abc import Callable, Iterable, Mapping from amaranth import * from coreblocks.utils._typing import LayoutLike @@ -16,12 +17,14 @@ "GraphCC", "get_caller_class_name", "def_helper", - "method_def_helper", + "bind_first_param", ] T = TypeVar("T") U = TypeVar("U") +P = ParamSpec("P") +CallableOptParam: TypeAlias = Callable[Concatenate[U, P], T] | Callable[P, T] class Scheduler(Elaboratable): @@ -124,7 +127,9 @@ def _graph_ccs(gr: ROGraph[T]) -> list[GraphCC[T]]: MethodLayout: TypeAlias = LayoutLike -def has_first_param(func: Callable[..., T], name: str, tp: type[U]) -> TypeGuard[Callable[Concatenate[U, ...], T]]: +def has_first_param( + func: CallableOptParam[U, P, T], name: str, tp: type[U] +) -> TypeGuard[Callable[Concatenate[U, P], T]]: parameters = signature(func).parameters return ( len(parameters) >= 1 @@ -134,6 +139,13 @@ def has_first_param(func: Callable[..., T], name: str, tp: type[U]) -> TypeGuard ) +def bind_first_param(func: CallableOptParam[U, P, T], name: str, tp: type[U], arg: U) -> Callable[P, T]: + if has_first_param(func, name, tp): + return functools.partial(func, arg) # type: ignore + else: + return func # type: ignore + + def def_helper(description, func: Callable[..., T], tp: type[U], arg: U, /, **kwargs) -> T: parameters = signature(func).parameters kw_parameters = set( @@ -147,14 +159,6 @@ def def_helper(description, func: Callable[..., T], tp: type[U], arg: U, /, **kw raise TypeError(f"Invalid {description}: {func}") -def mock_def_helper(tb, func: Callable[..., T], arg: Mapping[str, Any]) -> T: - return def_helper(f"mock definition for {tb}", func, Mapping[str, Any], arg, **arg) - - -def method_def_helper(method, func: Callable[..., T], arg: Record) -> T: - return def_helper(f"method definition for {method}", func, Record, arg, **arg.fields) - - def get_caller_class_name(default: Optional[str] = None) -> tuple[Optional[Elaboratable], str]: caller_frame = sys._getframe(2) if "self" in caller_frame.f_locals: diff --git a/transactron/core.py b/transactron/core.py index 035b3b588..6f380b90c 100644 --- a/transactron/core.py +++ b/transactron/core.py @@ -43,6 +43,7 @@ ] +T = TypeVar("T") TransactionGraph: TypeAlias = Graph["Transaction"] TransactionGraphCC: TypeAlias = GraphCC["Transaction"] PriorityOrder: TypeAlias = dict["Transaction", int] @@ -1162,6 +1163,10 @@ def debug_signals(self) -> SignalBundle: return [self.ready, self.run, self.data_in, self.data_out] +def method_def_helper(method: Method, func: Callable[..., T], arg: Record) -> T: + return def_helper(f"method definition for {method}", func, Record, arg, **arg.fields) + + def def_method(m: TModule, method: Method, ready: ValueLike = C(1)): """Define a method. diff --git a/transactron/lib/transformers.py b/transactron/lib/transformers.py index e4b7aa0c0..8bcad5bf6 100644 --- a/transactron/lib/transformers.py +++ b/transactron/lib/transformers.py @@ -1,9 +1,10 @@ from amaranth import * from ..core import * from ..core import RecordDict -from typing import Optional +from typing import Optional, ParamSpec, TypeAlias, TypeVar from collections.abc import Callable from coreblocks.utils import ValueLike, assign, AssignType +from transactron._utils import def_helper, bind_first_param, CallableOptParam from .connectors import Forwarder, ManyToOneConnectTrans, ConnectTrans __all__ = [ @@ -17,6 +18,19 @@ ] +T = TypeVar("T") +P = ParamSpec("P") +CallableOptTModule: TypeAlias = CallableOptParam[TModule, P, T] + + +def bind_tmodule(m: TModule, func: CallableOptTModule[P, T]) -> Callable[P, T]: + return bind_first_param(func, "m", TModule, m) + + +def transformer_helper(tr, m: TModule, func: Callable[..., T], arg: Record) -> T: + return def_helper(f"function for {tr}", bind_tmodule(m, func), Record, arg, **arg.fields) + + class MethodTransformer(Elaboratable): """Method transformer. @@ -36,8 +50,8 @@ def __init__( self, target: Method, *, - i_transform: Optional[tuple[MethodLayout, Callable[[TModule, Record], RecordDict]]] = None, - o_transform: Optional[tuple[MethodLayout, Callable[[TModule, Record], RecordDict]]] = None, + i_transform: Optional[tuple[MethodLayout, Callable[..., RecordDict]]] = None, + o_transform: Optional[tuple[MethodLayout, Callable[..., RecordDict]]] = None, ): """ Parameters @@ -54,9 +68,9 @@ def __init__( If not present, output is not transformed. """ if i_transform is None: - i_transform = (target.data_in.layout, lambda _, x: x) + i_transform = (target.data_in.layout, lambda arg: arg) if o_transform is None: - o_transform = (target.data_out.layout, lambda _, x: x) + o_transform = (target.data_out.layout, lambda arg: arg) self.target = target self.method = Method(i=i_transform[0], o=o_transform[0]) @@ -68,7 +82,7 @@ def elaborate(self, platform): @def_method(m, self.method) def _(arg): - return self.o_fun(m, self.target(m, self.i_fun(m, arg))) + return transformer_helper(self, m, self.o_fun, self.target(m, transformer_helper(self, m, self.i_fun, arg))) return m @@ -91,9 +105,7 @@ class MethodFilter(Elaboratable): The transformed method. """ - def __init__( - self, target: Method, condition: Callable[[TModule, Record], ValueLike], default: Optional[RecordDict] = None - ): + def __init__(self, target: Method, condition: Callable[..., ValueLike], default: Optional[RecordDict] = None): """ Parameters ---------- @@ -122,7 +134,7 @@ def elaborate(self, platform): @def_method(m, self.method) def _(arg): - with m.If(self.condition(m, arg)): + with m.If(transformer_helper(self, m, self.condition, arg)): m.d.comb += ret.eq(self.target(m, arg)) return ret @@ -133,7 +145,7 @@ class MethodProduct(Elaboratable): def __init__( self, targets: list[Method], - combiner: Optional[tuple[MethodLayout, Callable[[TModule, list[Record]], RecordDict]]] = None, + combiner: Optional[tuple[MethodLayout, CallableOptTModule[[list[Record]], RecordDict]]] = None, ): """Method product. @@ -148,10 +160,11 @@ def __init__( ---------- targets: list[Method] A list of methods to be called. - combiner: (int or method layout, function), optional + combiner: (record layout, function), optional A pair of the output layout and the combiner function. The - combiner function takes two parameters: a `Module` and - a list of outputs of the target methods. + combiner function takes a list of outputs of the target methods + and returns the result. Optionally, it can also take a `TModule` + as a first argument named `m`. Attributes ---------- @@ -159,7 +172,7 @@ def __init__( The product method. """ if combiner is None: - combiner = (targets[0].data_out.layout, lambda _, x: x[0]) + combiner = (targets[0].data_out.layout, lambda x: x[0]) self.targets = targets self.combiner = combiner self.method = Method(i=targets[0].data_in.layout, o=combiner[0]) @@ -172,7 +185,7 @@ def _(arg): results = [] for target in self.targets: results.append(target(m, arg)) - return self.combiner[1](m, results) + return bind_tmodule(m, self.combiner[1])(results) return m @@ -181,7 +194,7 @@ class MethodTryProduct(Elaboratable): def __init__( self, targets: list[Method], - combiner: Optional[tuple[MethodLayout, Callable[[TModule, list[tuple[Value, Record]]], RecordDict]]] = None, + combiner: Optional[tuple[MethodLayout, CallableOptTModule[[list[tuple[Value, Record]]], RecordDict]]] = None, ): """Method product with optional calling. @@ -196,11 +209,12 @@ def __init__( ---------- targets: list[Method] A list of methods to be called. - combiner: (int or method layout, function), optional + combiner: (record layout, function), optional A pair of the output layout and the combiner function. The - combiner function takes two parameters: a `Module` and - a list of pairs. Each pair contains a bit which signals - that a given call succeeded, and the result of the call. + combiner function takes a list of pairs such that the first + element is a bit which signals that a given call succeeded, + and the second is the result of the call. Optionally, it can + also take a `TModule` as a first argument named `m`. Attributes ---------- @@ -208,7 +222,7 @@ def __init__( The product method. """ if combiner is None: - combiner = ([], lambda _, __: {}) + combiner = ([], lambda arg: {}) self.targets = targets self.combiner = combiner self.method = Method(i=targets[0].data_in.layout, o=combiner[0]) @@ -224,7 +238,7 @@ def _(arg): with Transaction().body(m): m.d.comb += success.eq(1) results.append((success, target(m, arg))) - return self.combiner[1](m, results) + return bind_tmodule(m, self.combiner[1])(results) return m @@ -323,8 +337,8 @@ def __init__( method1: Method, method2: Method, *, - i_fun: Optional[Callable[[TModule, Record], RecordDict]] = None, - o_fun: Optional[Callable[[TModule, Record], RecordDict]] = None, + i_fun: Optional[Callable[..., RecordDict]] = None, + o_fun: Optional[Callable[..., RecordDict]] = None, ): """ Parameters @@ -340,8 +354,8 @@ def __init__( """ self.method1 = method1 self.method2 = method2 - self.i_fun = i_fun or (lambda _, x: x) - self.o_fun = o_fun or (lambda _, x: x) + self.i_fun = i_fun or (lambda arg: arg) + self.o_fun = o_fun or (lambda arg: arg) def elaborate(self, platform): m = TModule()