From 56c6e3676bbe1571dab0c7cc1fb5080846819438 Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Mon, 22 Nov 2021 18:27:25 +0000 Subject: [PATCH] Fix bug in NumToTensor handling of float values This commit fixes a type promotion bug when NumToTensor was given a float as an argument. In particular, the rules for type promotion of a scalar vary depending on if the scalar is part of a tensor op or not. NumToTensor falls under the second category, but it was being treated as part of the first category. --- e2e_testing/torchscript/basic.py | 24 +++++++++++++++++--- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 14 +++++++++++- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/e2e_testing/torchscript/basic.py b/e2e_testing/torchscript/basic.py index 28dd1b7d1..a5c798e0b 100644 --- a/e2e_testing/torchscript/basic.py +++ b/e2e_testing/torchscript/basic.py @@ -630,7 +630,7 @@ def LogSoftmaxIntModule_basic(module, tu: TestUtils): module.forward(torch.randn(3, 2, 4).double()) -class NumToTensorModule(torch.nn.Module): +class NumToTensorIntModule(torch.nn.Module): def __init__(self): super().__init__() @@ -642,8 +642,26 @@ class NumToTensorModule(torch.nn.Module): def forward(self): return torch.ops.prim.NumToTensor(1) -@register_test_case(module_factory=lambda: NumToTensorModule()) -def NumToTensorModule_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: NumToTensorIntModule()) +def NumToTensorIntModule_basic(module, tu: TestUtils): + module.forward() + + +class NumToTensorFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + + def forward(self): + return torch.ops.prim.NumToTensor(1.0) + + +@register_test_case(module_factory=lambda: NumToTensorFloatModule()) +def NumToTensorFloatModule_basic(module, tu: TestUtils): module.forward() diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 6cf5a9ec1..12b803c5c 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -1290,7 +1290,19 @@ ChangeResult TypeAnalyzer::visitNumToTensorOp(PrimNumToTensorScalarOp op) { auto knowledge = ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); knowledge.hasSizes = true; - knowledge.dtype = getDefaultDtypeForTorchScalar(op.a().getType()); + + // The resulting type from converting a Scalar into a Tensor is different + // if the scalar is part of a tensor operation (such as AtenMulScalar) or + // not. In the former case, the type promotion rules are captured by the + // `getDefaultDtypeForTorchScalar` helper above. The latter case follows the + // rules in https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/ScalarOps.h. + // `NumToTensor` falls in the latter case. + Type type = op.a().getType(); + if (type.isa()) + knowledge.dtype = Float64Type::get(op.getContext()); + else if (type.isa()) + knowledge.dtype = IntegerType::get(op.getContext(), 64, IntegerType::Signed); + return getLatticeElement(op.getResult()).join(knowledge); }