Fix creation of empty tensor in decomposition for randn ops (#2043)

The current decomposition for `aten.randn.generator` does not specify
the `dtype` argument of the empty tensors created to store the random
values. This leads to invalid IR when the output type of the `randn`
op is not the default PyTorch dtype.
pull/2057/head
Ramiro Leal-Cavazos 2023-04-19 08:25:39 -07:00
parent 6543db7071
commit af01ee742e
3 changed files with 37 additions and 4 deletions

View File

@ -226,7 +226,8 @@ public:
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
Type resultElementType;
if (op.getDtype().getType().isa<Torch::NoneType>()) {
resultElementType = resultType.getElementType();
resultElementType = getDefaultDtypeForTorchScalar(
Torch::FloatType::get(op->getContext()));
} else {
int64_t dtypeInt;
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt)))

View File

@ -3716,8 +3716,14 @@ public:
LogicalResult matchAndRewrite(AtenRandnGeneratorOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Type resultType = op.getType();
auto resultType = op.getType().cast<BaseTensorType>();
if (!resultType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected result type to have a dtype");
}
Value dtype = getDtypeIntValueForType(rewriter, loc, resultType.getDtype());
Value none = rewriter.create<ConstantNoneOp>(loc);
Value low = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr((double)0.0));
@ -3729,11 +3735,13 @@ public:
loc, rewriter.getF64FloatAttr((double)(2.0 * 3.14159)));
Value emptyTensorA = rewriter.create<AtenEmptyMemoryFormatOp>(
loc, resultType, op.getSize(), /*dtype=*/none, /*layout=*/op.getLayout(),
loc, resultType, op.getSize(), /*dtype=*/dtype,
/*layout=*/op.getLayout(),
/*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(),
/*memory_format=*/none);
Value emptyTensorB = rewriter.create<AtenEmptyMemoryFormatOp>(
loc, resultType, op.getSize(), /*dtype=*/none, /*layout=*/op.getLayout(),
loc, resultType, op.getSize(), /*dtype=*/dtype,
/*layout=*/op.getLayout(),
/*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(),
/*memory_format=*/none);

View File

@ -400,6 +400,30 @@ def RandnGeneratorModule_basic(module, tu: TestUtils):
# ==============================================================================
class RandnGeneratorF64Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
a = torch.ops.aten.randn([4, 512, 1024], generator=None, dtype=torch.float64)
std = torch.std(a)
return std
@register_test_case(module_factory=lambda: RandnGeneratorF64Module())
def RandnGeneratorF64Module_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class RandnLikeModule(torch.nn.Module):
def __init__(self):
super().__init__()