Tensor.empty op has different semantics for cuda.

cuda_f16
Prashant Kumar 2022-12-13 09:38:45 +00:00
parent 5351db0e99
commit 1eccac1264
1 changed files with 15 additions and 12 deletions

View File

@ -183,15 +183,15 @@ public:
} }
// TODO: Add support for device arg other than cpu. // TODO: Add support for device arg other than cpu.
if (!op.getDevice().getType().isa<Torch::NoneType>()) { //if (!op.getDevice().getType().isa<Torch::NoneType>()) {
std::string device; //std::string device;
if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device))) //if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device)))
return rewriter.notifyMatchFailure( //return rewriter.notifyMatchFailure(
op, "unimplemented: device must be a constant str"); //op, "unimplemented: device must be a constant str");
else if (device != "cpu") //else if (device != "cpu")
return rewriter.notifyMatchFailure( //return rewriter.notifyMatchFailure(
op, "unimplemented: device is expected to be cpu"); //op, "unimplemented: device is expected to be cpu");
} //}
// TODO: Add support for non-strided layout. // TODO: Add support for non-strided layout.
// torch.layout is by default strided i.e. 0. // torch.layout is by default strided i.e. 0.
@ -232,10 +232,13 @@ public:
IntegerType::Signless); IntegerType::Signless);
} }
Value constVal = getConstant(rewriter, loc, 0, resultElementType);
// Create an uninitialized tensor of `resultSize` shape. // Create an uninitialized tensor of `resultSize` shape.
Value initTensor = rewriter.create<tensor::EmptyOp>( // Value initTensor = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(resultSizeIndex), resultElementType); // loc, getAsOpFoldResult(resultSizeIndex), resultElementType);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, initTensor); Value outputTensor = createInitTensor(rewriter, loc, resultSizeIndex,
resultElementType, constVal);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, outputTensor);
return success(); return success();
} }
}; };