mirror of https://github.com/llvm/torch-mlir
Tensor.empty op has different semantics for cuda.
parent
5351db0e99
commit
1eccac1264
|
@ -183,15 +183,15 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Add support for device arg other than cpu.
|
// TODO: Add support for device arg other than cpu.
|
||||||
if (!op.getDevice().getType().isa<Torch::NoneType>()) {
|
//if (!op.getDevice().getType().isa<Torch::NoneType>()) {
|
||||||
std::string device;
|
//std::string device;
|
||||||
if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device)))
|
//if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device)))
|
||||||
return rewriter.notifyMatchFailure(
|
//return rewriter.notifyMatchFailure(
|
||||||
op, "unimplemented: device must be a constant str");
|
//op, "unimplemented: device must be a constant str");
|
||||||
else if (device != "cpu")
|
//else if (device != "cpu")
|
||||||
return rewriter.notifyMatchFailure(
|
//return rewriter.notifyMatchFailure(
|
||||||
op, "unimplemented: device is expected to be cpu");
|
//op, "unimplemented: device is expected to be cpu");
|
||||||
}
|
//}
|
||||||
|
|
||||||
// TODO: Add support for non-strided layout.
|
// TODO: Add support for non-strided layout.
|
||||||
// torch.layout is by default strided i.e. 0.
|
// torch.layout is by default strided i.e. 0.
|
||||||
|
@ -232,10 +232,13 @@ public:
|
||||||
IntegerType::Signless);
|
IntegerType::Signless);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value constVal = getConstant(rewriter, loc, 0, resultElementType);
|
||||||
// Create an uninitialized tensor of `resultSize` shape.
|
// Create an uninitialized tensor of `resultSize` shape.
|
||||||
Value initTensor = rewriter.create<tensor::EmptyOp>(
|
// Value initTensor = rewriter.create<tensor::EmptyOp>(
|
||||||
loc, getAsOpFoldResult(resultSizeIndex), resultElementType);
|
// loc, getAsOpFoldResult(resultSizeIndex), resultElementType);
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, initTensor);
|
Value outputTensor = createInitTensor(rewriter, loc, resultSizeIndex,
|
||||||
|
resultElementType, constVal);
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, outputTensor);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue