From 1eccac1264ede665d881ddc44e71d0efb1d84da7 Mon Sep 17 00:00:00 2001 From: Prashant Kumar Date: Tue, 13 Dec 2022 09:38:45 +0000 Subject: [PATCH] Tensor.empty op has different semantics for cuda. --- .../TorchToLinalg/TensorConstructors.cpp | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp index 42ec8657e..538c9f15c 100644 --- a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp +++ b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp @@ -183,15 +183,15 @@ public: } // TODO: Add support for device arg other than cpu. - if (!op.getDevice().getType().isa()) { - std::string device; - if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device))) - return rewriter.notifyMatchFailure( - op, "unimplemented: device must be a constant str"); - else if (device != "cpu") - return rewriter.notifyMatchFailure( - op, "unimplemented: device is expected to be cpu"); - } + //if (!op.getDevice().getType().isa()) { + //std::string device; + //if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device))) + //return rewriter.notifyMatchFailure( + //op, "unimplemented: device must be a constant str"); + //else if (device != "cpu") + //return rewriter.notifyMatchFailure( + //op, "unimplemented: device is expected to be cpu"); + //} // TODO: Add support for non-strided layout. // torch.layout is by default strided i.e. 0. @@ -232,10 +232,13 @@ public: IntegerType::Signless); } + Value constVal = getConstant(rewriter, loc, 0, resultElementType); // Create an uninitialized tensor of `resultSize` shape. - Value initTensor = rewriter.create( - loc, getAsOpFoldResult(resultSizeIndex), resultElementType); - rewriter.replaceOpWithNewOp(op, resultType, initTensor); + // Value initTensor = rewriter.create( + // loc, getAsOpFoldResult(resultSizeIndex), resultElementType); + Value outputTensor = createInitTensor(rewriter, loc, resultSizeIndex, + resultElementType, constVal); + rewriter.replaceOpWithNewOp(op, resultType, outputTensor); return success(); } };