mirror of https://github.com/llvm/torch-mlir
Fix bug in NumToTensor handling of float values
This commit fixes a type promotion bug when NumToTensor was given a float as an argument. In particular, the rules for type promotion of a scalar vary depending on if the scalar is part of a tensor op or not. NumToTensor falls under the second category, but it was being treated as part of the first category.pull/438/head snapshot-20211123.102
parent
1dc374014b
commit
56c6e3676b
|
@ -630,7 +630,7 @@ def LogSoftmaxIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randn(3, 2, 4).double())
|
module.forward(torch.randn(3, 2, 4).double())
|
||||||
|
|
||||||
|
|
||||||
class NumToTensorModule(torch.nn.Module):
|
class NumToTensorIntModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -642,8 +642,26 @@ class NumToTensorModule(torch.nn.Module):
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return torch.ops.prim.NumToTensor(1)
|
return torch.ops.prim.NumToTensor(1)
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: NumToTensorModule())
|
@register_test_case(module_factory=lambda: NumToTensorIntModule())
|
||||||
def NumToTensorModule_basic(module, tu: TestUtils):
|
def NumToTensorIntModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
|
class NumToTensorFloatModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
return torch.ops.prim.NumToTensor(1.0)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NumToTensorFloatModule())
|
||||||
|
def NumToTensorFloatModule_basic(module, tu: TestUtils):
|
||||||
module.forward()
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1290,7 +1290,19 @@ ChangeResult TypeAnalyzer::visitNumToTensorOp(PrimNumToTensorScalarOp op) {
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||||
knowledge.hasSizes = true;
|
knowledge.hasSizes = true;
|
||||||
knowledge.dtype = getDefaultDtypeForTorchScalar(op.a().getType());
|
|
||||||
|
// The resulting type from converting a Scalar into a Tensor is different
|
||||||
|
// if the scalar is part of a tensor operation (such as AtenMulScalar) or
|
||||||
|
// not. In the former case, the type promotion rules are captured by the
|
||||||
|
// `getDefaultDtypeForTorchScalar` helper above. The latter case follows the
|
||||||
|
// rules in https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/ScalarOps.h.
|
||||||
|
// `NumToTensor` falls in the latter case.
|
||||||
|
Type type = op.a().getType();
|
||||||
|
if (type.isa<Torch::FloatType>())
|
||||||
|
knowledge.dtype = Float64Type::get(op.getContext());
|
||||||
|
else if (type.isa<Torch::IntType>())
|
||||||
|
knowledge.dtype = IntegerType::get(op.getContext(), 64, IntegerType::Signed);
|
||||||
|
|
||||||
return getLatticeElement(op.getResult()).join(knowledge);
|
return getLatticeElement(op.getResult()).join(knowledge);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue