Skip to content

Commit

Permalink
[Stablehlo] AtenEmptyMemoryFormat remove device cpu check (#2288)
Browse files Browse the repository at this point in the history
* remove cpu check

* update dtype

---------

Co-authored-by: zhekun.zhang <[email protected]>
  • Loading branch information
zhekunz2 and zhekunz2 authored Jul 10, 2023
1 parent 05920f9 commit 6a072d4
Showing 1 changed file with 1 addition and 6 deletions.
7 changes: 1 addition & 6 deletions lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1476,15 +1476,11 @@ LogicalResult ConvertAtenOp<AtenEmptyMemoryFormatOp>::matchAndRewrite(
"memory_format is supported");
}

// TODO: Add support for device arg other than cpu.
if (!op.getDevice().getType().isa<Torch::NoneType>()) {
std::string device;
if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device)))
return rewriter.notifyMatchFailure(
op, "unimplemented: device must be a constant str");
else if (device != "cpu")
return rewriter.notifyMatchFailure(
op, "unimplemented: device is expected to be cpu");
}

// TODO: Add support for non-strided layout.
Expand Down Expand Up @@ -1515,8 +1511,7 @@ LogicalResult ConvertAtenOp<AtenEmptyMemoryFormatOp>::matchAndRewrite(
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
Type resultElementType;
if (op.getDtype().getType().isa<Torch::NoneType>()) {
resultElementType =
getDefaultDtypeForTorchScalar(Torch::FloatType::get(op->getContext()));
resultElementType = resultType.getElementType();
} else {
int64_t dtypeInt;
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt)))
Expand Down

0 comments on commit 6a072d4

Please sign in to comment.