Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better transformer API #489

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from transactron.core import SignalBundle, Method, TransactionModule
from transactron.lib import AdapterBase, AdapterTrans
from transactron._utils import mock_def_helper
from transactron._utils import def_helper
from coreblocks.utils import ValueLike, HasElaborate, HasDebugSignals, auto_debug_signals, LayoutLike, ModuleConnector
from .gtkw_extension import write_vcd_ext

Expand Down Expand Up @@ -260,6 +260,10 @@ def random_wait(self, max_cycle_cnt):
yield from self.tick(random.randrange(max_cycle_cnt))


def mock_def_helper(tb, func: Callable[..., T], arg: Mapping[str, Any]) -> T:
return def_helper(f"mock definition for {tb}", func, Mapping[str, Any], arg, **arg)


class TestbenchIO(Elaboratable):
def __init__(self, adapter: AdapterBase):
self.adapter = adapter
Expand Down
41 changes: 19 additions & 22 deletions test/transactions/test_transaction_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from transactron.core import RecordDict
from transactron.lib import *
from coreblocks.utils import *
from coreblocks.utils._typing import LayoutLike, ModuleLike
from coreblocks.utils._typing import LayoutLike
from coreblocks.utils import ModuleConnector
from ..common import (
SimpleTestCircuit,
Expand Down Expand Up @@ -362,21 +362,21 @@ def elaborate(self, platform):

layout = data_layout(self.iosize)

def itransform_rec(m: ModuleLike, v: Record) -> Record:
s = Record.like(v)
m.d.comb += s.data.eq(v.data + 1)
def itransform_rec(m: TModule, arg: Record) -> Record:
s = Record.like(arg)
m.d.comb += s.data.eq(arg.data + 1)
return s

def otransform_rec(m: ModuleLike, v: Record) -> Record:
s = Record.like(v)
m.d.comb += s.data.eq(v.data - 1)
def otransform_rec(m: TModule, arg: Record) -> Record:
s = Record.like(arg)
m.d.comb += s.data.eq(arg.data - 1)
return s

def itransform_dict(_, v: Record) -> RecordDict:
return {"data": v.data + 1}
def itransform_dict(data: Value) -> RecordDict:
return {"data": data + 1}

def otransform_dict(_, v: Record) -> RecordDict:
return {"data": v.data - 1}
def otransform_dict(data: Value) -> RecordDict:
return {"data": data - 1}

if self.use_dicts:
itransform = itransform_dict
Expand All @@ -388,16 +388,13 @@ def otransform_dict(_, v: Record) -> RecordDict:
m.submodules.target = self.target = TestbenchIO(Adapter(i=layout, o=layout))

if self.use_methods:
assert self.use_dicts

imeth = Method(i=layout, o=layout)
ometh = Method(i=layout, o=layout)

@def_method(m, imeth)
def _(arg: Record):
return itransform(m, arg)

@def_method(m, ometh)
def _(arg: Record):
return otransform(m, arg)
def_method(m, imeth)(itransform)
def_method(m, ometh)(otransform)

trans = MethodTransformer(
self.target.adapter.iface, i_transform=(layout, imeth), o_transform=(layout, ometh)
Expand Down Expand Up @@ -483,8 +480,8 @@ def test_method_filter_with_methods(self):
def test_method_filter(self):
self.initialize()

def condition(_, v):
return v[0]
def condition(data: Value):
return data[0]

self.tc = SimpleTestCircuit(MethodFilter(self.target.adapter.iface, condition))
m = ModuleConnector(test_circuit=self.tc, target=self.target)
Expand Down Expand Up @@ -515,7 +512,7 @@ def elaborate(self, platform):

combiner = None
if self.add_combiner:
combiner = (layout, lambda _, vs: {"data": sum(vs)})
combiner = (layout, lambda vs: {"data": sum(vs)})

m.submodules.product = product = MethodProduct(methods, combiner)

Expand Down Expand Up @@ -702,7 +699,7 @@ def elaborate(self, platform):

combiner = None
if self.add_combiner:
combiner = (layout, lambda _, vs: {"data": sum(Mux(s, r, 0) for (s, r) in vs)})
combiner = (layout, lambda vs: {"data": sum(Mux(s, r, 0) for (s, r) in vs)})

m.submodules.product = product = MethodTryProduct(methods, combiner)

Expand Down
26 changes: 15 additions & 11 deletions transactron/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import itertools
import sys
import functools
from inspect import Parameter, signature
from typing import Any, Concatenate, Optional, TypeAlias, TypeGuard, TypeVar
from typing import Concatenate, Optional, ParamSpec, TypeAlias, TypeGuard, TypeVar
from collections.abc import Callable, Iterable, Mapping
from amaranth import *
from coreblocks.utils._typing import LayoutLike
Expand All @@ -16,12 +17,14 @@
"GraphCC",
"get_caller_class_name",
"def_helper",
"method_def_helper",
"bind_first_param",
]


T = TypeVar("T")
U = TypeVar("U")
P = ParamSpec("P")
CallableOptParam: TypeAlias = Callable[Concatenate[U, P], T] | Callable[P, T]


class Scheduler(Elaboratable):
Expand Down Expand Up @@ -124,7 +127,9 @@ def _graph_ccs(gr: ROGraph[T]) -> list[GraphCC[T]]:
MethodLayout: TypeAlias = LayoutLike


def has_first_param(func: Callable[..., T], name: str, tp: type[U]) -> TypeGuard[Callable[Concatenate[U, ...], T]]:
def has_first_param(
func: CallableOptParam[U, P, T], name: str, tp: type[U]
) -> TypeGuard[Callable[Concatenate[U, P], T]]:
parameters = signature(func).parameters
return (
len(parameters) >= 1
Expand All @@ -134,6 +139,13 @@ def has_first_param(func: Callable[..., T], name: str, tp: type[U]) -> TypeGuard
)


def bind_first_param(func: CallableOptParam[U, P, T], name: str, tp: type[U], arg: U) -> Callable[P, T]:
if has_first_param(func, name, tp):
return functools.partial(func, arg) # type: ignore
else:
return func # type: ignore


def def_helper(description, func: Callable[..., T], tp: type[U], arg: U, /, **kwargs) -> T:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't be func annotated as CallableOptParam? I have the same question regarding mock_def_helper and method_def_helper.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could, but it basically changes nothing because of the ... parameter specification.

parameters = signature(func).parameters
kw_parameters = set(
Expand All @@ -147,14 +159,6 @@ def def_helper(description, func: Callable[..., T], tp: type[U], arg: U, /, **kw
raise TypeError(f"Invalid {description}: {func}")


def mock_def_helper(tb, func: Callable[..., T], arg: Mapping[str, Any]) -> T:
return def_helper(f"mock definition for {tb}", func, Mapping[str, Any], arg, **arg)


def method_def_helper(method, func: Callable[..., T], arg: Record) -> T:
return def_helper(f"method definition for {method}", func, Record, arg, **arg.fields)


def get_caller_class_name(default: Optional[str] = None) -> tuple[Optional[Elaboratable], str]:
caller_frame = sys._getframe(2)
if "self" in caller_frame.f_locals:
Expand Down
5 changes: 5 additions & 0 deletions transactron/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
]


T = TypeVar("T")
TransactionGraph: TypeAlias = Graph["Transaction"]
TransactionGraphCC: TypeAlias = GraphCC["Transaction"]
PriorityOrder: TypeAlias = dict["Transaction", int]
Expand Down Expand Up @@ -1162,6 +1163,10 @@ def debug_signals(self) -> SignalBundle:
return [self.ready, self.run, self.data_in, self.data_out]


def method_def_helper(method: Method, func: Callable[..., T], arg: Record) -> T:
return def_helper(f"method definition for {method}", func, Record, arg, **arg.fields)


def def_method(m: TModule, method: Method, ready: ValueLike = C(1)):
"""Define a method.

Expand Down
68 changes: 41 additions & 27 deletions transactron/lib/transformers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from amaranth import *
from ..core import *
from ..core import RecordDict
from typing import Optional
from typing import Optional, ParamSpec, TypeAlias, TypeVar
from collections.abc import Callable
from coreblocks.utils import ValueLike, assign, AssignType
from transactron._utils import def_helper, bind_first_param, CallableOptParam
from .connectors import Forwarder, ManyToOneConnectTrans, ConnectTrans

__all__ = [
Expand All @@ -17,6 +18,19 @@
]


T = TypeVar("T")
P = ParamSpec("P")
CallableOptTModule: TypeAlias = CallableOptParam[TModule, P, T]


def bind_tmodule(m: TModule, func: CallableOptTModule[P, T]) -> Callable[P, T]:
return bind_first_param(func, "m", TModule, m)


def transformer_helper(tr, m: TModule, func: Callable[..., T], arg: Record) -> T:
return def_helper(f"function for {tr}", bind_tmodule(m, func), Record, arg, **arg.fields)


class MethodTransformer(Elaboratable):
"""Method transformer.

Expand All @@ -36,8 +50,8 @@ def __init__(
self,
target: Method,
*,
i_transform: Optional[tuple[MethodLayout, Callable[[TModule, Record], RecordDict]]] = None,
o_transform: Optional[tuple[MethodLayout, Callable[[TModule, Record], RecordDict]]] = None,
i_transform: Optional[tuple[MethodLayout, Callable[..., RecordDict]]] = None,
o_transform: Optional[tuple[MethodLayout, Callable[..., RecordDict]]] = None,
):
"""
Parameters
Expand All @@ -54,9 +68,9 @@ def __init__(
If not present, output is not transformed.
"""
if i_transform is None:
i_transform = (target.data_in.layout, lambda _, x: x)
i_transform = (target.data_in.layout, lambda arg: arg)
if o_transform is None:
o_transform = (target.data_out.layout, lambda _, x: x)
o_transform = (target.data_out.layout, lambda arg: arg)

self.target = target
self.method = Method(i=i_transform[0], o=o_transform[0])
Expand All @@ -68,7 +82,7 @@ def elaborate(self, platform):

@def_method(m, self.method)
def _(arg):
return self.o_fun(m, self.target(m, self.i_fun(m, arg)))
return transformer_helper(self, m, self.o_fun, self.target(m, transformer_helper(self, m, self.i_fun, arg)))

return m

Expand All @@ -91,9 +105,7 @@ class MethodFilter(Elaboratable):
The transformed method.
"""

def __init__(
self, target: Method, condition: Callable[[TModule, Record], ValueLike], default: Optional[RecordDict] = None
):
def __init__(self, target: Method, condition: Callable[..., ValueLike], default: Optional[RecordDict] = None):
"""
Parameters
----------
Expand Down Expand Up @@ -122,7 +134,7 @@ def elaborate(self, platform):

@def_method(m, self.method)
def _(arg):
with m.If(self.condition(m, arg)):
with m.If(transformer_helper(self, m, self.condition, arg)):
m.d.comb += ret.eq(self.target(m, arg))
return ret

Expand All @@ -133,7 +145,7 @@ class MethodProduct(Elaboratable):
def __init__(
self,
targets: list[Method],
combiner: Optional[tuple[MethodLayout, Callable[[TModule, list[Record]], RecordDict]]] = None,
combiner: Optional[tuple[MethodLayout, CallableOptTModule[[list[Record]], RecordDict]]] = None,
):
"""Method product.

Expand All @@ -148,18 +160,19 @@ def __init__(
----------
targets: list[Method]
A list of methods to be called.
combiner: (int or method layout, function), optional
combiner: (record layout, function), optional
A pair of the output layout and the combiner function. The
combiner function takes two parameters: a `Module` and
a list of outputs of the target methods.
combiner function takes a list of outputs of the target methods
and returns the result. Optionally, it can also take a `TModule`
as a first argument named `m`.

Attributes
----------
method: Method
The product method.
"""
if combiner is None:
combiner = (targets[0].data_out.layout, lambda _, x: x[0])
combiner = (targets[0].data_out.layout, lambda x: x[0])
self.targets = targets
self.combiner = combiner
self.method = Method(i=targets[0].data_in.layout, o=combiner[0])
Expand All @@ -172,7 +185,7 @@ def _(arg):
results = []
for target in self.targets:
results.append(target(m, arg))
return self.combiner[1](m, results)
return bind_tmodule(m, self.combiner[1])(results)

return m

Expand All @@ -181,7 +194,7 @@ class MethodTryProduct(Elaboratable):
def __init__(
self,
targets: list[Method],
combiner: Optional[tuple[MethodLayout, Callable[[TModule, list[tuple[Value, Record]]], RecordDict]]] = None,
combiner: Optional[tuple[MethodLayout, CallableOptTModule[[list[tuple[Value, Record]]], RecordDict]]] = None,
):
"""Method product with optional calling.

Expand All @@ -196,19 +209,20 @@ def __init__(
----------
targets: list[Method]
A list of methods to be called.
combiner: (int or method layout, function), optional
combiner: (record layout, function), optional
A pair of the output layout and the combiner function. The
combiner function takes two parameters: a `Module` and
a list of pairs. Each pair contains a bit which signals
that a given call succeeded, and the result of the call.
combiner function takes a list of pairs such that the first
element is a bit which signals that a given call succeeded,
and the second is the result of the call. Optionally, it can
also take a `TModule` as a first argument named `m`.

Attributes
----------
method: Method
The product method.
"""
if combiner is None:
combiner = ([], lambda _, __: {})
combiner = ([], lambda arg: {})
self.targets = targets
self.combiner = combiner
self.method = Method(i=targets[0].data_in.layout, o=combiner[0])
Expand All @@ -224,7 +238,7 @@ def _(arg):
with Transaction().body(m):
m.d.comb += success.eq(1)
results.append((success, target(m, arg)))
return self.combiner[1](m, results)
return bind_tmodule(m, self.combiner[1])(results)

return m

Expand Down Expand Up @@ -323,8 +337,8 @@ def __init__(
method1: Method,
method2: Method,
*,
i_fun: Optional[Callable[[TModule, Record], RecordDict]] = None,
o_fun: Optional[Callable[[TModule, Record], RecordDict]] = None,
i_fun: Optional[Callable[..., RecordDict]] = None,
o_fun: Optional[Callable[..., RecordDict]] = None,
):
"""
Parameters
Expand All @@ -340,8 +354,8 @@ def __init__(
"""
self.method1 = method1
self.method2 = method2
self.i_fun = i_fun or (lambda _, x: x)
self.o_fun = o_fun or (lambda _, x: x)
self.i_fun = i_fun or (lambda arg: arg)
self.o_fun = o_fun or (lambda arg: arg)

def elaborate(self, platform):
m = TModule()
Expand Down