diff --git a/pytensor/compile/function/types.py b/pytensor/compile/function/types.py index ed01daaa61..20bc844f31 100644 --- a/pytensor/compile/function/types.py +++ b/pytensor/compile/function/types.py @@ -388,6 +388,9 @@ def __init__( self.nodes_with_inner_function = [] self.output_keys = output_keys + if self.output_keys is not None: + warnings.warn("output_keys is deprecated.", FutureWarning) + assert len(self.output_storage) == len(self.maker.fgraph.outputs) # See if we have any mutable / borrow inputs @@ -810,20 +813,15 @@ def __call__(self, *args, **kwargs): if ``output_subset`` is not passed. """ - def restore_defaults(): - for i, (required, refeed, value) in enumerate(self.defaults): - if refeed: - if isinstance(value, Container): - value = value.storage[0] - self[i] = value - profile = self.profile if profile: t0 = time.perf_counter() output_subset = kwargs.pop("output_subset", None) - if output_subset is not None and self.output_keys is not None: - output_subset = [self.output_keys.index(key) for key in output_subset] + if output_subset is not None: + warnings.warn("output_subset is deprecated.", FutureWarning) + if self.output_keys is not None: + output_subset = [self.output_keys.index(key) for key in output_subset] # Reinitialize each container's 'provided' counter if self.trust_input: @@ -1565,6 +1563,8 @@ def __init__( ) for i in self.inputs ] + if any(self.refeed): + warnings.warn("Inputs with default values are deprecated.", FutureWarning) def create(self, input_storage=None, storage_map=None): """ diff --git a/tests/compile/function/test_types.py b/tests/compile/function/test_types.py index 2508683998..20bf29cea6 100644 --- a/tests/compile/function/test_types.py +++ b/tests/compile/function/test_types.py @@ -35,6 +35,9 @@ ) +pytestmark = pytest.mark.filterwarnings("error") + + def PatternOptimizer(p1, p2, ign=True): return OpKeyGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign) @@ -195,7 +198,10 @@ def test_naming_rule3(self): x, s = scalars("xs") # x's name is not ignored (as in test_naming_rule2) because a has a default value. - f = function([x, In(a, value=1.0), s], a / s + x) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + f = function([x, In(a, value=1.0), s], a / s + x) assert f(9, 2, 4) == 9.5 # can specify all args in order assert f(9, 2, s=4) == 9.5 # can give s as kwarg assert f(9, s=4) == 9.25 # can give s as kwarg, get default a @@ -214,7 +220,10 @@ def test_naming_rule4(self): a = scalar() # the a is for 'anonymous' (un-named). x, s = scalars("xs") - f = function([x, In(a, value=1.0, name="a"), s], a / s + x) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + f = function([x, In(a, value=1.0, name="a"), s], a / s + x) assert f(9, 2, 4) == 9.5 # can specify all args in order assert f(9, 2, s=4) == 9.5 # can give s as kwarg @@ -248,11 +257,14 @@ def test_state_access(self, mode): a = scalar() x, s = scalars("xs") - f = function( - [x, In(a, value=1.0, name="a"), In(s, value=0.0, update=s + a * x)], - s + a * x, - mode=mode, - ) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + f = function( + [x, In(a, value=1.0, name="a"), In(s, value=0.0, update=s + a * x)], + s + a * x, + mode=mode, + ) assert f[a] == 1.0 assert f[s] == 0.0 @@ -303,16 +315,19 @@ def test_copy(self): a = scalar() x, s = scalars("xs") - f = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=0.0, update=s + a * x, mutable=True), - ], - s + a * x, - ) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + f = function( + [ + x, + In(a, value=1.0, name="a"), + In(s, value=0.0, update=s + a * x, mutable=True), + ], + s + a * x, + ) - g = copy.copy(f) + g = copy.copy(f) assert f.unpack_single == g.unpack_single assert f.trust_input == g.trust_input @@ -504,22 +519,25 @@ def test_shared_state0(self): a = scalar() # the a is for 'anonymous' (un-named). x, s = scalars("xs") - f = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=0.0, update=s + a * x, mutable=True), - ], - s + a * x, - ) - g = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=f.container[s], update=s - a * x, mutable=True), - ], - s + a * x, - ) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + f = function( + [ + x, + In(a, value=1.0, name="a"), + In(s, value=0.0, update=s + a * x, mutable=True), + ], + s + a * x, + ) + g = function( + [ + x, + In(a, value=1.0, name="a"), + In(s, value=f.container[s], update=s - a * x, mutable=True), + ], + s + a * x, + ) f(1, 2) assert f[s] == 2 @@ -532,17 +550,20 @@ def test_shared_state1(self): a = scalar() # the a is for 'anonymous' (un-named). x, s = scalars("xs") - f = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=0.0, update=s + a * x, mutable=True), - ], - s + a * x, - ) - g = function( - [x, In(a, value=1.0, name="a"), In(s, value=f.container[s])], s + a * x - ) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + f = function( + [ + x, + In(a, value=1.0, name="a"), + In(s, value=0.0, update=s + a * x, mutable=True), + ], + s + a * x, + ) + g = function( + [x, In(a, value=1.0, name="a"), In(s, value=f.container[s])], s + a * x + ) f(1, 2) assert f[s] == 2 @@ -556,17 +577,20 @@ def test_shared_state2(self): a = scalar() # the a is for 'anonymous' (un-named). x, s = scalars("xs") - f = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=0.0, update=s + a * x, mutable=False), - ], - s + a * x, - ) - g = function( - [x, In(a, value=1.0, name="a"), In(s, value=f.container[s])], s + a * x - ) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + f = function( + [ + x, + In(a, value=1.0, name="a"), + In(s, value=0.0, update=s + a * x, mutable=False), + ], + s + a * x, + ) + g = function( + [x, In(a, value=1.0, name="a"), In(s, value=f.container[s])], s + a * x + ) f(1, 2) assert f[s] == 2 @@ -718,7 +742,10 @@ def test_default_values(self): a, b = dscalars("a", "b") c = a + b - funct = function([In(a, name="first"), In(b, value=1, name="second")], c) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + funct = function([In(a, name="first"), In(b, value=1, name="second")], c) x = funct(first=1) try: funct(second=2) @@ -771,7 +798,8 @@ def test_output_dictionary(self): # Tests that function works when outputs is a dictionary x = scalar() - f = function([x], outputs={"a": x, "c": x * 2, "b": x * 3, "1": x * 4}) + with pytest.warns(FutureWarning, match="output_keys is deprecated."): + f = function([x], outputs={"a": x, "c": x * 2, "b": x * 3, "1": x * 4}) outputs = f(10.0) @@ -786,7 +814,8 @@ def test_input_named_variables(self): x = scalar("x") y = scalar("y") - f = function([x, y], outputs={"a": x + y, "b": x * y}) + with pytest.warns(FutureWarning, match="output_keys is deprecated."): + f = function([x, y], outputs={"a": x + y, "b": x * y}) assert f(2, 4) == {"a": 6, "b": 8} assert f(2, y=4) == f(2, 4) @@ -801,9 +830,10 @@ def test_output_order_sorted(self): e1 = scalar("1") e2 = scalar("2") - f = function( - [x, y, z, e1, e2], outputs={"x": x, "y": y, "z": z, "1": e1, "2": e2} - ) + with pytest.warns(FutureWarning, match="output_keys is deprecated."): + f = function( + [x, y, z, e1, e2], outputs={"x": x, "y": y, "z": z, "1": e1, "2": e2} + ) assert "1" in str(f.outputs[0]) assert "2" in str(f.outputs[1]) @@ -821,7 +851,8 @@ def test_composing_function(self): a = x + y b = x * y - f = function([x, y], outputs={"a": a, "b": b}) + with pytest.warns(FutureWarning, match="output_keys is deprecated."): + f = function([x, y], outputs={"a": a, "b": b}) a = scalar("a") b = scalar("b") @@ -876,14 +907,17 @@ def test_deepcopy(self): a = scalar() # the a is for 'anonymous' (un-named). x, s = scalars("xs") - f = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=0.0, update=s + a * x, mutable=True), - ], - s + a * x, - ) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + f = function( + [ + x, + In(a, value=1.0, name="a"), + In(s, value=0.0, update=s + a * x, mutable=True), + ], + s + a * x, + ) try: g = copy.deepcopy(f) except NotImplementedError as e: @@ -932,14 +966,17 @@ def test_deepcopy_trust_input(self): a = dscalar() # the a is for 'anonymous' (un-named). x, s = dscalars("xs") - f = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=0.0, update=s + a * x, mutable=True), - ], - s + a * x, - ) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + f = function( + [ + x, + In(a, value=1.0, name="a"), + In(s, value=0.0, update=s + a * x, mutable=True), + ], + s + a * x, + ) f.trust_input = True try: g = copy.deepcopy(f) @@ -958,11 +995,13 @@ def test_deepcopy_trust_input(self): def test_output_keys(self): x = vector() - f = function([x], {"vec": x**2}) + with pytest.warns(FutureWarning, match="output_keys is deprecated."): + f = function([x], {"vec": x**2}) o = f([2, 3, 4]) assert isinstance(o, dict) assert np.allclose(o["vec"], [4, 9, 16]) - g = copy.deepcopy(f) + with pytest.warns(FutureWarning, match="output_keys is deprecated."): + g = copy.deepcopy(f) o = g([2, 3, 4]) assert isinstance(o, dict) assert np.allclose(o["vec"], [4, 9, 16]) @@ -971,7 +1010,10 @@ def test_deepcopy_shared_container(self): # Ensure that shared containers remain shared after a deep copy. a, x = scalars("ax") - h = function([In(a, value=0.0)], a) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + h = function([In(a, value=0.0)], a) f = function([x, In(a, value=h.container[a], implicit=True)], x + a) try: @@ -995,14 +1037,17 @@ def test_pickle(self): a = scalar() # the a is for 'anonymous' (un-named). x, s = scalars("xs") - f = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=0.0, update=s + a * x, mutable=True), - ], - s + a * x, - ) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + f = function( + [ + x, + In(a, value=1.0, name="a"), + In(s, value=0.0, update=s + a * x, mutable=True), + ], + s + a * x, + ) try: # Note that here we also test protocol 0 on purpose, since it @@ -1096,25 +1141,31 @@ def test_multiple_functions(self): # some derived thing, whose inputs aren't all in the list list_of_things.append(a * x + s) - f1 = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=0.0, update=s + a * x, mutable=True), - ], - s + a * x, - ) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + f1 = function( + [ + x, + In(a, value=1.0, name="a"), + In(s, value=0.0, update=s + a * x, mutable=True), + ], + s + a * x, + ) list_of_things.append(f1) # now put in a function sharing container with the previous one - f2 = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=f1.container[s], update=s + a * x, mutable=True), - ], - s + a * x, - ) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + f2 = function( + [ + x, + In(a, value=1.0, name="a"), + In(s, value=f1.container[s], update=s + a * x, mutable=True), + ], + s + a * x, + ) list_of_things.append(f2) assert isinstance(f2.container[s].storage, list) @@ -1122,7 +1173,10 @@ def test_multiple_functions(self): # now put in a function with non-scalar v_value = np.asarray([2, 3, 4.0], dtype=config.floatX) - f3 = function([x, In(v, value=v_value)], x + v) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + f3 = function([x, In(v, value=v_value)], x + v) list_of_things.append(f3) # try to pickle the entire things @@ -1254,23 +1308,29 @@ def __init__(self): self.e = a * x + s - self.f1 = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=0.0, update=s + a * x, mutable=True), - ], - s + a * x, - ) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + self.f1 = function( + [ + x, + In(a, value=1.0, name="a"), + In(s, value=0.0, update=s + a * x, mutable=True), + ], + s + a * x, + ) - self.f2 = function( - [ - x, - In(a, value=1.0, name="a"), - In(s, value=self.f1.container[s], update=s + a * x, mutable=True), - ], - s + a * x, - ) + with pytest.warns( + FutureWarning, match="Inputs with default values are deprecated." + ): + self.f2 = function( + [ + x, + In(a, value=1.0, name="a"), + In(s, value=self.f1.container[s], update=s + a * x, mutable=True), + ], + s + a * x, + ) def test_empty_givens_updates():