Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possibly incorrect expected shapes for STFT op #1278

Open
alexander-camuto opened this issue Dec 4, 2023 · 0 comments
Open

Possibly incorrect expected shapes for STFT op #1278

alexander-camuto opened this issue Dec 4, 2023 · 0 comments

Comments

@alexander-camuto
Copy link
Contributor

Pytorch STFT expects a 1D or 2D tensor as inputs. Such that the following runs fine:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

    def forward(self, x):  
        stft = torch.stft(x, 8, 4, 8, return_complex=False)
        
        return stft 

circuit = MyModel()

x = torch.empty(1, 8).uniform_(0, 1)

out = circuit(x)

print(out)

However tract fails to analyze this model (onnx attached) with the following error:

0 Source input
┃   ━━━ batch_size,8,F32
┣┓ 
┃┣ 1 Shape /Shape
┃┃   ━━━ 2,TDim batch_size, 8
┃┣┻ 3 Gather /Gather
┃┃   ━━━ ,TDim batch_size
┃┣┻ 9 Unsqueeze13 /Unsqueeze
┃┃   ━━━ 1,TDim batch_size
┣━┓ 
┃┃┣ 4 Shape /Shape_1
┃┃┃   ━━━ 2,TDim batch_size, 8
┃┃┣┻ 6 Gather /Gather_1
┃┃┃   ━━━ ,TDim 8
┃┃┣┻ 11 Unsqueeze13 /Unsqueeze_1
┃┃┃   ━━━ 1,TDim 8
 
┃ 
┃┣┻┻ 12 InferenceConcat /Concat
┃┃   ━━━ 3,TDim 1, batch_size, 8
┣┻ 13 Reshape /Reshape
┃   ━━━ 1,batch_size,8,F32
┃┣ 16 ConstantOfShape /ConstantOfShape
┃┃   ━━━ 4,I64 0, 0, 0, 0
┃┣┻ 17 InferenceConcat /Concat_1
┃┃   ━━━ 6,I64 4, 4, 0, 0, 0, 0
┃┣┻ 19 Reshape /Reshape_1
┃┃   ━━━ 3,2,I64 4, 4, 0, 0, 0, 0
┃┣┻┻┻┻ 24 StridedSlice /Slice
┃┃   ━━━ 3,2,I64 0, 0, 0, 0, 4, 4
┃┣ 25 PermuteAxes /Transpose
┃┃   ━━━ 2,3,I64 0, 0, 4, 0, 0, 4
┃┣┻ 27 Reshape /Reshape_2
┃┃   ━━━ 6,I64 0, 0, 4, 0, 0, 4
┃┣ 28 onnx.Cast /Cast
┃┃   ━━━ 6,TDim 0, 0, 4, 0, 0, 4
┣┻ 29 Pad /Pad
┃   ━━━ 1,batch_size,16,F32
┣┓ 
┃┣ 30 Shape /Shape_2
┃┃   ━━━ 3,TDim 1, batch_size, 16
┃┣┻ 32 Gather /Gather_2
┃┃   ━━━ ,TDim batch_size
┃┣┻ 37 Unsqueeze13 /Unsqueeze_2
┃┃   ━━━ 1,TDim batch_size
┣━┓ 
┃┃┣ 33 Shape /Shape_3
┃┃┃   ━━━ 3,TDim 1, batch_size, 16
┃┃┣┻ 35 Gather /Gather_3
┃┃┃   ━━━ ,TDim 16
┃┃┣┻ 39 Unsqueeze13 /Unsqueeze_3
┃┃┃   ━━━ 1,TDim 16
┃┣┻ 40 InferenceConcat /Concat_2
┃┃   ━━━ 2,TDim batch_size, 16
┣┻ 41 Reshape /Reshape_3
┃   ━━━ batch_size,16,F32
┣┻┻┻ 45 STFT /STFT
┃   ━━━ ?,?,?,?,?
┣ 46 PermuteAxes /Transpose_1
    ━━━ ?,?,?,?,?
[2023-12-04T15:07:47.426876000Z ERROR tract] Error at stage analyse
    
    Caused by:
        0: Failed analyse for node #45 "/STFT" STFT
        1: Failed analyse for node #45 "/STFT" STFT
        2: Infering facts
        3: Applying rule inputs[0].rank == 3
        4: Impossible to unify 2 with 3.

network.onnx.zip

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant