From d9ee6377cdd8395b27385d2fc2745b741fad6183 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 29 Sep 2024 06:59:33 +0900 Subject: [PATCH] [Relax][PyTorch] Support neural network ops for ExportedProgram importer (#17426) * support batchnorm2d and getitem * support addmm * support avg_pool2d * support baddbmm * support bmm * support conv_transpose1d * support conv_transpose2d * support conv1d * support conv3d * support einsum * support embedding * support group_norm * support layer_norm * support scaled_dot_product_attention * support unbind * support interpolate * fix lint error --- .../torch/base_fx_graph_translator.py | 464 +++++++ .../torch/exported_program_translator.py | 111 ++ .../tvm/relax/frontend/torch/fx_translator.py | 482 +------ .../test_frontend_from_exported_program.py | 1150 ++++++++++++++++- 4 files changed, 1723 insertions(+), 484 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index a41b9b6d4f9a..52784dc8c3cd 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -227,6 +227,228 @@ def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW") ) + def _addmm(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + y = self.env[node.args[1]] + z = self.env[node.args[2]] + alpha = node.kwargs.get("alpha", 1) + beta = node.kwargs.get("beta", 1) + + res = None + if alpha != 0: + res = self.block_builder.emit(relax.op.linear_algebra.matmul(y, z, out_dtype="float32")) + if alpha != 1: + dtype = res.struct_info.dtype + res = self.block_builder.emit(relax.op.multiply(res, relax.const(alpha, dtype))) + if beta != 0: + dtype = x.struct_info.dtype + if beta != 1: + bias = self.block_builder.emit(relax.op.multiply(x, relax.const(beta, dtype))) + else: + bias = x + res = bias if res is None else self.block_builder.emit(relax.op.add(bias, res)) + return res + + def _avg_pool2d_impl( + self, + x: relax.Expr, + kernel_size: Union[int, Tuple[int, int]] = (1, 1), + stride: Optional[Union[int, Tuple[int, int]]] = None, + padding: Optional[int] = 0, + ceil_mode: Optional[bool] = False, + ) -> relax.Var: + stride = kernel_size if stride is None or stride == [] else stride + return self.block_builder.emit( + relax.op.nn.avg_pool2d( + x, + pool_size=kernel_size, + strides=stride, + padding=padding, + ceil_mode=ceil_mode, + layout="NCHW", + ) + ) + + def _avg_pool2d(self, node: fx.Node) -> relax.Var: + args, kwargs = node.normalized_arguments(node) + x = self.env[args[0]] + kernel_size = args[1] if len(args) > 1 else kwargs["kernel_size"] + stride = args[2] if len(args) > 2 else kwargs.get("stride", None) + padding = args[3] if len(args) > 3 else kwargs.get("padding", 0) + ceil_mode = args[4] if len(args) > 4 else kwargs.get("ceil_mode", False) + return self._avg_pool2d_impl(x, kernel_size, stride, padding, ceil_mode) + + def _baddbmm(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + batch1 = self.env[node.args[1]] + batch2 = self.env[node.args[2]] + alpha = node.kwargs.get("alpha", 1) + beta = node.kwargs.get("beta", 1) + + res = None + if alpha != 0: + res = self.block_builder.emit(relax.op.matmul(batch1, batch2)) + if alpha != 1: + dtype = res.struct_info.dtype + res = self.block_builder.emit(relax.op.multiply(res, relax.const(alpha, dtype))) + if beta != 0: + dtype = x.struct_info.dtype + if beta != 1: + bias = self.block_builder.emit(relax.op.multiply(x, relax.const(beta, dtype))) + else: + bias = x + res = bias if res is None else self.block_builder.emit(relax.op.add(res, bias)) + return res + + def _conv_transpose1d_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ) -> relax.Var: + conv1d_transpose = self.block_builder.emit( + relax.op.nn.conv1d_transpose( + x, + weight, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout="NCW", + kernel_layout="OIW", + out_dtype="float32", + ) + ) + + if bias is None: + return conv1d_transpose + + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1)) + return self.block_builder.emit(relax.op.add(conv1d_transpose, bias)) + + def _conv_transpose1d(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv_transpose1d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + def _conv_transpose2d_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ) -> relax.Var: + conv2d_transpose = self.block_builder.emit( + relax.op.nn.conv2d_transpose( + x, + weight, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="float32", + ) + ) + + if bias is None: + return conv2d_transpose + + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1, 1)) + return self.block_builder.emit(relax.op.add(conv2d_transpose, bias)) + + def _conv_transpose2d(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv_transpose2d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + def _conv1d_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ) -> relax.Var: + conv1d = self.block_builder.emit( + relax.op.nn.conv1d( + x, + weight, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout="NCW", + kernel_layout="OIW", + out_dtype="float32", + ) + ) + + if bias is None: + return conv1d + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1)) + return self.block_builder.emit(relax.op.add(conv1d, bias)) + + def _conv1d(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv1d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + def _conv2d_impl( self, x: relax.Expr, @@ -276,6 +498,134 @@ def _conv2d(self, node: fx.Node) -> relax.Var: groups=groups, ) + def _conv3d_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ): + conv3d = self.block_builder.emit( + relax.op.nn.conv3d( + x, + weight, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout="NCDHW", + kernel_layout="OIDHW", + out_dtype="float32", + ) + ) + + if bias is None: + return conv3d + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1, 1, 1)) + return self.block_builder.emit(relax.op.add(conv3d, bias)) + + def _conv3d(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv3d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + def _einsum(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + operands = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] + return self.block_builder.emit(relax.op.einsum(operands, args[0])) + + def _embedding_impl( + self, + x, + weight, + ) -> relax.Var: + x = self.block_builder.emit(relax.op.astype(x, "int32")) + + ndim = x.struct_info.ndim + if ndim == 1: + return self.block_builder.emit(relax.op.take(weight, x, axis=0)) + else: + x_shape = x.struct_info.shape.values + emb_size = weight.struct_info.shape.values[-1] + x = self.block_builder.emit(relax.op.reshape(x, shape=[-1])) + embedding = self.block_builder.emit(relax.op.take(weight, x, axis=0)) + return self.block_builder.emit(relax.op.reshape(embedding, [*x_shape, emb_size])) + + def _layer_norm_impl(self, x, gamma, beta, eps, normalized_shape) -> relax.Var: + from torch.fx.immutable_collections import immutable_list + import numpy as np # type: ignore + + if isinstance(normalized_shape, (immutable_list, tuple)): + normalized_shape = tuple(normalized_shape) + else: + try: + normalized_shape = self.env[normalized_shape] + except TypeError: + normalized_shape = tuple(normalized_shape) + + dim_num = len(normalized_shape) + axes = list(range(-dim_num, 0)) + + if gamma is None: + shape_tuple = [int(s) for s in normalized_shape] + gamma = relax.const(np.ones(shape_tuple), x.struct_info.dtype) + if beta is None: + shape_tuple = [int(s) for s in normalized_shape] + beta = relax.const(np.zeros(shape_tuple), x.struct_info.dtype) + + return self.block_builder.emit( + relax.op.nn.layer_norm( + x, + gamma, + beta, + axes=axes, + epsilon=eps, + ) + ) + + def _layer_norm(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + normalized_shape = node.args[1] + gamma = self.env[node.args[2]] if len(node.args) > 2 else None + beta = self.env[node.args[3]] if len(node.args) > 3 else None + eps = node.args[4] if len(node.args) > 4 else 1e-05 + return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape) + + def _layer_norm_module(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + x = self.env[node.args[0]] + module = self.named_modules[node.target] + normalized_shape = module.normalized_shape + if module.elementwise_affine: + gamma = self.params[module.weight] + beta = self.params[module.bias] + else: + gamma = relax.const(torch.ones_like(module.normalized_shape), x.struct_info.dtype) + beta = relax.const(torch.zeros_like(module.normalized_shape), x.struct_info.dtype) + eps = module.eps + return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape) + def _linear(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] @@ -316,6 +666,39 @@ def _max_pool2d(self, node: fx.Node) -> relax.Var: return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: + transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3]) + query = transpose_S_H(self.env[node.args[0]]) + key = transpose_S_H(self.env[node.args[1]]) + value = transpose_S_H(self.env[node.args[2]]) + attn_mask = node.args[3] if len(node.args) > 3 else node.kwargs.get("attn_mask", None) + dropout_p = node.args[4] if len(node.args) > 4 else node.kwargs.get("dropout_p", 0.0) + assert dropout_p == 0.0, "Dropout is not supported" + is_causal = node.args[5] if len(node.args) > 5 else node.kwargs.get("is_causal", False) + causal_mask = "TopLeft" if is_causal else None + + if attn_mask is not None: + attn_mask = self.env[attn_mask] + msg = "Only a float mask is supported for the attn_mask input." + assert "float" in attn_mask.struct_info.dtype, msg + + return self.block_builder.emit( + transpose_S_H( + relax.op.nn.attention(query, key, value, bias=attn_mask, causal_mask=causal_mask) + ) + ) + + def _unbind(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) + assert isinstance(dim, int), "Expected 2nd argument of unbind as int" + selections = self.shape_of(x)[dim].value + n_section = list(range(1, selections + 1)) + ret, split = [], self.block_builder.emit(relax.op.split(x, n_section, dim)) + for i in range(selections): + ret.append(self.block_builder.emit(relax.op.squeeze(split[i], axis=dim))) + return self.block_builder.emit(relax.Tuple(ret)) + ########## Statistical ########## def _mean(self, node: fx.Node) -> relax.Var: @@ -357,6 +740,87 @@ def _reshape(self, node: fx.Node) -> relax.Var: ########## Others ########## + def _getitem(self, node: fx.Node) -> relax.Var: + import torch + + x = self.env[node.args[0]] + if isinstance(x, (list, tuple, relax.ShapeExpr, relax.Tuple)): + return x[node.args[1]] + elif isinstance(x, relax.Var): + if isinstance(x.struct_info, relax.TupleStructInfo): + return self.block_builder.emit(relax.TupleGetItem(x, node.args[1])) + + assert isinstance(x.struct_info, relax.TensorStructInfo) + take_indices = [] + take_axes = [] + stride_begin = [] + stride_end = [] + stride = [] + stride_axes = [] + expand_dim = [] + i = 0 + shape = self.shape_of(x) + non_ellipsis_cnt = 0 + for index in node.args[1]: + if isinstance(index, (int, slice, torch.fx.Node)): + non_ellipsis_cnt += 1 + for index in node.args[1]: + if isinstance(index, int): + stride_begin.append(index) + stride_end.append(index + 1) + stride.append(1) + stride_axes.append(i) + i = i + 1 + elif isinstance(index, slice): + stride_begin.append(0 if index.start is None else index.start) + stride_end.append(shape[i] if index.stop is None else index.stop) + stride.append(1 if index.step is None else index.step) + stride_axes.append(i) + i = i + 1 + elif index is None: + expand_dim.append(len(stride_axes) + len(expand_dim)) + elif index is Ellipsis: + for _ in range(len(shape) - non_ellipsis_cnt): + stride_begin.append(0) + stride_end.append(shape[i]) + stride.append(1) + stride_axes.append(i) + i += 1 + elif isinstance(index, torch.fx.Node): + node_index = self.env[index] + if not isinstance(node_index, relax.Expr): + raise ValueError( + "Unsupported index type for relax.op.take: " + str(type(node_index)) + ) + take_indices.append(node_index) + take_axes.append(i) + i = i + 1 + else: + raise ValueError("Unsupported index type: " + str(type(index))) + while i < len(shape): + stride_begin.append(0) + stride_end.append(shape[i]) + stride.append(1) + stride_axes.append(i) + i += 1 + taken = x + if len(take_indices) > 1: + raise ValueError("Multiple tensors as index not yet supported") + for each_index, each_axis in zip(take_indices, take_axes): + taken = self.block_builder.emit(relax.op.take(taken, each_index, each_axis)) + sliced = self.block_builder.emit( + relax.op.strided_slice(taken, stride_axes, stride_begin, stride_end, stride) + ) + sliced_shape = list(self.shape_of(sliced)) + for i in expand_dim: + sliced_shape.insert(i, 1) + return self.block_builder.emit(relax.op.reshape(sliced, sliced_shape)) + elif isinstance(x, relax.Constant): + dtype = x.struct_info.dtype + return relax.const(x.data.numpy()[node.args[1]], dtype) + else: + assert False + @abc.abstractmethod def create_convert_map( self, diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 11594690cdc2..64583d750974 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -74,6 +74,94 @@ def _hardtanh(self, node: fx.Node) -> relax.Expr: max_val = node.args[2] if len(args) > 2 else node.kwargs("max_val", 1.0) return self.block_builder.emit(relax.op.clip(x, min_val, max_val)) + ########## Neural Network ########## + + def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var: + import numpy as np + + x = self.env[node.args[0]] + channel = int(self.shape_of(x)[1]) + dtype = x.struct_info.dtype + weight = self.env.get(node.args[1], relax.const(np.ones(channel), dtype=dtype)) + bias = self.env.get(node.args[2], relax.const(np.zeros(channel), dtype=dtype)) + running_mean = self.env.get(node.args[3], relax.const(np.zeros(channel), dtype=dtype)) + running_var = self.env.get(node.args[4], relax.const(np.ones(channel), dtype=dtype)) + momentum = node.args[5] if len(node.args) > 5 else node.kwargs.get("momentum", 0.1) + eps = node.args[6] if len(node.args) > 6 else node.kwargs.get("eps", 1e-05) + + return self.block_builder.emit( + relax.op.nn.batch_norm( + x, + weight, + bias, + running_mean, + running_var, + axis=1, + epsilon=eps, + momentum=momentum, + ) + ) + + def _group_norm(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + num_groups = node.args[1] + gamma = self.env[node.args[2]] if len(node.args) > 2 else None + beta = self.env[node.args[3]] if len(node.args) > 3 else None + eps = node.args[4] if len(node.args) > 4 else 1e-05 + + dim = len(self.shape_of(x)) + return self.block_builder.emit( + relax.op.nn.group_norm( + x, + gamma, + beta, + num_groups=num_groups, + channel_axis=1, + axes=list(range(2, dim)), + epsilon=eps, + ) + ) + + def _upsample_impl( + self, x: relax.Expr, size, align_corners: bool, scale_factor, method: str + ) -> relax.Var: + coord_trans = "align_corners" if align_corners else "half_pixel" + + if size is None: + shape = self.shape_of(x) + assert isinstance(shape, relax.ShapeExpr) + if isinstance(scale_factor, (tuple, list)): + assert len(scale_factor) == len(shape) - 2 + size = tuple( + int(shape[i].value * scale_factor[i - 2]) for i in range(2, len(shape)) + ) + else: + size = tuple(int(shape[i].value * scale_factor) for i in range(2, len(shape))) + + return self.block_builder.emit( + relax.op.image.resize2d( + x, size, layout="NCHW", method=method, coordinate_transformation_mode=coord_trans + ) + ) + + def _upsample_bilinear2d(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + size = node.args[1] if len(node.args) > 1 else node.kwargs.get("size", None) + align_corners = ( + node.args[2] if len(node.args) > 2 else node.kwargs.get("align_corners", True) + ) + scale_factor = node.args[3] if len(node.args) > 3 else node.kwargs.get("scale_factor", None) + return self._upsample_impl(x, size, align_corners, scale_factor, "linear") + + def _upsample_nearest2d(self, node: fx.node) -> relax.Var: + x = self.env[node.args[0]] + size = node.args[1] if len(node.args) > 1 else node.kwargs.get("size", None) + align_corners = ( + node.args[2] if len(node.args) > 2 else node.kwargs.get("align_corners", True) + ) + scale_factor = node.args[3] if len(node.args) > 3 else node.kwargs.get("scale_factor", None) + return self._upsample_impl(x, size, align_corners, scale_factor, "nearest_neighbor") + def create_convert_map( self, ) -> Dict[str, Callable[[fx.Node], relax.Var]]: @@ -129,10 +217,31 @@ def create_convert_map( "pow.Tensor_Tensor": self._binary_op(relax.op.power, operator.pow), "sub.Tensor": self._binary_op(relax.op.subtract, operator.sub), # neural network + "_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training, "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, + "addmm.default": self._addmm, + "avg_pool2d.default": self._avg_pool2d, + "baddbmm.default": self._baddbmm, + "bmm.default": self._binary_op( + partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul + ), + "conv_transpose1d.default": self._conv_transpose1d, + "conv_transpose2d.input": self._conv_transpose2d, + "conv1d.default": self._conv1d, "conv2d.default": self._conv2d, + "conv3d.default": self._conv3d, + "einsum.default": self._einsum, + "embedding.default": lambda node: self._embedding_impl( + self.env[node.args[1]], self.env[node.args[0]] + ), + "group_norm.default": self._group_norm, + "layer_norm.default": self._layer_norm, "linear.default": self._linear, "max_pool2d.default": self._max_pool2d, + "scaled_dot_product_attention.default": self._scaled_dot_product_attention, + "unbind.int": self._unbind, + "upsample_bilinear2d.vec": self._upsample_bilinear2d, + "upsample_nearest2d.vec": self._upsample_nearest2d, # statistical "mean.dim": self._mean, "sum.dim_IntList": self._sum, @@ -141,6 +250,8 @@ def create_convert_map( "argmin.default": self._argmax_argmin(relax.op.argmin), # tensor manipulation "view.default": self._reshape, + # other + "getitem": self._getitem, } def from_exported_program( diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index dc6ebc2eb34f..c60c7c3953b4 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -18,7 +18,7 @@ # pylint: disable=invalid-name, inconsistent-return-statements, unidiomatic-typecheck # pylint: disable=import-outside-toplevel """PyTorch FX frontend of Relax.""" -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Tuple, Union from functools import partial, reduce import tvm @@ -107,57 +107,6 @@ def _adaptive_avg_pool2d_module(self, node: fx.Node) -> relax.Var: relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW") ) - def _addmm(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - y = self.env[node.args[1]] - z = self.env[node.args[2]] - alpha = node.kwargs.get("alpha", 1) - beta = node.kwargs.get("beta", 1) - - res = None - if alpha != 0: - res = self.block_builder.emit(relax.op.linear_algebra.matmul(y, z, out_dtype="float32")) - if alpha != 1: - dtype = res.struct_info.dtype - res = self.block_builder.emit(relax.op.multiply(res, relax.const(alpha, dtype))) - if beta != 0: - dtype = x.struct_info.dtype - if beta != 1: - bias = self.block_builder.emit(relax.op.multiply(x, relax.const(beta, dtype))) - else: - bias = x - res = bias if res is None else self.block_builder.emit(relax.op.add(bias, res)) - return res - - def _avg_pool2d_impl( - self, - x: relax.Expr, - kernel_size: Union[int, Tuple[int, int]] = (1, 1), - stride: Optional[Union[int, Tuple[int, int]]] = None, - padding: Optional[int] = 0, - ceil_mode: Optional[bool] = False, - ) -> relax.Var: - stride = kernel_size if stride is None or stride == [] else stride - return self.block_builder.emit( - relax.op.nn.avg_pool2d( - x, - pool_size=kernel_size, - strides=stride, - padding=padding, - ceil_mode=ceil_mode, - layout="NCHW", - ) - ) - - def _avg_pool2d(self, node: fx.Node) -> relax.Var: - args, kwargs = node.normalized_arguments(node) - x = self.env[args[0]] - kernel_size = args[1] if len(args) > 1 else kwargs["kernel_size"] - stride = args[2] if len(args) > 2 else kwargs.get("stride", None) - padding = args[3] if len(args) > 3 else kwargs.get("padding", 0) - ceil_mode = args[4] if len(args) > 4 else kwargs.get("ceil_mode", False) - return self._avg_pool2d_impl(x, kernel_size, stride, padding, ceil_mode) - def _avg_pool2d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -167,28 +116,6 @@ def _avg_pool2d_module(self, node: fx.Node) -> relax.Var: ceil_mode = module.ceil_mode return self._avg_pool2d_impl(x, kernel_size, stride, padding, ceil_mode) - def _baddbmm(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - a = self.env[node.args[1]] - b = self.env[node.args[2]] - alpha = node.kwargs.get("alpha", 1) - beta = node.kwargs.get("beta", 1) - - res = None - if alpha != 0: - res = self.block_builder.emit(relax.op.matmul(a, b)) - if alpha != 1: - dtype = res.struct_info.dtype - res = self.block_builder.emit(relax.op.multiply(res, relax.const(alpha, dtype))) - if beta != 0: - dtype = x.struct_info.dtype - if beta != 1: - bias = self.block_builder.emit(relax.op.multiply(x, relax.const(beta, dtype))) - else: - bias = x - res = bias if res is None else self.block_builder.emit(relax.op.add(res, bias)) - return res - def _batch_norm_2d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -212,63 +139,13 @@ def _batch_norm_2d_module(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.TupleGetItem(res_tuple, 0)) - def _conv1d_transpose_impl( - self, - x: relax.Expr, - weight: relax.Expr, - bias: Optional[relax.Expr], - strides: Optional[Tuple], - padding: Optional[Tuple], - dilation: Optional[Tuple], - groups: Optional[Tuple], - ) -> relax.Var: - conv1d_transpose = self.block_builder.emit( - relax.op.nn.conv1d_transpose( - x, - weight, - strides=strides, - padding=padding, - dilation=dilation, - groups=groups, - data_layout="NCW", - kernel_layout="OIW", - out_dtype="float32", - ) - ) - - if bias is None: - return conv1d_transpose - - assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1)) - return self.block_builder.emit(relax.op.add(conv1d_transpose, bias)) - - def _conv1d_transpose(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - stride = args[3] if len(args) > 3 else 1 - padding = args[4] if len(args) > 4 else 0 - dilation = args[5] if len(args) > 5 else 1 - groups = args[6] if len(args) > 6 else 1 - return self._conv1d_transpose_impl( - x, - weight, - bias=bias, - strides=stride, - padding=padding, - dilation=dilation, - groups=groups, - ) - - def _conv1d_transpose_module(self, node: fx.Node) -> relax.Var: + def _conv_transpose1d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] bias = self.params.get(module.bias, None) - return self._conv1d_transpose_impl( + return self._conv_transpose1d_impl( x, weight, bias=bias, @@ -278,63 +155,13 @@ def _conv1d_transpose_module(self, node: fx.Node) -> relax.Var: groups=module.groups, ) - def _conv2d_transpose_impl( - self, - x: relax.Expr, - weight: relax.Expr, - bias: Optional[relax.Expr], - strides: Optional[Tuple], - padding: Optional[Tuple], - dilation: Optional[Tuple], - groups: Optional[Tuple], - ) -> relax.Var: - conv2d_transpose = self.block_builder.emit( - relax.op.nn.conv2d_transpose( - x, - weight, - strides=strides, - padding=padding, - dilation=dilation, - groups=groups, - data_layout="NCHW", - kernel_layout="OIHW", - out_dtype="float32", - ) - ) - - if bias is None: - return conv2d_transpose - - assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1, 1)) - return self.block_builder.emit(relax.op.add(conv2d_transpose, bias)) - - def _conv2d_transpose(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - stride = args[3] if len(args) > 3 else 1 - padding = args[4] if len(args) > 4 else 0 - dilation = args[5] if len(args) > 5 else 1 - groups = args[6] if len(args) > 6 else 1 - return self._conv2d_transpose_impl( - x, - weight, - bias=bias, - strides=stride, - padding=padding, - dilation=dilation, - groups=groups, - ) - - def _conv2d_transpose_module(self, node: fx.Node) -> relax.Var: + def _conv_transpose2d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] bias = self.params.get(module.bias, None) - return self._conv2d_transpose_impl( + return self._conv_transpose2d_impl( x, weight, bias=bias, @@ -344,55 +171,6 @@ def _conv2d_transpose_module(self, node: fx.Node) -> relax.Var: groups=module.groups, ) - def _conv1d_impl( - self, - x: relax.Expr, - weight: relax.Expr, - bias: Optional[relax.Expr], - strides: Optional[Tuple], - padding: Optional[Tuple], - dilation: Optional[Tuple], - groups: Optional[Tuple], - ) -> relax.Var: - conv1d = self.block_builder.emit( - relax.op.nn.conv1d( - x, - weight, - strides=strides, - padding=padding, - dilation=dilation, - groups=groups, - data_layout="NCW", - kernel_layout="OIW", - out_dtype="float32", - ) - ) - - if bias is None: - return conv1d - assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1)) - return self.block_builder.emit(relax.op.add(conv1d, bias)) - - def _conv1d(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - stride = args[3] if len(args) > 3 else 1 - padding = args[4] if len(args) > 4 else 0 - dilation = args[5] if len(args) > 5 else 1 - groups = args[6] if len(args) > 6 else 1 - return self._conv1d_impl( - x, - weight, - bias=bias, - strides=stride, - padding=padding, - dilation=dilation, - groups=groups, - ) - def _conv1d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -425,55 +203,6 @@ def _conv2d_module(self, node: fx.Node) -> relax.Var: groups=module.groups, ) - def _conv3d_impl( - self, - x: relax.Expr, - weight: relax.Expr, - bias: Optional[relax.Expr], - strides: Optional[Tuple], - padding: Optional[Tuple], - dilation: Optional[Tuple], - groups: Optional[Tuple], - ): - conv3d = self.block_builder.emit( - relax.op.nn.conv3d( - x, - weight, - strides=strides, - padding=padding, - dilation=dilation, - groups=groups, - data_layout="NCDHW", - kernel_layout="OIDHW", - out_dtype="float32", - ) - ) - - if bias is None: - return conv3d - assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1, 1, 1)) - return self.block_builder.emit(relax.op.add(conv3d, bias)) - - def _conv3d(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - stride = args[3] if len(args) > 3 else 1 - padding = args[4] if len(args) > 4 else 0 - dilation = args[5] if len(args) > 5 else 1 - groups = args[6] if len(args) > 6 else 1 - return self._conv3d_impl( - x, - weight, - bias=bias, - strides=stride, - padding=padding, - dilation=dilation, - groups=groups, - ) - def _conv3d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -524,30 +253,6 @@ def _cross_entropy_module(self, node: fx.Node) -> relax.Expr: ) ) - def _einsum(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - args = self.retrieve_args(node) - operands = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] - return self.block_builder.emit(relax.op.einsum(operands, args[0])) - - def _embedding_impl( - self, - x, - weight, - ) -> relax.Var: - x = self.block_builder.emit(relax.op.astype(x, "int32")) - - ndim = x.struct_info.ndim - if ndim == 1: - return self.block_builder.emit(relax.op.take(weight, x, axis=0)) - else: - x_shape = x.struct_info.shape.values - emb_size = weight.struct_info.shape.values[-1] - x = self.block_builder.emit(relax.op.reshape(x, shape=[-1])) - embedding = self.block_builder.emit(relax.op.take(weight, x, axis=0)) - return self.block_builder.emit(relax.op.reshape(embedding, [*x_shape, emb_size])) - def _embedding_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -655,61 +360,6 @@ def _interpolate(self, node: fx.Node) -> relax.Var: ) ) - def _layer_norm_impl(self, x, gamma, beta, eps, normalized_shape) -> relax.Var: - from torch.fx.immutable_collections import immutable_list - import numpy as np # type: ignore - - if isinstance(normalized_shape, (immutable_list, tuple)): - normalized_shape = tuple(normalized_shape) - else: - try: - normalized_shape = self.env[normalized_shape] - except TypeError: - normalized_shape = tuple(normalized_shape) - - dim_num = len(normalized_shape) - axes = list(range(-dim_num, 0)) - - if gamma is None: - shape_tuple = [int(s) for s in normalized_shape] - gamma = relax.const(np.ones(shape_tuple), x.struct_info.dtype) - if beta is None: - shape_tuple = [int(s) for s in normalized_shape] - beta = relax.const(np.zeros(shape_tuple), x.struct_info.dtype) - - return self.block_builder.emit( - relax.op.nn.layer_norm( - x, - gamma, - beta, - axes=axes, - epsilon=eps, - ) - ) - - def _layer_norm(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - normalized_shape = node.args[1] - gamma = self.env[node.args[2]] if len(node.args) > 2 else None - beta = self.env[node.args[3]] if len(node.args) > 3 else None - eps = node.args[4] if len(node.args) > 4 else 1e-05 - return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape) - - def _layer_norm_module(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - x = self.env[node.args[0]] - module = self.named_modules[node.target] - normalized_shape = module.normalized_shape - if module.elementwise_affine: - gamma = self.params[module.weight] - beta = self.params[module.bias] - else: - gamma = relax.const(torch.ones_like(module.normalized_shape), x.struct_info.dtype) - beta = relax.const(torch.zeros_like(module.normalized_shape), x.struct_info.dtype) - eps = module.eps - return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape) - def _linear_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -728,39 +378,6 @@ def _max_pool2d_module(self, node: fx.Node) -> relax.Var: return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) - def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: - transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3]) - query = transpose_S_H(self.env[node.args[0]]) - key = transpose_S_H(self.env[node.args[1]]) - value = transpose_S_H(self.env[node.args[2]]) - attn_mask = node.args[3] if len(node.args) > 3 else node.kwargs.get("attn_mask", None) - dropout_p = node.args[4] if len(node.args) > 4 else node.kwargs.get("dropout_p", 0.0) - assert dropout_p == 0.0, "Dropout is not supported" - is_causal = node.args[5] if len(node.args) > 5 else node.kwargs.get("is_causal", False) - causal_mask = "TopLeft" if is_causal else None - - if attn_mask is not None: - attn_mask = self.env[attn_mask] - msg = "Only a float mask is supported for the attn_mask input." - assert "float" in attn_mask.struct_info.dtype, msg - - return self.block_builder.emit( - transpose_S_H( - relax.op.nn.attention(query, key, value, bias=attn_mask, causal_mask=causal_mask) - ) - ) - - def _unbind(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) - assert isinstance(dim, int), "Expected 2nd argument of unbind as int" - selections = self.shape_of(x)[dim].value - n_section = list(range(1, selections + 1)) - ret, split = [], self.block_builder.emit(relax.op.split(x, n_section, dim)) - for i in range(selections): - ret.append(self.block_builder.emit(relax.op.squeeze(split[i], axis=dim))) - return self.block_builder.emit(relax.Tuple(ret)) - ########## Manipulation ########## def _cat(self, node: fx.Node) -> relax.Var: @@ -1054,87 +671,6 @@ def _getattr(self, node: fx.Node) -> relax.Var: return self.shape_of(self.env[node.args[0]]) return getattr(self.env[node.args[0]], node.args[1]) - def _getitem(self, node: fx.Node) -> relax.Var: - import torch - - x = self.env[node.args[0]] - if isinstance(x, (list, tuple, relax.ShapeExpr, relax.Tuple)): - return x[node.args[1]] - elif isinstance(x, relax.Var): - if isinstance(x.struct_info, relax.TupleStructInfo): - return self.block_builder.emit(relax.TupleGetItem(x, node.args[1])) - - assert isinstance(x.struct_info, relax.TensorStructInfo) - take_indices = [] - take_axes = [] - stride_begin = [] - stride_end = [] - stride = [] - stride_axes = [] - expand_dim = [] - i = 0 - shape = self.shape_of(x) - non_ellipsis_cnt = 0 - for index in node.args[1]: - if isinstance(index, (int, slice, torch.fx.Node)): - non_ellipsis_cnt += 1 - for index in node.args[1]: - if isinstance(index, int): - stride_begin.append(index) - stride_end.append(index + 1) - stride.append(1) - stride_axes.append(i) - i = i + 1 - elif isinstance(index, slice): - stride_begin.append(0 if index.start is None else index.start) - stride_end.append(shape[i] if index.stop is None else index.stop) - stride.append(1 if index.step is None else index.step) - stride_axes.append(i) - i = i + 1 - elif index is None: - expand_dim.append(len(stride_axes) + len(expand_dim)) - elif index is Ellipsis: - for _ in range(len(shape) - non_ellipsis_cnt): - stride_begin.append(0) - stride_end.append(shape[i]) - stride.append(1) - stride_axes.append(i) - i += 1 - elif isinstance(index, torch.fx.Node): - node_index = self.env[index] - if not isinstance(node_index, relax.Expr): - raise ValueError( - "Unsupported index type for relax.op.take: " + str(type(node_index)) - ) - take_indices.append(node_index) - take_axes.append(i) - i = i + 1 - else: - raise ValueError("Unsupported index type: " + str(type(index))) - while i < len(shape): - stride_begin.append(0) - stride_end.append(shape[i]) - stride.append(1) - stride_axes.append(i) - i += 1 - taken = x - if len(take_indices) > 1: - raise ValueError("Multiple tensors as index not yet supported") - for each_index, each_axis in zip(take_indices, take_axes): - taken = self.block_builder.emit(relax.op.take(taken, each_index, each_axis)) - sliced = self.block_builder.emit( - relax.op.strided_slice(taken, stride_axes, stride_begin, stride_end, stride) - ) - sliced_shape = list(self.shape_of(sliced)) - for i in expand_dim: - sliced_shape.insert(i, 1) - return self.block_builder.emit(relax.op.reshape(sliced, sliced_shape)) - elif isinstance(x, relax.Constant): - dtype = x.struct_info.dtype - return relax.const(x.data.numpy()[node.args[1]], dtype) - else: - assert False - def _sym_size_int(self, node: fx.Node) -> relax.Expr: x = self.env[node.args[0]] shape = self.shape_of(x) @@ -1182,8 +718,8 @@ def create_convert_map( nn.Conv1d: self._conv1d_module, nn.Conv2d: self._conv2d_module, nn.Conv3d: self._conv3d_module, - nn.ConvTranspose1d: self._conv1d_transpose_module, - nn.ConvTranspose2d: self._conv2d_transpose_module, + nn.ConvTranspose1d: self._conv_transpose1d_module, + nn.ConvTranspose2d: self._conv_transpose2d_module, nn.CrossEntropyLoss: self._cross_entropy_module, nn.GroupNorm: self._group_norm_module, nn.LayerNorm: self._layer_norm_module, @@ -1248,8 +784,8 @@ def create_convert_map( "bmm": self._binary_op( partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul ), - "conv_transpose1d": self._conv1d_transpose, - "conv_transpose2d": self._conv2d_transpose, + "conv_transpose1d": self._conv_transpose1d, + "conv_transpose2d": self._conv_transpose2d, "conv1d": self._conv1d, "conv2d": self._conv2d, "conv3d": self._conv3d, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 25e6dbfae308..7c887d9b9610 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1156,6 +1156,59 @@ def main( verify_model(Sub2(), example_args2, {}, expected_sub2) +def test_batchnorm2d(): + class BatchNorm2d(Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm2d(3) + + def forward(self, input): + return self.bn(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((3,), dtype="float32"), + w2: R.Tensor((3,), dtype="float32"), + w3: R.Tensor((3,), dtype="float32"), + w4: R.Tensor((3,), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((3,), dtype="float32"), + R.Tensor((3,), dtype="float32"), + ) = R.nn.batch_norm( + input_1, + w1, + w2, + w3, + w4, + axis=1, + epsilon=1e-05, + center=True, + scale=True, + ) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0] + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + + model = BatchNorm2d().eval() + binding = { + "w1": model.bn.weight.detach().numpy(), + "w2": model.bn.bias.detach().numpy(), + "w3": model.bn.running_mean.detach().numpy(), + "w4": model.bn.running_var.detach().numpy(), + } + verify_model(model, example_args, binding, expected1) + + def test_adaptive_avgpool2d(): class AdaptiveAvgPool2d0(Module): def __init__(self): @@ -1165,28 +1218,594 @@ def __init__(self): def forward(self, input): return self.pool(input) - class AdaptiveAvgPool2d1(Module): + class AdaptiveAvgPool2d1(Module): + def forward(self, input): + return torch.nn.functional.adaptive_avg_pool2d(input, [10, 10]) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.adaptive_avg_pool2d( + input_1, output_size=[10, 10], layout="NCHW", out_layout="NCHW" + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(AdaptiveAvgPool2d0(), example_args, {}, expected1) + verify_model(AdaptiveAvgPool2d1(), example_args, {}, expected1) + + +def test_addmm(): + class Addmm1(Module): + def __init__(self): + super().__init__() + + def forward(self, x1, x2, x3): + return torch.addmm(x1, x2, x3) + + class Addmm2(Module): + def __init__(self): + super().__init__() + + def forward(self, x1, x2, x3): + return torch.addmm(x1, x2, x3, beta=0.8, alpha=0.5) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x1: R.Tensor((10, 10), dtype="float32"), + x2: R.Tensor((10, 10), dtype="float32"), + x3: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.matmul(x2, x3, out_dtype="float32") + lv1: R.Tensor((10, 10), dtype="float32") = R.add(x1, lv) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected2: + @R.function + def main( + x1: R.Tensor((10, 10), dtype="float32"), + x2: R.Tensor((10, 10), dtype="float32"), + x3: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.matmul(x2, x3, out_dtype="float32") + lv1: R.Tensor((10, 10), dtype="float32") = R.multiply(lv, R.const(0.5, "float32")) + lv2: R.Tensor((10, 10), dtype="float32") = R.multiply(x1, R.const(0.8, "float32")) + lv3: R.Tensor((10, 10), dtype="float32") = R.add(lv2, lv1) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + example_args = ( + torch.randn(10, 10, dtype=torch.float32), + torch.randn(10, 10, dtype=torch.float32), + torch.randn(10, 10, dtype=torch.float32), + ) + + verify_model(Addmm1(), example_args, {}, expected1) + verify_model(Addmm2(), example_args, {}, expected2) + + +def test_avg_pool2d(): + class AvgPool2d1(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AvgPool2d(kernel_size=[1, 1]) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.avg_pool2d( + input_1, + pool_size=[1, 1], + strides=[1, 1], + dilation=[1, 1], + padding=[0, 0, 0, 0], + layout="NCHW", + out_layout="NCHW", + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class AvgPool2d2(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AvgPool2d(kernel_size=[4, 4], stride=2, padding=2, ceil_mode=True) + + def forward(self, input): + return self.pool(input) + + class AvgPool2d3(Module): + def forward(self, input): + return torch.nn.functional.avg_pool2d( + input, kernel_size=[4, 4], stride=2, padding=2, ceil_mode=True + ) + + @tvm.script.ir_module + class expected2: + @R.function + def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv = R.nn.avg_pool2d( + input_1, + pool_size=[4, 4], + strides=[2, 2], + dilation=[1, 1], + padding=[2, 2, 2, 2], + ceil_mode=True, + layout="NCHW", + out_layout="NCHW", + ) + gv = (lv,) + R.output(gv) + return gv + + class AvgPool2d4(Module): + def forward(self, input): + return torch.nn.functional.avg_pool2d(input, kernel_size=[2, 1], divisor_override=2) + + @tvm.script.ir_module + class expected3: + @R.function + def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv = R.nn.avg_pool2d( + input_1, + pool_size=[2, 1], + strides=[2, 1], + dilation=[1, 1], + padding=[0, 0, 0, 0], + ceil_mode=False, + layout="NCHW", + out_layout="NCHW", + ) + gv = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(AvgPool2d1(), example_args, {}, expected1) + verify_model(AvgPool2d2(), example_args, {}, expected2) + verify_model(AvgPool2d3(), example_args, {}, expected2) + verify_model(AvgPool2d4(), example_args, {}, expected3) + + +def test_baddbmm(): + class BAddBMM1(Module): + def __init__(self): + super().__init__() + + def forward(self, c, x, y): + return torch.baddbmm(c, x, y) + + @tvm.script.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((4, 128, 512), dtype="float32"), + inp_1: R.Tensor((4, 128, 256), dtype="float32"), + inp_2: R.Tensor((4, 256, 512), dtype="float32"), + ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1, inp_2) + lv1: R.Tensor((4, 128, 512), dtype="float32") = R.add(lv, inp_0) + gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + class BAddBMM2(Module): + def __init__(self): + super().__init__() + + def forward(self, c, x, y): + return torch.baddbmm(c, x, y, alpha=2, beta=0) + + @tvm.script.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((4, 128, 512), dtype="float32"), + inp_1: R.Tensor((4, 128, 256), dtype="float32"), + inp_2: R.Tensor((4, 256, 512), dtype="float32"), + ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1, inp_2) + lv1: R.Tensor((4, 128, 512), dtype="float32") = R.multiply( + lv, R.const(2, "float32") + ) + gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + class BAddBMM3(Module): + def __init__(self): + super().__init__() + + def forward(self, c, x, y): + return torch.baddbmm(c, x, y, alpha=2, beta=3) + + @tvm.script.ir_module + class Expected3: + @R.function + def main( + inp_0: R.Tensor((4, 128, 512), dtype="float32"), + inp_1: R.Tensor((4, 128, 256), dtype="float32"), + inp_2: R.Tensor((4, 256, 512), dtype="float32"), + ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1, inp_2) + lv1: R.Tensor((4, 128, 512), dtype="float32") = R.multiply( + lv, R.const(2, "float32") + ) + lv2: R.Tensor((4, 128, 512), dtype="float32") = R.multiply( + inp_0, R.const(3, "float32") + ) + lv3: R.Tensor((4, 128, 512), dtype="float32") = R.add(lv1, lv2) + gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + example_args = ( + torch.randn(4, 128, 512, dtype=torch.float32), + torch.randn(4, 128, 256, dtype=torch.float32), + torch.randn(4, 256, 512, dtype=torch.float32), + ) + verify_model( + BAddBMM1(), + example_args, + {}, + Expected1, + ) + + verify_model( + BAddBMM2(), + example_args, + {}, + Expected2, + ) + + verify_model( + BAddBMM3(), + example_args, + {}, + Expected3, + ) + + +def test_bmm(): + class BMM(Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.bmm(x, y) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input_1: R.Tensor((4, 128, 256), dtype="float32"), + input_2: R.Tensor((4, 256, 512), dtype="float32"), + ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul( + input_1, input_2, out_dtype="float32" + ) + gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = ( + torch.randn(4, 128, 256, dtype=torch.float32), + torch.randn(4, 256, 512, dtype=torch.float32), + ) + verify_model( + BMM(), + example_args, + {}, + Expected, + ) + + +def test_conv_transpose1d(): + class ConvTranspose1d1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.ConvTranspose1d(6, 6, 3, bias=True) + + def forward(self, input): + return self.conv(input) + + class ConvTranspose1d1Func(Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(size=[6, 6, 3]) + self.bias = torch.randn(size=[6]) + + def forward(self, input): + return torch.nn.functional.conv_transpose1d(input, self.weight, self.bias) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 6, 4), dtype="float32"), + w1: R.Tensor((6, 6, 3), dtype="float32"), + w2: R.Tensor((6,), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 6, 6), dtype="float32")): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 6), dtype="float32") = R.nn.conv1d_transpose( + input_1, + w1, + strides=[1], + padding=[0, 0], + dilation=[1], + data_layout="NCW", + kernel_layout="OIW", + out_layout="NCW", + out_dtype="float32", + ) + lv2: R.Tensor((1, 6, 1)) = R.reshape(w2, [1, 6, 1]) + lv3: R.Tensor((1, 6, 6), dtype="float32") = R.add(lv1, lv2) + gv: R.Tuple(R.Tensor((1, 6, 6), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + class ConvTranspose1d2(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.ConvTranspose1d(6, 6, 3, bias=False) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 6, 4), dtype="float32"), + w1: R.Tensor((6, 6, 3), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 6, 6), dtype="float32")): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 6), dtype="float32") = R.nn.conv1d_transpose( + input_1, + w1, + strides=[1], + padding=[0, 0], + dilation=[1], + data_layout="NCW", + kernel_layout="OIW", + out_layout="NCW", + out_dtype="float32", + ) + gv: R.Tuple(R.Tensor((1, 6, 6), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 6, 4, dtype=torch.float32),) + + model = ConvTranspose1d1() + binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + model = ConvTranspose1d1Func() + binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + model = ConvTranspose1d2() + binding = {"w1": model.conv.weight.detach().numpy()} + verify_model(model, example_args, binding, expected2) + + +def test_conv_transpose2d(): + class ConvTranspose2d1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.ConvTranspose2d(3, 3, 7, bias=True) + + def forward(self, input): + return self.conv(input) + + class ConvTranspose2d1Func(Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(size=[3, 3, 7, 7]) + self.bias = torch.randn(size=[3]) + + def forward(self, input): + return torch.nn.functional.conv_transpose2d(input, self.weight, self.bias) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((3, 3, 7, 7), dtype="float32"), + w2: R.Tensor((3,), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 16, 16), dtype="float32")): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 3, 16, 16), dtype="float32") = R.nn.conv2d_transpose( + input_1, + w1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + lv2: R.Tensor((1, 3, 1, 1)) = R.reshape(w2, [1, 3, 1, 1]) + lv3: R.Tensor((1, 3, 16, 16), dtype="float32") = R.add(lv1, lv2) + gv: R.Tuple(R.Tensor((1, 3, 16, 16), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + class ConvTranspose2d2(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.ConvTranspose2d(3, 3, 7, bias=False) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((3, 3, 7, 7), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 16, 16), dtype="float32")): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 3, 16, 16), dtype="float32") = R.nn.conv2d_transpose( + input_1, + w1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + gv: R.Tuple(R.Tensor((1, 3, 16, 16), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + + model = ConvTranspose2d1() + binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + model = ConvTranspose2d1Func() + binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + model = ConvTranspose2d2() + binding = {"w1": model.conv.weight.detach().numpy()} + verify_model(model, example_args, binding, expected2) + + +def test_conv1d(): + class Conv1D1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d(3, 6, 7, bias=True) + + def forward(self, input): + return self.conv(input) + + class Conv1D1Func(Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(size=[6, 3, 7]) + self.bias = torch.randn(size=[6]) + + def forward(self, input): + return torch.nn.functional.conv1d(input, self.weight, self.bias) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + w1: R.Tensor((6, 3, 7), dtype="float32"), + w2: R.Tensor((6,), dtype="float32"), + input_1: R.Tensor((1, 3, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 6, 4), dtype="float32")): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4), dtype="float32") = R.nn.conv1d( + input_1, + w1, + strides=[1], + padding=[0, 0], + dilation=[1], + data_layout="NCW", + kernel_layout="OIW", + out_layout="NCW", + out_dtype="float32", + ) + lv2: R.Tensor((1, 6, 1), dtype="float32") = R.reshape(w2, [1, 6, 1]) + lv3: R.Tensor((1, 6, 4), dtype="float32") = R.add(lv1, lv2) + gv: R.Tuple(R.Tensor((1, 6, 4), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + class Conv1D2(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d(3, 6, 7, bias=False) + def forward(self, input): - return torch.nn.functional.adaptive_avg_pool2d(input, [10, 10]) + return self.conv(input) @tvm.script.ir_module - class expected1: + class expected2: @R.function def main( - input_1: R.Tensor((1, 3, 10, 10), dtype="float32") - ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + w1: R.Tensor((6, 3, 7), dtype="float32"), + input_1: R.Tensor((1, 3, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 6, 4), dtype="float32")): # block 0 with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.adaptive_avg_pool2d( - input_1, output_size=[10, 10], layout="NCHW", out_layout="NCHW" + lv1: R.Tensor((1, 6, 4), dtype="float32") = R.nn.conv1d( + input_1, + w1, + strides=[1], + padding=[0, 0], + dilation=[1], + data_layout="NCW", + kernel_layout="OIW", + out_layout="NCW", + out_dtype="float32", ) - gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + gv: R.Tuple(R.Tensor((1, 6, 4), dtype="float32")) = (lv1,) R.output(gv) return gv - example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) - verify_model(AdaptiveAvgPool2d0(), example_args, {}, expected1) - verify_model(AdaptiveAvgPool2d1(), example_args, {}, expected1) + example_args = (torch.randn(1, 3, 10, dtype=torch.float32),) + + model = Conv1D1() + binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + model = Conv1D1Func() + binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + model = Conv1D2() + binding = {"w1": model.conv.weight.detach().numpy()} + verify_model(model, example_args, binding, expected2) def test_conv2d(): @@ -1281,6 +1900,267 @@ def main( verify_model(model, example_args, binding, expected2) +def test_conv3d(): + class Conv3D1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv3d(3, 6, 7, bias=True) + + def forward(self, input): + return self.conv(input) + + class Conv3D1Func(Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(size=[6, 3, 7, 7, 7]) + self.bias = torch.randn(size=[6]) + + def forward(self, input): + return torch.nn.functional.conv3d(input, self.weight, self.bias) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32"), + w1: R.Tensor((6, 3, 7, 7, 7), dtype="float32"), + w2: R.Tensor((6,), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 6, 4, 4, 4), dtype="float32")): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.nn.conv3d( + input_1, + w1, + strides=[1], + padding=[0, 0, 0], + dilation=[1], + data_layout="NCDHW", + kernel_layout="OIDHW", + out_layout="NCDHW", + out_dtype="float32", + ) + lv2: R.Tensor((1, 6, 1, 1, 1)) = R.reshape(w2, [1, 6, 1, 1, 1]) + lv3: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.add(lv1, lv2) + gv: R.Tuple(R.Tensor((1, 6, 4, 4, 4), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + class Conv3D2(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv3d(3, 6, 7, bias=False) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32"), + w1: R.Tensor((6, 3, 7, 7, 7), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 6, 4, 4, 4), dtype="float32")): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.nn.conv3d( + input_1, + w1, + strides=[1], + padding=[0, 0, 0], + dilation=[1], + data_layout="NCDHW", + kernel_layout="OIDHW", + out_layout="NCDHW", + out_dtype="float32", + ) + gv: R.Tuple(R.Tensor((1, 6, 4, 4, 4), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, 10, dtype=torch.float32),) + + model = Conv3D1() + binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + model = Conv3D1Func() + binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + model = Conv3D2() + binding = {"w1": model.conv.weight.detach().numpy()} + verify_model(model, example_args, binding, expected2) + + +def test_einsum(): + class Einsum1(Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.einsum("ii", x) + + class Einsum2(Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.einsum("i,j->ij", x, y) + + @tvm.script.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((4, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((), dtype="float32") = R.einsum((inp_0,), subscripts="ii") + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((5,), dtype="float32"), inp_1: R.Tensor((4,), dtype="float32") + ) -> R.Tuple(R.Tensor((5, 4), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((5, 4), dtype="float32") = R.einsum( + (inp_0, inp_1), subscripts="i,j->ij" + ) + gv: R.Tuple(R.Tensor((5, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(4, 4, dtype=torch.float32),) + verify_model(Einsum1(), example_args, {}, Expected1) + + example_args = (torch.randn(5, dtype=torch.float32), torch.randn(4, dtype=torch.float32)) + verify_model(Einsum2(), example_args, {}, Expected2) + + +def test_embedding(): + class Embedding(Module): + def __init__(self): + super().__init__() + self.embedding = torch.nn.Embedding(10, 3) + + def forward(self, input): + return self.embedding(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((4,), dtype="int64"), w1: R.Tensor((10, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((4, 3), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((4,), dtype="int32") = R.astype(input_1, dtype="int32") + lv1: R.Tensor((4, 3), dtype="float32") = R.take(w1, lv, axis=0) + gv: R.Tuple(R.Tensor((4, 3), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + example_args = (torch.randint(low=-int(1e5), high=int(1e5), size=(4,), dtype=torch.int64),) + + model = Embedding() + binding = {"w1": model.embedding.weight.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + +def test_groupnorm(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + class GroupNorm(Module): + def __init__(self): + super().__init__() + self.gn = torch.nn.GroupNorm(3, 3) + + def forward(self, input): + return self.gn(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((3,), dtype="float32"), + w2: R.Tensor((3,), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.group_norm( + input_1, + w1, + w2, + num_groups=3, + channel_axis=1, + axes=[2, 3], + epsilon=1.0000000000000001e-05, + center=True, + scale=True, + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + + model = GroupNorm() + binding = { + "w1": model.gn.weight.detach().numpy(), + "w2": model.gn.bias.detach().numpy(), + } + verify_model(model, example_args, binding, expected1) + + +def test_layernorm(): + class LayerNorm(Module): + def __init__(self): + super().__init__() + self.ln = torch.nn.LayerNorm((10, 10)) + + def forward(self, input): + return self.ln(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((10, 10), dtype="float32"), + w2: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.layer_norm( + input_1, + w1, + w2, + axes=[-2, -1], + epsilon=1e-05, + center=True, + scale=True, + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + + model = LayerNorm() + binding = { + "w1": model.ln.weight.detach().numpy(), + "w2": model.ln.bias.detach().numpy(), + } + verify_model(LayerNorm(), example_args, binding, expected1) + + def test_linear(): class Dense1(Module): def __init__(self): @@ -1460,6 +2340,254 @@ def main( verify_model(MaxPool2d3(), example_args, {}, expected3) +def test_scaled_dot_product_attention(): + class Attention1(Module): + def forward(self, q, k, v): + return torch.nn.functional.scaled_dot_product_attention(q, k, v) + + @I.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((32, 8, 128, 64), dtype="float32"), + inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"), + inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"), + ) -> R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( + inp_0, axes=[0, 2, 1, 3] + ) + lv1: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( + inp_1, axes=[0, 2, 1, 3] + ) + lv2: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( + inp_2, axes=[0, 2, 1, 3] + ) + lv3: R.Tensor((32, 128, 8, 64), dtype="float32") = R.nn.attention( + lv, lv1, lv2, scale=None + ) + lv4: R.Tensor((32, 8, 128, 64), dtype="float32") = R.permute_dims( + lv3, axes=[0, 2, 1, 3] + ) + gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) = (lv4,) + R.output(gv) + return gv + + class Attention2(Module): + def forward(self, q, k, v, mask): + return torch.nn.functional.scaled_dot_product_attention(q, k, v, mask) + + @I.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((32, 8, 128, 64), dtype="float32"), + inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"), + inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"), + inp_3: R.Tensor((32, 8, 128, 128), dtype="float32"), + ) -> R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( + inp_0, axes=[0, 2, 1, 3] + ) + lv1: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( + inp_1, axes=[0, 2, 1, 3] + ) + lv2: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( + inp_2, axes=[0, 2, 1, 3] + ) + lv3: R.Tensor((32, 128, 8, 64), dtype="float32") = R.nn.attention( + lv, lv1, lv2, inp_3, scale=None + ) + lv4: R.Tensor((32, 8, 128, 64), dtype="float32") = R.permute_dims( + lv3, axes=[0, 2, 1, 3] + ) + gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) = (lv4,) + R.output(gv) + return gv + + verify_model( + Attention1(), + ( + torch.randn(32, 8, 128, 64, dtype=torch.float32), + torch.randn(32, 8, 128, 64, dtype=torch.float32), + torch.randn(32, 8, 128, 64, dtype=torch.float32), + ), + {}, + Expected1, + ) + + verify_model( + Attention2(), + ( + torch.randn(32, 8, 128, 64, dtype=torch.float32), + torch.randn(32, 8, 128, 64, dtype=torch.float32), + torch.randn(32, 8, 128, 64, dtype=torch.float32), + torch.randn(32, 8, 128, 128, dtype=torch.float32), + ), + {}, + Expected2, + ) + + +def test_unbind(): + class Unbind1(Module): + def forward(self, data): + return torch.unbind(data) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((3, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((0, 3, 10, 10), dtype="float32"), + ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=0) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0] + lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[0]) + lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[1] + lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, axis=[0]) + lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[2] + lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, axis=[0]) + lv7: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = (lv2, lv4, lv6) + lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0] + lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1] + lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2] + gv: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = (lv8, lv9, lv10) + R.output(gv) + return gv + + class Unbind2(Module): + def forward(self, data): + return torch.unbind(data, dim=1) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((3, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((3, 1, 10, 10), dtype="float32"), + R.Tensor((3, 1, 10, 10), dtype="float32"), + R.Tensor((3, 1, 10, 10), dtype="float32"), + R.Tensor((3, 0, 10, 10), dtype="float32"), + ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=1) + lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[0] + lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[1]) + lv3: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[1] + lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, axis=[1]) + lv5: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[2] + lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, axis=[1]) + lv7: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = (lv2, lv4, lv6) + lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0] + lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1] + lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2] + gv: R.Tuple( + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + R.Tensor((3, 10, 10), dtype="float32"), + ) = (lv8, lv9, lv10) + R.output(gv) + return gv + + example_args = (torch.randn(3, 3, 10, 10, dtype=torch.float32),) + verify_model(Unbind1(), example_args, {}, expected1) + verify_model(Unbind2(), example_args, {}, expected2) + + +def test_interpolate(): + class InterpolateBilinear(Module): + def forward(self, input): + return torch.nn.functional.interpolate(input, (224, 224), mode="bilinear") + + @tvm.script.ir_module + class expected_bilinear: + @R.function + def main( + input: R.Tensor((1, 3, 112, 112), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 224, 224), dtype="float32") = R.image.resize2d( + input, + R.shape([224, 224]), + roi=[T.float32(0.0), T.float32(0.0), T.float32(0.0), T.float32(0.0)], + layout="NCHW", + method="linear", + coordinate_transformation_mode="half_pixel", + rounding_method="round", + cubic_alpha=-0.5, + cubic_exclude=0, + extrapolation_value=0.0, + out_dtype="void", + ) + gv: R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class InterpolateNearest(Module): + def forward(self, input): + return torch.nn.functional.interpolate(input, (224, 224), mode="nearest") + + @tvm.script.ir_module + class expected_nearest: + @R.function + def main( + input: R.Tensor((1, 3, 112, 112), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 224, 224), dtype="float32") = R.image.resize2d( + input, + R.shape([224, 224]), + roi=[T.float32(0.0), T.float32(0.0), T.float32(0.0), T.float32(0.0)], + layout="NCHW", + method="nearest_neighbor", + coordinate_transformation_mode="half_pixel", + rounding_method="round", + cubic_alpha=-0.5, + cubic_exclude=0, + extrapolation_value=0.0, + out_dtype="void", + ) + gv: R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 112, 112, dtype=torch.float32),) + verify_model(InterpolateBilinear(), example_args, {}, expected_bilinear) + verify_model(InterpolateNearest(), example_args, {}, expected_nearest) + + def test_mean(): class Mean(Module): def forward(self, input):