mirror of https://github.com/llvm/torch-mlir
Fix creation of empty tensor in decomposition for randn ops (#2043)
The current decomposition for `aten.randn.generator` does not specify the `dtype` argument of the empty tensors created to store the random values. This leads to invalid IR when the output type of the `randn` op is not the default PyTorch dtype.pull/2047/head snapshot-20230420.814
parent
dbbcc4aaff
commit
f85f5799e4
|
@ -226,7 +226,8 @@ public:
|
|||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
||||
Type resultElementType;
|
||||
if (op.getDtype().getType().isa<Torch::NoneType>()) {
|
||||
resultElementType = resultType.getElementType();
|
||||
resultElementType = getDefaultDtypeForTorchScalar(
|
||||
Torch::FloatType::get(op->getContext()));
|
||||
} else {
|
||||
int64_t dtypeInt;
|
||||
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt)))
|
||||
|
|
|
@ -3923,8 +3923,14 @@ public:
|
|||
LogicalResult matchAndRewrite(AtenRandnGeneratorOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Type resultType = op.getType();
|
||||
auto resultType = op.getType().cast<BaseTensorType>();
|
||||
|
||||
if (!resultType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected result type to have a dtype");
|
||||
}
|
||||
|
||||
Value dtype = getDtypeIntValueForType(rewriter, loc, resultType.getDtype());
|
||||
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||
Value low = rewriter.create<Torch::ConstantFloatOp>(
|
||||
loc, rewriter.getF64FloatAttr((double)0.0));
|
||||
|
@ -3936,11 +3942,13 @@ public:
|
|||
loc, rewriter.getF64FloatAttr((double)(2.0 * 3.14159)));
|
||||
|
||||
Value emptyTensorA = rewriter.create<AtenEmptyMemoryFormatOp>(
|
||||
loc, resultType, op.getSize(), /*dtype=*/none, /*layout=*/op.getLayout(),
|
||||
loc, resultType, op.getSize(), /*dtype=*/dtype,
|
||||
/*layout=*/op.getLayout(),
|
||||
/*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(),
|
||||
/*memory_format=*/none);
|
||||
Value emptyTensorB = rewriter.create<AtenEmptyMemoryFormatOp>(
|
||||
loc, resultType, op.getSize(), /*dtype=*/none, /*layout=*/op.getLayout(),
|
||||
loc, resultType, op.getSize(), /*dtype=*/dtype,
|
||||
/*layout=*/op.getLayout(),
|
||||
/*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(),
|
||||
/*memory_format=*/none);
|
||||
|
||||
|
|
|
@ -459,6 +459,30 @@ def RandnGeneratorModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class RandnGeneratorF64Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
a = torch.ops.aten.randn([4, 512, 1024], generator=None, dtype=torch.float64)
|
||||
std = torch.std(a)
|
||||
return std
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: RandnGeneratorF64Module())
|
||||
def RandnGeneratorF64Module_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class RandnLikeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue