diff --git a/mathics/core/expression.py b/mathics/core/expression.py index f407f9957..3449677ae 100644 --- a/mathics/core/expression.py +++ b/mathics/core/expression.py @@ -24,13 +24,19 @@ attribute_string_to_number, ) from mathics.core.convert.python import from_python -from mathics.core.element import ElementsProperties, EvalMixin, ensure_context + +# from mathics.core.convert.sympy import SympyExpression, sympy_symbol_prefix +from mathics.core.element import ( + BaseElement, + ElementsProperties, + EvalMixin, + ensure_context, +) from mathics.core.evaluation import Evaluation from mathics.core.interrupt import ReturnInterrupt from mathics.core.structure import LinkedStructure from mathics.core.symbols import ( Atom, - BaseElement, Monomial, NumericOperators, Symbol, @@ -625,31 +631,30 @@ def flatten_with_respect_to_head( if self._does_not_contain_symbol(head.get_name()): return self sub_level = level - 1 - do_flatten = False - for element in self._elements: - if element.get_head().sameQ(head) and ( + indx_to_flatten = [] + for idx, element in enumerate(self._elements): + if ( not pattern_only or element.pattern_sequence - ): - do_flatten = True - break - if do_flatten: - new_elements = [] - for element in self._elements: - if element.get_head().sameQ(head) and ( - not pattern_only or element.pattern_sequence - ): - new_element = element.flatten_with_respect_to_head( - head, pattern_only, callback, level=sub_level - ) - if callback is not None: - callback(new_element._elements, element) - new_elements.extend(new_element._elements) - else: - new_elements.append(element) - return to_expression_with_specialization(self._head, *new_elements) - else: + ) and element.get_head().sameQ(head): + indx_to_flatten.append(idx) + + if len(indx_to_flatten) == 0: return self + new_elements = [] + for idx, element in enumerate(self._elements): + if len(indx_to_flatten) > 0 and idx == indx_to_flatten[0]: + indx_to_flatten.pop(0) + new_element = element.flatten_with_respect_to_head( + head, pattern_only, callback, level=sub_level + ) + if callback is not None: + callback(new_element._elements, element) + new_elements.extend(new_element._elements) + else: + new_elements.append(element) + return to_expression_with_specialization(self._head, *new_elements) + def get_atoms(self, include_heads=True): """Returns a list of atoms involved in the expression.""" # Comment @mmatera: maybe, what we really want here are the Symbol's @@ -1065,36 +1070,99 @@ def rewrite_apply_eval_step(self, evaluation) -> Tuple["Expression", bool]: See also https://mathics-development-guide.readthedocs.io/en/latest/extending/code-overview/evaluation.html#detailed-rewrite-apply-eval-process """ - # Step 1 : evaluate the Head and get its Attributes. These attributes, - # used later, include: HoldFirst / HoldAll / HoldRest / HoldAllComplete. + # Internal class and functions to handle ``Unevaluated`` elements + class UnevaluatedWrapper(BaseElement): + """ + This class is used to wrap the argument of + elements of the form ``Unevaluated[expr_]``, + to provide the right behaviour under + ``flatten_with_respect_to_head``, sort and thread. + The wrapper is removed before step of looking and applying + rules. + If no rule is successfully applied, then the wrapper is converted + again into an expression of the form Unevaluated(expr_) + """ + + def __init__(self, expr: BaseElement): + self.expr = expr + self.elements_properties = ElementsProperties(True, True, True) + + def get_head(self): + return self.expr.get_head() + + def __repr__(self): + return f"<>" + + def get_sort_key(self, pattern_sort=False): + # this ensures that when the elements of an expression are + # sorted, elements tagged follows the corresponding untagged + # elements. + return self.expr.get_sort_key(pattern_sort) + ("Unevaluated",) + + @property + def is_literal(self): + return False + + def flatten_with_respect_to_head( + self, head, pattern_only=False, callback=None, level=100 + ): + flatten_expr = self.expr.flatten_with_respect_to_head( + head=head, pattern_only=pattern_only, callback=callback, level=level + ) + # distribute the tag over the elements. + marked_elements = tuple( + UnevaluatedWrapper(element) for element in flatten_expr._elements + ) + return Expression(self.expr._head, *marked_elements) + + def strip_unevaluated_wrapper(expr_with_wrappers): + items = ( + element.expr if isinstance(element, UnevaluatedWrapper) else element + for element in expr_with_wrappers._elements + ) + return Expression(expr_with_wrappers._head, *items) + + def restore_unevaluated_from_wrapper(expr_with_wrappers): + items = ( + Expression(SymbolUnevaluated, element.expr) + if isinstance(element, UnevaluatedWrapper) + else element + for element in expr_with_wrappers._elements + ) + return Expression(expr_with_wrappers._head, *items) + + # Step 1 : evaluate the Head and get its Attributes. These attributes, used later, include + # HoldFirst / HoldAll / HoldRest / HoldAllComplete. # Note: self._head can be not just a symbol, but some arbitrary expression. # This is what makes expressions in Mathics be M-expressions rather than # S-expressions. head = self._head.evaluate(evaluation) - attributes = head.get_attributes(evaluation.definitions) + contains_unevaluated_wrapper = False if self.elements_properties is None: self._build_elements_properties() # @timeit def eval_elements(): - nonlocal recompute_properties + nonlocal recompute_properties, contains_unevaluated_wrapper # @timeit def eval_range(indices): - nonlocal recompute_properties + nonlocal recompute_properties, contains_unevaluated_wrapper recompute_properties = False for index in indices: element = elements[index] - if not element.has_form("Unevaluated", 1): - if isinstance(element, EvalMixin): - new_value = element.evaluate(evaluation) - # We need id() because != by itself is too permissive - if id(element) != id(new_value): - recompute_properties = True - elements[index] = new_value + if element.has_form("System`Unevaluated", 1): + contains_unevaluated_wrapper = True + elements[index] = UnevaluatedWrapper(element._elements[0]) + elif isinstance(element, EvalMixin): + new_value = element.evaluate(evaluation) + # We need id() because != by itself is too permissive + if id(element) != id(new_value): + recompute_properties = True + elements[index] = new_value # @timeit def rest_range(indices): @@ -1111,6 +1179,9 @@ def rest_range(indices): if id(new_value) != id(element): elements[index] = new_value recompute_properties = True + elif element.has_form("System`Unevaluated", 1): + contains_unevaluated_wrapper = True + elements[index] = UnevaluatedWrapper(element._elements[0]) if (A_HOLD_ALL | A_HOLD_ALL_COMPLETE) & attributes: # eval_range(range(0, 0)) @@ -1161,53 +1232,8 @@ def rest_range(indices): new._build_elements_properties() elements = new._elements - # comment @mmatera: I think this is wrong now, because alters - # singletons... (see PR #58) The idea is to mark which elements was - # marked as "Unevaluated" Also, this consumes time for long lists, and - # is useful just for a very unfrequent expressions, involving - # `Unevaluated` elements. Notice also that this behaviour is broken - # when the argument of "Unevaluated" is a symbol (see comment and tests - # in test/test_unevaluate.py) - - for element in elements: - element.unevaluated = False - - # If HoldAllComplete Attribute (flag ``A_HOLD_ALL_COMPLETE``) is not set, - # and the expression has elements of the form `Unevaluated[element]` - # change them to `element` and set a flag `unevaluated=True` - # If the evaluation fails, use this flag to restore back the initial form - # Unevaluated[element] - - # comment @mmatera: - # what we need here is some way to track which elements are marked as - # Unevaluated, that propagates by flatten, and at the end, - # to recover a list of positions that (eventually) - # must be marked again as Unevaluated. - - if not A_HOLD_ALL_COMPLETE & attributes: - dirty_elements = None - - for index, element in enumerate(elements): - if element.has_form("Unevaluated", 1): - if dirty_elements is None: - dirty_elements = list(elements) - dirty_elements[index] = element._elements[0] - dirty_elements[index].unevaluated = True - - if dirty_elements: - new = Expression(head, *dirty_elements) - elements = dirty_elements - new._build_elements_properties() - - # If the Attribute ``Flat`` (flag ``A_FLAT``) is set, calls - # flatten with a callback that set elements as unevaluated - # too. - def flatten_callback(new_elements, old): - for element in new_elements: - element.unevaluated = old.unevaluated - if A_FLAT & attributes: - new = new.flatten_with_respect_to_head(new._head, callback=flatten_callback) + new = new.flatten_with_respect_to_head(new._head) if new.elements_properties is None: new._build_elements_properties() @@ -1239,12 +1265,21 @@ def flatten_callback(new_elements, old): # threading. Still, we need to perform this rewrite to # maintain correct semantic behavior. if A_LISTABLE & attributes: + # TODO: Check how Unevaluated works here done, threaded = new.thread(evaluation) if done: + # if contains_unevaluated_wrapper: + # new = restore_unevaluated_from_wrapper(new) + # threaded = restore_unevaluated_from_wrapper(new) + if threaded.sameQ(new): + if contains_unevaluated_wrapper: + new = restore_unevaluated_from_wrapper(new) new._timestamp_cache(evaluation) return new, False else: + if contains_unevaluated_wrapper: + threaded = restore_unevaluated_from_wrapper(threaded) return threaded, True # Step 6: @@ -1283,6 +1318,10 @@ def flatten_callback(new_elements, old): # In Mathics3, the result in "fish", but WL gives "5". # This shows that WMA evaluates certain symbols differently. + if contains_unevaluated_wrapper: + wrapped_new = new + new = strip_unevaluated_wrapper(new) + def rules(): rules_names = set() if not A_HOLD_ALL_COMPLETE & attributes: @@ -1321,21 +1360,9 @@ def rules(): else: return result, True - # Step 7: If we are here, is because we didn't find any rule that - # matches the expression. - - dirty_elements = None - - # Expression did not change, re-apply Unevaluated - for index, element in enumerate(new._elements): - if element.unevaluated: - if dirty_elements is None: - dirty_elements = list(new._elements) - dirty_elements[index] = Expression(SymbolUnevaluated, element) - - if dirty_elements: - new = Expression(head) - new.elements = dirty_elements + # Step 7: If we are here, is because we didn't find any rule that matches with the expression. + if contains_unevaluated_wrapper: + new = restore_unevaluated_from_wrapper(wrapped_new) # Step 8: Update the cache. Return the new compound Expression and # indicate that no further evaluation is needed. diff --git a/test/test_evaluation.py b/test/test_evaluation.py index 14a6c4e10..55374edf7 100644 --- a/test/test_evaluation.py +++ b/test/test_evaluation.py @@ -62,55 +62,66 @@ def test_evaluation(str_expr: str, str_expected: str, message=""): assert result == expected -# We skip this test because this behaviour is currently -# broken @pytest.mark.parametrize( "str_setup,str_expr,str_expected,message", [ ( - "F[x___Real]:=List[x]^2; a=.4;", + "F[x_, y_,z_]:={Length[x], Length[y], Length[z]}; a:={1,1}", + "F[Unevaluated[a],a,Unevaluated[a]]", + "{2,2,2}", + "evaluated as {Length[a], Length[{1,1}], Length[a]}", + ), + ( + "F[x_, y_,z_]:={HoldForm[x], HoldForm[y], HoldForm[z]}; a:={1,1}", + "F[Unevaluated[a],a,Unevaluated[a]]", + "{HoldForm[a], {1, 1}, HoldForm[a]}", + "evaluated as {HoldForm[a], HoldForm[{1,1}], HoldForm[a]}", + ), + ( + "ClearAll[a,F]; F[x_Symbol, y_Real, z_Symbol] := {x^2,y^2,z^2};a=4.;", "F[Unevaluated[a], a, Unevaluated[a]]", - "F[Unevaluated[a], 0.4, Unevaluated[a]]", - None, + "{16., 16.,16.}", + ( + "Here the definition matchec because the first and last parameters are" + "keep unevaluated." + ), ), ( - "F[x___Real]:=List[x]^2; a=.4;", - "F[Unevaluated[b], b, Unevaluated[b]]", - "F[Unevaluated[b], b, Unevaluated[b]]", - "the second argument shouldn't be ``Unevaluated[b]``", + "ClearAll[a, b, F]; Attributes[F]=Flat;a=4.;", + "F[Unevaluated[a], a, F[b,1], Unevaluated[F[b,a]]]", + "F[Unevaluated[a], 4., b, 1, Unevaluated[b], Unevaluated[a]]", + "If F does not have a pattern that matches, keeps the unevaluated elements", ), ( - "G[x___Symbol]:=List[x]^2; a=.4;", - "G[Unevaluated[a], a, Unevaluated[a]]", - "F[Unevaluated[a], 0.4, Unevaluated[a]]", - None, + "ClearAll[a, b, F]; Attributes[F]={Orderless, Flat};a=4.;", + "F[Unevaluated[a], a, F[b,1], Unevaluated[F[b,a]]]", + "F[1, 4., Unevaluated[a], Unevaluated[a], b, Unevaluated[b]]", + "the same, with orderless. Unevaluated[expr] comes right after than expr.", ), ( - "G[x___Symbol]:=List[x]^2; a=.4;", - "G[Unevaluated[b], b, Unevaluated[b]]", - "F[Unevaluated[b], b, Unevaluated[b]]", - "the second argument shouldn't be ``Unevaluated[b]``", + "ClearAll[a, b, F,G]; Attributes[F]=Flat;a=4.;G[x_,y_]:=0", + "F[Unevaluated[a], a, G[b,1], Unevaluated[G[b,a]]]", + "F[Unevaluated[a], 4., 0, Unevaluated[G[b,a]]]", + "G is evaluated", ), + # ( + # "ClearAll[a, b, F,G]; Attributes[F]=Flat;a=4.;F[x_,y__]:={x,y}", + # "F[Unevaluated[a], a, G[b,1], Unevaluated[G[b,a]]]", + # "{F[4.], 4., G[b, 1], G[b, 4.]}", + # "Since F is successfully evaluated, Unevaluated is removed.", + # ), ( - "a =.; F[a, x_Real, a] := List[x]^2;a=4.;", - "F[Unevaluated[a], a, Unevaluated[a]]", - "{16.}", - "Here, the second ``a`` is kept unevaluated because of a bug.", + "ClearAll[a, b, F,G]; Attributes[F]=HoldFirst;a=4.;F[x_,y__]:={Hold[x],Hold[y]}", + "F[Unevaluated[a], a, G[b,1], Unevaluated[F[b,a]]]", + "{Hold[a], Hold[4., G[b, 1], F[b, a]]}", + "Since F is successfully evaluated, Unevaluated is removed.", ), ], ) -@pytest.mark.skip( - reason="the right behaviour was broken since we start to use Symbol as singleton, to speedup comparisons." -) def test_unevaluate(str_setup, str_expr, str_expected, message): if str_setup: evaluate(str_setup) - result = evaluate(str_expr) - expected = evaluate(str_expected) - if message: - assert result == expected, message - else: - assert result == expected + check_evaluation(str_expr, str_expected, message) @pytest.mark.parametrize(