Fix bug in NumToTensor handling of float values

This commit fixes a type promotion bug when NumToTensor was given a
float as an argument. In particular, the rules for type promotion of a
scalar vary depending on if the scalar is part of a tensor op or
not. NumToTensor falls under the second category, but it was being
treated as part of the first category.
pull/438/head snapshot-20211123.102
Ramiro Leal-Cavazos 2021-11-22 18:27:25 +00:00 committed by Yi Zhang
parent 1dc374014b
commit 56c6e3676b
2 changed files with 34 additions and 4 deletions

View File

@ -630,7 +630,7 @@ def LogSoftmaxIntModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4).double())
class NumToTensorModule(torch.nn.Module):
class NumToTensorIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -642,8 +642,26 @@ class NumToTensorModule(torch.nn.Module):
def forward(self):
return torch.ops.prim.NumToTensor(1)
@register_test_case(module_factory=lambda: NumToTensorModule())
def NumToTensorModule_basic(module, tu: TestUtils):
@register_test_case(module_factory=lambda: NumToTensorIntModule())
def NumToTensorIntModule_basic(module, tu: TestUtils):
module.forward()
class NumToTensorFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.ops.prim.NumToTensor(1.0)
@register_test_case(module_factory=lambda: NumToTensorFloatModule())
def NumToTensorFloatModule_basic(module, tu: TestUtils):
module.forward()

View File

@ -1290,7 +1290,19 @@ ChangeResult TypeAnalyzer::visitNumToTensorOp(PrimNumToTensorScalarOp op) {
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.hasSizes = true;
knowledge.dtype = getDefaultDtypeForTorchScalar(op.a().getType());
// The resulting type from converting a Scalar into a Tensor is different
// if the scalar is part of a tensor operation (such as AtenMulScalar) or
// not. In the former case, the type promotion rules are captured by the
// `getDefaultDtypeForTorchScalar` helper above. The latter case follows the
// rules in https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/ScalarOps.h.
// `NumToTensor` falls in the latter case.
Type type = op.a().getType();
if (type.isa<Torch::FloatType>())
knowledge.dtype = Float64Type::get(op.getContext());
else if (type.isa<Torch::IntType>())
knowledge.dtype = IntegerType::get(op.getContext(), 64, IntegerType::Signed);
return getLatticeElement(op.getResult()).join(knowledge);
}