From 5544691f682fd50387db1444d31614e0cc71dfe1 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 17 Aug 2023 12:49:08 -0500 Subject: [PATCH 1/4] [Unity] Implement relax.Function.bind_params Similar to `relax.Function.bind_symbolic_vars`, implemented in https://github.com/apache/tvm/pull/15509, this commit introduces `relax.Function.bind_params` to allow Relax parameters to be manipulated on a per-function basis. This utility function and the existing `BindParams` transform both use the same underlying implementation. --- include/tvm/relax/transform.h | 2 +- include/tvm/relax/utils.h | 22 +++ python/tvm/relax/expr.py | 50 ++++++ python/tvm/relax/transform/transform.py | 11 +- src/relax/transform/bind_params.cc | 116 +++++++++---- src/relax/utils.cc | 56 +++++++ tests/python/relax/test_bind_params.py | 156 ++++++++++++++++++ .../relax/test_transform_bind_params.py | 52 ++++++ 8 files changed, 426 insertions(+), 39 deletions(-) create mode 100644 tests/python/relax/test_bind_params.py diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 05b26f024212..6d3e92b82245 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -180,7 +180,7 @@ TVM_DLL Pass EliminateCommonSubexpr(bool call_only = false); * * \return The Pass. */ -TVM_DLL Pass BindParams(String func_name, Map params); +TVM_DLL Pass BindParams(String func_name, Map params); /*! * \brief Bind symbolic vars to constant shape values. diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h index 1a6d5d4a5269..0e0249b863a6 100644 --- a/include/tvm/relax/utils.h +++ b/include/tvm/relax/utils.h @@ -24,7 +24,9 @@ #ifndef TVM_RELAX_UTILS_H_ #define TVM_RELAX_UTILS_H_ +#include #include +#include #include namespace tvm { @@ -48,6 +50,26 @@ namespace relax { TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& binds, const tvm::Map& symbolic_var_map = {}); +/*! + * \brief Infer a binding map for symbolic variables + * + * If a set of relax variables are replaced within an expression, this + * may result in removal of the definition site of a symbolic + * variable. This utility function determines the symbolic variable + * replacements that can be inferred based on the replaced relax + * variables, and can be used alongside the `Bind` utility function to + * replace both the relax variables and the implied symbolic + * variables. + * + * \param binds A map of relax variables to relax expressions + * + * \param analyzer The analyzer to use for simplifications + * + * \return A map of TIR variables to TIR expressions + */ +TVM_DLL tvm::Map InferSymbolicVarMap( + const tvm::Map& binds, arith::Analyzer* analyzer); + /*! * \brief Check if the given StructInfo is for a boolean scalar (tensor of rank 0 with a boolean * dtype). diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 49b91ffb3da1..cd5dfa2863a7 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -657,6 +657,56 @@ def bind_symbolic_vars( return _ffi_api.FunctionBindSymbolicVars(self, binding_map) # type: ignore + def bind_params( + self, + binding_map: Mapping[ + Union[str, Var], + Union[int, float, PrimExpr, tvm.runtime.NDArray, _np.ndarray, Expr], + ], + ) -> "Function": + """Return a new function with updated symbolic variable + + Parameters + ---------- + binding_map: Mapping[ + Union[str, Var], + Union[int, float, PrimExpr, tvm.runtime.NDArray, _np.ndarray, Expr], + ] + + The mapping of values to be replaced. + + Keys may be either a `relax.Var` or a string name of the + Relax variable. If the variables are referred to by name, + the name must uniquely identify a parameter in the + function. + + Values must be a relax expression, or a value that is + convertible into a relax expression. The value must be + compatible with the variable being replaced. + + Returns + ------- + func: Function + + The updated function + """ + + def _normalize_value(value): + # Conversions that must occur prior to the FFI + # conversions. + if isinstance(value, int): + # Relax uses int64 for symbolic variables, but the FFI + # converts python integers into int32. + return tvm.tir.const(value, "int64") + elif isinstance(value, (_np.ndarray, tvm.nd.NDArray)): + return tvm.relax.const(value) + else: + return value + + binding_map = {key: _normalize_value(value) for key, value in binding_map.items()} + + return _ffi_api.FunctionBindParams(self, binding_map) # type: ignore + @tvm._ffi.register_object("relax.expr.ExternFunc") class ExternFunc(BaseFunc): diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 438a6d1213e8..407805050547 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -387,7 +387,7 @@ def AttachGlobalSymbol() -> tvm.ir.transform.Pass: def BindParams( func_name: str, - params: Dict[str, Union[tvm.runtime.NDArray, np.ndarray]], + params: Dict[Union[str, Var], Union[tvm.runtime.NDArray, np.ndarray]], ) -> tvm.ir.transform.Pass: """Bind params of function of the module to constant tensors. @@ -397,8 +397,13 @@ def BindParams( func_name: str The function name to be bound - params : Dict[str, Union[tvm.runtime.NDArray, np.ndarray]] - The map from param name to constant tensors. + params : Dict[ + Union[str,relax.Var], + Union[tvm.runtime.NDArray, np.ndarray], + ] + + The map from parameter or parameter name name to constant + tensors. Returns ------- diff --git a/src/relax/transform/bind_params.cc b/src/relax/transform/bind_params.cc index c444a84f44e0..27931b601760 100644 --- a/src/relax/transform/bind_params.cc +++ b/src/relax/transform/bind_params.cc @@ -25,6 +25,7 @@ #include #include +#include #include namespace tvm { @@ -81,45 +82,88 @@ void MatchSymbolicVar(const Expr& arg, const Expr& constant, } } +std::tuple, Map> NormalizeBindings( + const Function& func, const Map& untyped_params) { + ICHECK(func.defined()); + ICHECK(untyped_params.defined()); + + // Map from string to the variable(s) with that name. + std::unordered_map> string_lookup; + std::unordered_set var_set; + for (const auto& param : func->params) { + string_lookup[param->name_hint()].push_back(param); + var_set.insert(param.get()); + } + + Map relax_var_remap; + + auto normalize_key = [&](ObjectRef obj) -> relax::Var { + if (auto opt_str = obj.as()) { + std::string str = opt_str.value(); + auto it = string_lookup.find(str); + CHECK(it != string_lookup.end()) + << "Function does not have parameter with name \"" << str << "\". " + << "Function parameters are named " + << func->params.Map([](const auto& param) { return param->name_hint(); }); + CHECK_EQ(it->second.size(), 1) + << "Function contains multiple parameters with name \"" << str << "\". " + << "The Relax variables " << it->second << " are all named \"" << str << "\""; + auto var = it->second[0]; + CHECK(!relax_var_remap.count(var)) + << "Remap of variable " << var << " was defined multiple times"; + + return var; + } else if (auto opt_var = obj.as()) { + auto var = opt_var.value(); + CHECK(!relax_var_remap.count(var)) + << "Remap of variable " << var << " was defined multiple times"; + CHECK(var_set.count(var.get())) + << "Function does not use Relax variable " << var << " as a parameter. " + << "Function parameters are " << func->params; + return var; + } else { + LOG(FATAL) + << "Expected bound parameter to be a relax::Var, " + << " or a string that uniquely identifies a relax::Var param within the function. " + << "However, received object " << obj << " of type " << obj->GetTypeKey(); + } + }; + auto normalize_value = [&](ObjectRef obj) -> relax::Expr { + if (auto opt = obj.as()) { + return opt.value(); + } else if (auto opt = obj.as()) { + return Constant(opt.value()); + } else { + LOG(FATAL) << "Cannot coerce object of type " << obj->GetTypeKey() + << " into relax expression"; + } + }; + + for (const auto& [key, value] : untyped_params) { + relax_var_remap.Set(normalize_key(key), normalize_value(value)); + } + + arith::Analyzer analyzer; + Map symbolic_var_map = InferSymbolicVarMap(relax_var_remap, &analyzer); + + // for (const auto& [bind_param, bind_expr] : relax_var_remap) { + // MatchSymbolicVar(bind_param, bind_expr, &symbolic_var_map, &analyzer); + // } + + return {relax_var_remap, symbolic_var_map}; +} + /*! * \brief Bind params to function by using name * \param func Relax function * \param params params dict * \return Function */ -inline Function BindParamsByName(Function func, const Map& params) { - std::unordered_map name_dict; - std::unordered_set repeat_var; - for (auto arg : func->params) { - const auto& name = arg->name_hint(); - if (name_dict.count(name)) { - repeat_var.insert(name_dict[name]); - } else { - name_dict[name] = arg; - } - } +Function FunctionBindParams(Function func, const Map& untyped_params) { + auto [bind_dict, symbolic_var_map] = NormalizeBindings(func, untyped_params); - arith::Analyzer analyzer; - Map bind_dict; - Map symbolic_var_map; - - for (auto& kv : params) { - if (name_dict.count(kv.first) == 0) { - continue; - } - const Var& arg = name_dict.at(kv.first); - if (repeat_var.count(arg)) { - LOG(FATAL) << "ValueError: Multiple args in the function have name " << kv.first; - } - Expr const_expr = Constant(kv.second); - bind_dict.Set(arg, const_expr); - MatchSymbolicVar(arg, const_expr, &symbolic_var_map, &analyzer); - } Expr bound_expr = Bind(func, bind_dict, symbolic_var_map); - Function ret = Downcast(bound_expr); - ICHECK(ret.defined()) << "The returning type is expected to be a Relax Function." - << "\n"; - return ret; + return Downcast(bound_expr); } /*! @@ -129,7 +173,7 @@ inline Function BindParamsByName(Function func, const Map param) { +IRModule BindParam(IRModule m, String func_name, Map bind_params) { IRModuleNode* new_module = m.CopyOnWrite(); Map functions = m->functions; for (const auto& func_pr : functions) { @@ -138,13 +182,13 @@ IRModule BindParam(IRModule m, String func_name, Map p // Use global_symbol if it's external linkage Optional gsymbol = relax_f->GetAttr(tvm::attr::kGlobalSymbol); if (gsymbol.defined() && gsymbol.value() == func_name) { - Function f_after_bind = BindParamsByName(GetRef(relax_f), param); + Function f_after_bind = FunctionBindParams(GetRef(relax_f), bind_params); new_module->Update(func_pr.first, f_after_bind); } } else { // Use global var's name_hint if it's internal linkage if (func_pr.first->name_hint == func_name) { - Function f_after_bind = BindParamsByName(GetRef(relax_f), param); + Function f_after_bind = FunctionBindParams(GetRef(relax_f), bind_params); new_module->Update(func_pr.first, f_after_bind); } } @@ -153,9 +197,11 @@ IRModule BindParam(IRModule m, String func_name, Map p return GetRef(new_module); } +TVM_REGISTER_GLOBAL("relax.FunctionBindParams").set_body_typed(FunctionBindParams); + namespace transform { -Pass BindParams(String func_name, Map params) { +Pass BindParams(String func_name, Map params) { runtime::TypedPackedFunc pass_func = [=](IRModule mod, PassContext pc) { return BindParam(std::move(mod), func_name, params); }; return CreateModulePass(pass_func, 0, "BindParams", {}); diff --git a/src/relax/utils.cc b/src/relax/utils.cc index ccb72805e371..f8235def240b 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -144,6 +144,62 @@ Expr Bind(const Expr& expr, const tvm::Map& binds, return ExprBinder(binds, symbolic_var_map).VisitExpr(expr); } +tvm::Map InferSymbolicVarMap( + const tvm::Map& relax_var_remap, arith::Analyzer* analyzer) { + tvm::Map tir_var_remap; + + auto bind_from_prim_expr = [&tir_var_remap](const PrimExpr& var_shape, + const PrimExpr& expr_shape) { + if (auto var = var_shape.as()) { + tir_var_remap.Set(var.value(), expr_shape); + } + }; + + auto bind_from_shape = [&bind_from_prim_expr](const StructInfo& var, const StructInfo& expr) { + auto var_shape = var.as(); + if (!var_shape) return; + if (!var_shape->values.defined()) return; + + auto expr_shape = expr.as(); + CHECK(expr_shape) << "Cannot bind expression with struct type " << expr + << " to variable with struct type " << var; + if (!expr_shape->values.defined()) return; + + auto var_shape_arr = var_shape->values.value(); + auto expr_shape_arr = expr_shape->values.value(); + CHECK_EQ(var_shape_arr.size(), expr_shape_arr.size()) + << "Cannot bind shape " << expr_shape_arr << " of dimension " << expr_shape_arr.size() + << " to variable with shape " << var_shape_arr << " of dimension " << var_shape_arr.size(); + for (size_t i = 0; i < var_shape_arr.size(); i++) { + bind_from_prim_expr(var_shape_arr[i], expr_shape_arr[i]); + } + }; + + auto bind_from_tensor = [&bind_from_shape](const StructInfo& var, const StructInfo& expr) { + auto var_tensor = var.as(); + if (!var_tensor) return; + if (!var_tensor->shape.defined()) return; + + auto expr_tensor = expr.as(); + CHECK(expr_tensor) << "Cannot bind expression with struct type " << expr + << " to variable with struct type " << var; + if (!expr_tensor->shape.defined()) return; + + bind_from_shape(GetStructInfo(var_tensor->shape.value()), + GetStructInfo(expr_tensor->shape.value())); + }; + + for (const auto& [relax_var, relax_expr] : relax_var_remap) { + auto var_sinfo = GetStructInfo(relax_var); + auto expr_sinfo = GetStructInfo(relax_expr); + + bind_from_tensor(var_sinfo, expr_sinfo); + bind_from_shape(var_sinfo, expr_sinfo); + } + + return tir_var_remap; +} + bool IsBoolStructInfo(const StructInfo& sinfo, bool permit_unknown_rank, bool permit_unknown_dtype) { const TensorStructInfoNode* tt = sinfo.as(); diff --git a/tests/python/relax/test_bind_params.py b/tests/python/relax/test_bind_params.py new file mode 100644 index 000000000000..a92e4fe8e510 --- /dev/null +++ b/tests/python/relax/test_bind_params.py @@ -0,0 +1,156 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import tvm +import tvm.script +import tvm.testing +from tvm import relax, tir +from tvm.script import relax as R + +import numpy as np +import pytest + +param_specification = tvm.testing.parameter("by_string", "by_var") +param_shape = tvm.testing.parameter("static_shape", "dynamic_shape", "ndim", "arbitrary") +tensor_param_dtype = tvm.testing.parameter("float32", None) + + +def test_bind_tensor_param(param_specification, param_shape, tensor_param_dtype): + if param_shape == "static_shape": + shape = [16] + ndim = -1 + elif param_shape == "dynamic_shape": + shape = [tir.Var("N", "int64")] + ndim = -1 + elif param_shape == "ndim": + shape = None + ndim = 1 + elif param_shape == "arbitrary": + shape = None + ndim = -1 + else: + raise ValueError(f"Unknown param_shape: {param_shape}") + + @R.function + def before(A: R.Tensor(shape, ndim=ndim, dtype=tensor_param_dtype)): + R.func_attr({"global_symbol": "main"}) + B: R.Tensor(shape=shape, ndim=ndim, dtype=tensor_param_dtype) = A + out = R.add(B, B) + return out + + np_data = np.arange(16).astype("float32") + inlined_relax_const = relax.const(np_data) + + @R.function + def expected() -> R.Tensor([16], "float32"): + R.func_attr({"global_symbol": "main"}) + B = inlined_relax_const + out = R.add(B, B) + return out + + if param_specification == "by_string": + var = "A" + elif param_specification == "by_var": + var = before.params[0] + else: + raise ValueError("Unknown param_specification: {param_specification}") + + after = before.bind_params({var: np.arange(16).astype("float32")}) + + tvm.ir.assert_structural_equal(expected, after) + + +def test_bind_shape_param(param_shape): + if param_shape == "static_shape": + shape = [16] + ndim = -1 + elif param_shape == "dynamic_shape": + shape = [tir.Var("N", "int64")] + ndim = -1 + elif param_shape == "ndim": + shape = None + ndim = 1 + elif param_shape == "arbitrary": + shape = None + ndim = -1 + else: + raise ValueError(f"Unknown param_shape: {param_shape}") + + @R.function + def before(A: R.Shape(shape, ndim=ndim)): + R.func_attr({"global_symbol": "main"}) + B: R.Shape(shape, ndim=ndim) = A + return B + + @R.function + def expected() -> R.Shape([16]): + R.func_attr({"global_symbol": "main"}) + B = R.ShapeExpr([16]) + return B + + after = before.bind_params({"A": relax.ShapeExpr([16])}) + + tvm.ir.assert_structural_equal(expected, after) + + +prim_value_dtype = tvm.testing.parameter("int64", "int32", "float32") + + +@pytest.mark.xfail(reason="Depends on relax.PrimValue holding a tir.PrimExpr, PR#15577") +def test_bind_prim_value(prim_value_dtype): + @R.function + def before(A: R.Prim(value="N", dtype=prim_value_dtype)): + R.func_attr({"global_symbol": "main"}) + B: R.Prim(value="N", dtype=prim_value_dtype) = A + return B + + @R.function + def expected() -> R.Prim(value=16, dtype=prim_value_dtype): + R.func_attr({"global_symbol": "main"}) + B = R.PrimValue(value=16, dtype=dtype) + return B + + after = before.bind_params({"A": relax.PrimValue(tir.const(16, prim_value_dtype))}) + + tvm.ir.assert_structural_equal(expected, after) + + +def test_error_on_unknown_var(): + @R.function + def before(A: R.Tensor([16], dtype="float32")): + R.func_attr({"global_symbol": "main"}) + return A + + unknown_var = relax.Var("unknown_var") + + with pytest.raises(tvm.TVMError): + before.bind_params({unknown_var: np.arange(16).astype("float32")}) + + +def test_error_on_unknown_var_name(): + @R.function + def before(A: R.Tensor([16], dtype="float32")): + R.func_attr({"global_symbol": "main"}) + return A + + with pytest.raises(tvm.TVMError): + before.bind_params({"unknown_var_name": np.arange(16).astype("float32")}) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_bind_params.py b/tests/python/relax/test_transform_bind_params.py index 8e760b6fd70f..9e212693f969 100644 --- a/tests/python/relax/test_transform_bind_params.py +++ b/tests/python/relax/test_transform_bind_params.py @@ -123,5 +123,57 @@ def main( ) +param_specification = tvm.testing.parameter("by_string", "by_var") + + +def test_bind_params_by_var_obj(param_specification): + @tvm.script.ir_module + class Before: + @R.function + def main(A: R.Tensor([16], "float32")): + return A + + np_data = np.arange(16).astype("float32") + inlined_relax_const = relax.const(np_data) + + @tvm.script.ir_module + class Expected: + @R.function + def main(): + return inlined_relax_const + + if param_specification == "by_string": + var = "A" + elif param_specification == "by_var": + var = Before["main"].params[0] + else: + raise ValueError("Unknown param_specification: {param_specification}") + + After = relax.transform.BindParams("main", {var: np_data})(Before) + + tvm.ir.assert_structural_equal(Expected, After) + + +def test_bind_params_by_var_name(): + @tvm.script.ir_module + class Before: + @R.function + def main(A: R.Tensor([16], "float32")): + return A + + np_data = np.arange(16).astype("float32") + inlined_relax_const = relax.const(np_data) + + @tvm.script.ir_module + class Expected: + @R.function + def main(): + return inlined_relax_const + + After = relax.transform.BindParams("main", {"A": np_data})(Before) + + tvm.ir.assert_structural_equal(Expected, After) + + if __name__ == "__main__": tvm.testing.main() From 5e166ebe9bd54d3a003adf60a789911a02918d4b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 1 Sep 2023 14:44:00 -0500 Subject: [PATCH 2/4] Update relay_translator unit tests to avoid duplicate binding --- tests/python/relax/test_relay_translator.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/python/relax/test_relay_translator.py b/tests/python/relax/test_relay_translator.py index 54cd1b243dbf..c752fa5e1015 100644 --- a/tests/python/relax/test_relay_translator.py +++ b/tests/python/relax/test_relay_translator.py @@ -126,10 +126,14 @@ def test_verify_e2e_translation_gpu(layout, batch_size, image_shape): def verify_extracted_tasks(target_str, layout, batch_size, image_shape, module_equality): target = Target(target_str) relay_mod, params = get_resnet(batch_size, "float32", layout, image_shape) + # Parameters can be bound either as part of the `from_relay` + # conversion, or as part of the `extract_tasks` method. However, + # they shouldn't be used in both locations, because + # `relax.BindParams` validates that there exists an unbound + # parameter of the specified name. relax_mod = relay_translator.from_relay( relay_mod["main"], target, - params, pass_config={ "relay.backend.use_meta_schedule": True, "relay.FuseOps.max_depth": 1, # Disable relay fusion From 67cff52341c2c81e41378b169f8a72a0505b83cb Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 1 Sep 2023 21:00:04 -0500 Subject: [PATCH 3/4] Updated unit test that attempted to bind non-existent parameter --- tests/python/relax/test_transform_fold_constant.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/python/relax/test_transform_fold_constant.py b/tests/python/relax/test_transform_fold_constant.py index b8ad5c4487d3..c2a3bd50922b 100644 --- a/tests/python/relax/test_transform_fold_constant.py +++ b/tests/python/relax/test_transform_fold_constant.py @@ -378,8 +378,7 @@ def expected(data: R.Tensor((256,), "float32")) -> R.Tensor((16, 16), dtype="flo before = gen_mod(Module, "before", {"c0": c0_np, "c1": c1_np}) assert relax.analysis.well_formed(before) - c2_np = np.multiply(np.add(c0_np, c0_np), c1_np) - expected = gen_mod(Module, "expected", {"c2": c2_np}) + expected = gen_mod(Module, "expected", {}) after = relax.transform.FoldConstant()(before) tvm.ir.assert_structural_equal(after, expected) From 0869a9b33b896e085f78ee21321af580f2417096 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sat, 2 Sep 2023 16:08:38 -0500 Subject: [PATCH 4/4] ci bump