mirror of https://github.com/llvm/torch-mlir
Tensor.empty op has different semantics for cuda.
parent
5351db0e99
commit
1eccac1264
|
@ -183,15 +183,15 @@ public:
|
|||
}
|
||||
|
||||
// TODO: Add support for device arg other than cpu.
|
||||
if (!op.getDevice().getType().isa<Torch::NoneType>()) {
|
||||
std::string device;
|
||||
if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: device must be a constant str");
|
||||
else if (device != "cpu")
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: device is expected to be cpu");
|
||||
}
|
||||
//if (!op.getDevice().getType().isa<Torch::NoneType>()) {
|
||||
//std::string device;
|
||||
//if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device)))
|
||||
//return rewriter.notifyMatchFailure(
|
||||
//op, "unimplemented: device must be a constant str");
|
||||
//else if (device != "cpu")
|
||||
//return rewriter.notifyMatchFailure(
|
||||
//op, "unimplemented: device is expected to be cpu");
|
||||
//}
|
||||
|
||||
// TODO: Add support for non-strided layout.
|
||||
// torch.layout is by default strided i.e. 0.
|
||||
|
@ -232,10 +232,13 @@ public:
|
|||
IntegerType::Signless);
|
||||
}
|
||||
|
||||
Value constVal = getConstant(rewriter, loc, 0, resultElementType);
|
||||
// Create an uninitialized tensor of `resultSize` shape.
|
||||
Value initTensor = rewriter.create<tensor::EmptyOp>(
|
||||
loc, getAsOpFoldResult(resultSizeIndex), resultElementType);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, initTensor);
|
||||
// Value initTensor = rewriter.create<tensor::EmptyOp>(
|
||||
// loc, getAsOpFoldResult(resultSizeIndex), resultElementType);
|
||||
Value outputTensor = createInitTensor(rewriter, loc, resultSizeIndex,
|
||||
resultElementType, constVal);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, outputTensor);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue