mirror of https://github.com/llvm/torch-mlir
Fix bug in NumToTensor handling of float values
This commit fixes a type promotion bug when NumToTensor was given a float as an argument. In particular, the rules for type promotion of a scalar vary depending on if the scalar is part of a tensor op or not. NumToTensor falls under the second category, but it was being treated as part of the first category.pull/438/head snapshot-20211123.102
parent
1dc374014b
commit
56c6e3676b
|
@ -630,7 +630,7 @@ def LogSoftmaxIntModule_basic(module, tu: TestUtils):
|
|||
module.forward(torch.randn(3, 2, 4).double())
|
||||
|
||||
|
||||
class NumToTensorModule(torch.nn.Module):
|
||||
class NumToTensorIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -642,8 +642,26 @@ class NumToTensorModule(torch.nn.Module):
|
|||
def forward(self):
|
||||
return torch.ops.prim.NumToTensor(1)
|
||||
|
||||
@register_test_case(module_factory=lambda: NumToTensorModule())
|
||||
def NumToTensorModule_basic(module, tu: TestUtils):
|
||||
@register_test_case(module_factory=lambda: NumToTensorIntModule())
|
||||
def NumToTensorIntModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class NumToTensorFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
|
||||
def forward(self):
|
||||
return torch.ops.prim.NumToTensor(1.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: NumToTensorFloatModule())
|
||||
def NumToTensorFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
|
|
|
@ -1290,7 +1290,19 @@ ChangeResult TypeAnalyzer::visitNumToTensorOp(PrimNumToTensorScalarOp op) {
|
|||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.hasSizes = true;
|
||||
knowledge.dtype = getDefaultDtypeForTorchScalar(op.a().getType());
|
||||
|
||||
// The resulting type from converting a Scalar into a Tensor is different
|
||||
// if the scalar is part of a tensor operation (such as AtenMulScalar) or
|
||||
// not. In the former case, the type promotion rules are captured by the
|
||||
// `getDefaultDtypeForTorchScalar` helper above. The latter case follows the
|
||||
// rules in https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/ScalarOps.h.
|
||||
// `NumToTensor` falls in the latter case.
|
||||
Type type = op.a().getType();
|
||||
if (type.isa<Torch::FloatType>())
|
||||
knowledge.dtype = Float64Type::get(op.getContext());
|
||||
else if (type.isa<Torch::IntType>())
|
||||
knowledge.dtype = IntegerType::get(op.getContext(), 64, IntegerType::Signed);
|
||||
|
||||
return getLatticeElement(op.getResult()).join(knowledge);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue