Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle transitive calls from a nonexclusive method #719

Merged
merged 5 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion test/transactron/test_methods.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Sequence
from collections.abc import Callable, Sequence
import pytest
import random
from amaranth import *
Expand Down Expand Up @@ -559,6 +559,58 @@ def process():
sim.add_sync_process(process)


class TwoNonexclusiveConflictCircuit(Elaboratable):
def __init__(self, two_nonexclusive: bool):
self.two_nonexclusive = two_nonexclusive

def elaborate(self, platform):
m = TModule()

self.running1 = Signal()
self.running2 = Signal()

method1 = Method(o=data_layout(WIDTH), nonexclusive=True)
method2 = Method(o=data_layout(WIDTH), nonexclusive=self.two_nonexclusive)
method_in = Method(o=data_layout(WIDTH))

@def_method(m, method_in)
def _():
return {"data": 0}

@def_method(m, method1)
def _():
m.d.comb += self.running1.eq(1)
return method_in(m)

@def_method(m, method2)
def _():
m.d.comb += self.running2.eq(1)
return method_in(m)

m.submodules.t1 = self.t1 = TestbenchIO(AdapterTrans(method1))
m.submodules.t2 = self.t2 = TestbenchIO(AdapterTrans(method2))

return m


class TestConflicting(TestCaseWithSimulator):
@pytest.mark.parametrize(
"test_circuit", [lambda: TwoNonexclusiveConflictCircuit(False), lambda: TwoNonexclusiveConflictCircuit(True)]
)
def test_conflicting(self, test_circuit: Callable[[], TwoNonexclusiveConflictCircuit]):
circ = test_circuit()

def process():
yield from circ.t1.enable()
yield from circ.t2.enable()
yield Settle()

assert not (yield circ.running1) or not (yield circ.running2)

with self.run_simulation(circ) as sim:
sim.add_sync_process(process)


class CustomCombinerMethodCircuit(Elaboratable):
def elaborate(self, platform):
m = TModule()
Expand Down
28 changes: 19 additions & 9 deletions transactron/core/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,25 @@ class MethodMap:
def __init__(self, transactions: Iterable["Transaction"]):
self.methods_by_transaction = dict[Transaction, list[Method]]()
self.transactions_by_method = defaultdict[Method, list[Transaction]](list)
self.readiness_by_method_and_transaction = dict[tuple[Transaction, Method], ValueLike]()
self.readiness_by_call = dict[tuple[Transaction, Method], ValueLike]()
self.ancestors_by_call = dict[tuple[Transaction, Method], tuple[Method, ...]]()
self.method_parents = defaultdict[Method, list[TransactionBase]](list)

def rec(transaction: Transaction, source: TransactionBase):
def rec(transaction: Transaction, source: TransactionBase, ancestors: tuple[Method, ...]):
for method, (arg_rec, _) in source.method_uses.items():
if not method.defined:
raise RuntimeError(f"Trying to use method '{method.name}' which is not defined yet")
if method in self.methods_by_transaction[transaction]:
raise RuntimeError(f"Method '{method.name}' can't be called twice from the same transaction")
self.methods_by_transaction[transaction].append(method)
self.transactions_by_method[method].append(transaction)
self.readiness_by_method_and_transaction[(transaction, method)] = method._validate_arguments(arg_rec)
rec(transaction, method)
self.readiness_by_call[(transaction, method)] = method._validate_arguments(arg_rec)
self.ancestors_by_call[(transaction, method)] = new_ancestors = (method, *ancestors)
rec(transaction, method, new_ancestors)

for transaction in transactions:
self.methods_by_transaction[transaction] = []
rec(transaction, transaction)
rec(transaction, transaction, ())

for transaction_or_method in self.methods_and_transactions:
for method in transaction_or_method.method_uses.keys():
Expand Down Expand Up @@ -127,6 +129,12 @@ def transactions_exclusive(trans1: Transaction, trans2: Transaction):

return False

def calls_nonexclusive(trans1: Transaction, trans2: Transaction, method: Method):
ancestors1 = method_map.ancestors_by_call[(trans1, method)]
ancestors2 = method_map.ancestors_by_call[(trans2, method)]
common_ancestors = longest_common_prefix(ancestors1, ancestors2)
return common_ancestors[-1].nonexclusive

cgr: TransactionGraph = {} # Conflict graph
pgr: TransactionGraph = {} # Priority graph

Expand All @@ -145,11 +153,13 @@ def add_edge(begin: Transaction, end: Transaction, priority: Priority, conflict:
pgr[transaction] = set()

for method in method_map.methods:
if method.nonexclusive:
continue
for transaction1 in method_map.transactions_for(method):
for transaction2 in method_map.transactions_for(method):
if transaction1 is not transaction2 and not transactions_exclusive(transaction1, transaction2):
if (
transaction1 is not transaction2
and not transactions_exclusive(transaction1, transaction2)
and not calls_nonexclusive(transaction1, transaction2, method)
):
add_edge(transaction1, transaction2, Priority.UNDEFINED, True)

relations = [
Expand Down Expand Up @@ -328,7 +338,7 @@ def elaborate(self, platform):

for transaction in self.transactions:
ready = [
method_map.readiness_by_method_and_transaction[transaction, method]
method_map.readiness_by_call[transaction, method]
for method in method_map.methods_by_transaction[transaction]
]
m.d.comb += transaction.runnable.eq(Cat(ready).all())
Expand Down
12 changes: 11 additions & 1 deletion transactron/utils/transactron_helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
from contextlib import contextmanager
from typing import Optional, Any, Concatenate, TypeGuard, TypeVar
from collections.abc import Callable, Mapping
from collections.abc import Callable, Mapping, Sequence
from ._typing import ROGraph, GraphCC, SrcLoc, MethodLayout, MethodStruct, ShapeLike, LayoutList, LayoutListField
from inspect import Parameter, signature
from itertools import count
Expand All @@ -11,6 +11,7 @@


__all__ = [
"longest_common_prefix",
"silence_mustuse",
"get_caller_class_name",
"def_helper",
Expand Down Expand Up @@ -60,6 +61,15 @@ def _graph_ccs(gr: ROGraph[T]) -> list[GraphCC[T]]:
return ccs


def longest_common_prefix(*seqs: Sequence[T]) -> Sequence[T]:
if not seqs:
raise ValueError("no arguments")
for i, letter_group in enumerate(zip(*seqs)):
if len(set(letter_group)) > 1:
return seqs[0][:i]
return min(seqs, key=lambda s: len(s))


def has_first_param(func: Callable[..., T], name: str, tp: type[U]) -> TypeGuard[Callable[Concatenate[U, ...], T]]:
parameters = signature(func).parameters
return (
Expand Down