Skip to content

Commit

Permalink
fix test_graph_input_use_in_if work on llvm test
Browse files Browse the repository at this point in the history
  • Loading branch information
gangmul12 committed Aug 22, 2023
1 parent c5b028f commit 21c6a07
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5157,7 +5157,7 @@ def verify_if(num_nested, cond):
input_tensor = helper.make_tensor_value_info("graph_input", TensorProto.FLOAT, [1])
output_tensor = helper.make_tensor_value_info("graph_output", TensorProto.FLOAT, [1])
constant_node = make_constant_node("const_val", TensorProto.FLOAT, [1], [-1])
cond_tensor = helper.make_tensor_value_info("cond", TensorProto.BOOL, [])
cond_tensor = helper.make_tensor_value_info("cond", TensorProto.BOOL, [1])
inner_if_node = None
for i in range(num_nested):
identity_node = helper.make_node(
Expand Down Expand Up @@ -5200,7 +5200,12 @@ def verify_if(num_nested, cond):
inner_if_node = if_node
else:
then_branch = helper.make_graph(
[inner_if_node], f"then{i}_body", [], [f"if_output{i-1}"]
[inner_if_node],
f"then{i}_body",
inputs=[],
outputs=[
helper.make_tensor_value_info(f"if_output{i-1}", TensorProto.FLOAT, [1])
],
)
if_node = helper.make_node(
"If",
Expand All @@ -5222,10 +5227,12 @@ def verify_if(num_nested, cond):

verify_with_ort_with_inputs(
model,
[np.array([3.0], dtype="float32"), [cond]],
[np.array([3.0], dtype="float32"), np.array([cond])],
dtype="float32",
use_vm=True,
opset=14,
target=target,
dev=dev,
)

# Confirm that if works with cond as an array or scalar.
Expand Down

0 comments on commit 21c6a07

Please sign in to comment.