Tensor.empty op has different semantics for cuda.

cuda_f16
Prashant Kumar 2022-12-13 09:38:45 +00:00
parent 5351db0e99
commit 1eccac1264
1 changed files with 15 additions and 12 deletions

View File

@ -183,15 +183,15 @@ public:
}
// TODO: Add support for device arg other than cpu.
if (!op.getDevice().getType().isa<Torch::NoneType>()) {
std::string device;
if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device)))
return rewriter.notifyMatchFailure(
op, "unimplemented: device must be a constant str");
else if (device != "cpu")
return rewriter.notifyMatchFailure(
op, "unimplemented: device is expected to be cpu");
}
//if (!op.getDevice().getType().isa<Torch::NoneType>()) {
//std::string device;
//if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device)))
//return rewriter.notifyMatchFailure(
//op, "unimplemented: device must be a constant str");
//else if (device != "cpu")
//return rewriter.notifyMatchFailure(
//op, "unimplemented: device is expected to be cpu");
//}
// TODO: Add support for non-strided layout.
// torch.layout is by default strided i.e. 0.
@ -232,10 +232,13 @@ public:
IntegerType::Signless);
}
Value constVal = getConstant(rewriter, loc, 0, resultElementType);
// Create an uninitialized tensor of `resultSize` shape.
Value initTensor = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(resultSizeIndex), resultElementType);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, initTensor);
// Value initTensor = rewriter.create<tensor::EmptyOp>(
// loc, getAsOpFoldResult(resultSizeIndex), resultElementType);
Value outputTensor = createInitTensor(rewriter, loc, resultSizeIndex,
resultElementType, constVal);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, outputTensor);
return success();
}
};