diff --git a/pyecoreocl/compiler.py b/pyecoreocl/compiler.py index 32686ba..f87043a 100644 --- a/pyecoreocl/compiler.py +++ b/pyecoreocl/compiler.py @@ -290,7 +290,7 @@ def call_rule(fun): @call_rule -def rule_collect(emitter, ctx): +def rule_collect_nested(emitter, ctx): emitter.inline("(") emitter.visit(ctx.argExp().oclExp()) variables = [arg.text for arg in ctx.argExp().varnames] @@ -304,6 +304,21 @@ def rule_collect(emitter, ctx): emitter.inline(")") +@call_rule +def rule_collect(emitter, ctx): + emitter.inline("ocl.flatten(") + emitter.visit(ctx.argExp().oclExp()) + variables = [arg.text for arg in ctx.argExp().varnames] + emitter.inline(f" for {', '.join(variables)} in ") + if len(variables) > 1: + emitter.inline("itertools.combinations_with_replacement(") + emitter.visit(ctx.expression) + emitter.inline(f", {len(variables)})") + else: + emitter.visit(ctx.expression) + emitter.inline(")") + + @call_rule def rule_for_all(emitter, ctx): emitter.inline(f"all(") diff --git a/pyecoreocl/runtime.py b/pyecoreocl/runtime.py index d42e2f6..3ec7bd1 100644 --- a/pyecoreocl/runtime.py +++ b/pyecoreocl/runtime.py @@ -11,7 +11,17 @@ def is_unique(lst): seen = [] return not any(i in seen or seen.append(i) for i in lst) + def geta(o, attr): if isinstance(o, Iterable): return (getattr(e, attr) for e in o) - return getattr(o, attr) \ No newline at end of file + return getattr(o, attr) + + +def flatten(items): + for x in items: + if isinstance(x, Iterable) and not isinstance(x, (str, bytes)): + for sub_x in flatten(x): + yield sub_x + else: + yield x diff --git a/tests/test_collection_lib.py b/tests/test_collection_lib.py index d308152..644d6a7 100644 --- a/tests/test_collection_lib.py +++ b/tests/test_collection_lib.py @@ -49,6 +49,23 @@ def test__collect(): l = !Sequence{Tuple{value=3}, Tuple{value=4}, Tuple{value='stuff'}}! assert !l->collect(e | e.value)->asSequence()! == [3, 4, "stuff"] + l = [3, 4, 5] + assert !l->collect(e: int | Sequence{e})->asSequence()! == [3, 4, 5] + + +def test__collect_nested(): + l = [3, 4, 5] + assert !l->collectNested(e: int | e + 1)->asSequence()! == [4, 5, 6] + + l = !Sequence{Tuple{value=3}, Tuple{value=4}, Tuple{value='stuff'}}! + assert !l->collectNested(e | e.value)->asSequence()! == [3, 4, "stuff"] + + l = [3, 4, 5] + assert !l->collectNested(e: int | Sequence{e})->asSequence()! == [[3], [4], [5]] + + l = [3, 4, 5] + assert !l->collectNested(e: int | Sequence{Sequence{e}})->asSequence()! == [[[3]], [[4]], [[5]]] + def test__any(): l = [3, 4, 5]