Fix error in RefineTypes for constant alloc ops (#579)

This commit fixes an error in the refine types pass of constant
allocation ops. The function used to set the dtype,
`fillInDtypeGivenDtypeAndDataType`, takes two torch types as arguments,
but a torch type and a standard MLIR type were being passed into it.

This commit also fixes the way the dtype was calculated in
`visitAtenToDtypeOp`. This op was also passing a standard MLIR type as
an argument to the `fillInDtypeGivenDtypeAndDataType`
function. Moreover, since the op `aten.to.dtype` has the dtype
argument as not optional, all that is needed is to match
against the int value to extract the dtype.
pull/587/head
Ramiro Leal-Cavazos 2022-02-10 18:02:18 -08:00 committed by GitHub
parent 1ab2e3260b
commit c1167853db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 105 additions and 2 deletions

View File

@ -11,6 +11,22 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ==============================================================================
class ZerosModuleDefaultDtype(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.zeros(3, 4)
@register_test_case(module_factory=lambda: ZerosModuleDefaultDtype())
def ZerosModuleDefaultDtype_basic(module, tu: TestUtils):
module.forward()
class ZerosModuleInt2D(torch.nn.Module):
def __init__(self):
super().__init__()
@ -92,6 +108,22 @@ def ZerosModuleFalsePinMemory_basic(module, tu: TestUtils):
# ==============================================================================
class OnesModuleDefaultDtype(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.ones(3, 4)
@register_test_case(module_factory=lambda: OnesModuleDefaultDtype())
def OnesModuleDefaultDtype_basic(module, tu: TestUtils):
module.forward()
class OnesModuleInt(torch.nn.Module):
def __init__(self):
super().__init__()
@ -141,6 +173,22 @@ def OnesModuleFalsePinMemory_basic(module, tu: TestUtils):
# ==============================================================================
class EmptyDefaultDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
return torch.pow(torch.empty((3, 4)), 0)
@register_test_case(module_factory=lambda: EmptyDefaultDtypeModule())
def EmptyModule_defaultDtype(module, tu: TestUtils):
module.forward()
class EmptyIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -191,6 +239,23 @@ def EmptyModule_falsePinMemory(module, tu: TestUtils):
# ==============================================================================
class EmptyLikeDefaultDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.pow(torch.empty_like(a), 0.0)
@register_test_case(module_factory=lambda: EmptyLikeDefaultDtypeModule())
def EmptyLikeModule_defaultDtype(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
class EmptyLikeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -244,6 +309,23 @@ def EmptyLikeModule_falsePinMemory(module, tu: TestUtils):
# ==============================================================================
class ZerosLikeDefaultDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.zeros_like(a)
@register_test_case(module_factory=lambda: ZerosLikeDefaultDtypeModule())
def ZerosLikeModule_defaultDtype(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
class ZerosLikeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -296,6 +378,23 @@ def ZerosLikeModule_falsePinMemory(module, tu: TestUtils):
# ==============================================================================
class OnesLikeDefaultDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ones_like(a)
@register_test_case(module_factory=lambda: OnesLikeDefaultDtypeModule())
def OnesLikeModule_defaultDtype(module, tu: TestUtils):
module.forward(tu.rand(3, 5))
class OnesLikeIntModule(torch.nn.Module):
def __init__(self):
super().__init__()

View File

@ -752,6 +752,8 @@ static void fillInDTypeGivenDTypeIntAndInputDType(ValueKnowledge &knowledge,
// type.
static void fillInDTypeGivenDTypeAndDataType(ValueKnowledge &knowledge,
Value dtype, Type dataType) {
assert(isa<TorchDialect>(dataType.getDialect()) &&
"`dataType` must be a torch type");
Type dtypeForDataType = getDefaultDtypeForTorchScalar(dataType);
fillInDTypeGivenDTypeIntAndInputDType(knowledge, dtype, dtypeForDataType);
}
@ -1538,7 +1540,7 @@ ChangeResult TypeAnalyzer::visitConstantTensorAllocLikeOp(
knowledge.hasSizes = true;
knowledge.sizes = input.sizes;
}
fillInDTypeGivenDTypeAndDataType(knowledge, op.dtype(), input.dtype);
fillInDTypeGivenDTypeIntAndInputDType(knowledge, op.dtype(), input.dtype);
return getLatticeElement(op.getResult()).join(knowledge);
}
@ -1552,7 +1554,9 @@ ChangeResult TypeAnalyzer::visitAtenToDtypeOp(
if (input.hasSizes)
knowledge.sizes = input.sizes;
Value dtype = op.dtype();
fillInDTypeGivenDTypeAndDataType(knowledge, dtype, input.dtype);
int64_t dtypeInt;
if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
knowledge.dtype = getTypeForDTypeInteger(op->getContext(), dtypeInt);
return getLatticeElement(op.getResult()).join(knowledge);
}