Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

arith constant in stable HLO graph #3803

Open
AleksKnezevic opened this issue Oct 18, 2024 · 0 comments
Open

arith constant in stable HLO graph #3803

AleksKnezevic opened this issue Oct 18, 2024 · 0 comments

Comments

@AleksKnezevic
Copy link

When converting a simple linear module with bias into a stable HLO graph, a constant comes out in arith dialect. Is this expected?

Here is the code to produce it:

def test_linear_with_bias():
  class Basic(nn.Module):
    def __init__(self):
      super().__init__()
      self.linear_a = nn.Linear(32, 32)

    def forward(self, x):
      x = self.linear_a(x)
      return x

  mod = fx.export_and_import(Basic(), torch.randint(0, 100, (1, 32)), output_type=OutputType.STABLEHLO)
  mod.dump()

and here is the result:

test/test_basic.py::test_linear_with_bias module {
  func.func @main(%arg0: tensor<1x32xi64>) -> tensor<1x32xf32> {
    %cst = stablehlo.constant dense_resource<torch_tensor_32_torch.float32> : tensor<32xf32>
    %cst_0 = stablehlo.constant dense_resource<torch_tensor_32_32_torch.float32> : tensor<32x32xf32>
    %cst_1 = arith.constant dense<1> : tensor<1xi64>
    %0 = stablehlo.transpose %cst_0, dims = [1, 0] : (tensor<32x32xf32>) -> tensor<32x32xf32>
    %1 = stablehlo.convert %arg0 : (tensor<1x32xi64>) -> tensor<1x32xf32>
    %2 = stablehlo.dot_general %1, %0, contracting_dims = [1] x [0] : (tensor<1x32xf32>, tensor<32x32xf32>) -> tensor<1x32xf32>
    %3 = stablehlo.convert %cst_1 : (tensor<1xi64>) -> tensor<1xf32>
    %4 = stablehlo.reshape %3 : (tensor<1xf32>) -> tensor<f32>
    %5 = stablehlo.broadcast_in_dim %2, dims = [0, 1] : (tensor<1x32xf32>) -> tensor<1x32xf32>
    %6 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor<f32>) -> tensor<1x32xf32>
    %7 = stablehlo.multiply %5, %6 : tensor<1x32xf32>
    %8 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor<f32>) -> tensor<32xf32>
    %9 = stablehlo.broadcast_in_dim %cst, dims = [0] : (tensor<32xf32>) -> tensor<32xf32>
    %10 = stablehlo.multiply %8, %9 : tensor<32xf32>
    %11 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x32xf32>) -> tensor<1x32xf32>
    %12 = stablehlo.broadcast_in_dim %10, dims = [1] : (tensor<32xf32>) -> tensor<1x32xf32>
    %13 = stablehlo.add %11, %12 : tensor<1x32xf32>
    return %13 : tensor<1x32xf32>
  }
}

I'm using torch-mlir to interface with a compiler that expects only stable HLO dialect input. I'm happy to either add support for arith constants in my compiler, or change how it's emitted from torch-mlir and submit a PR, just wanted to know what the expected behaviour is.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant