mirror of https://github.com/llvm/torch-mlir
Fix error in RefineTypes for constant alloc ops (#579)
This commit fixes an error in the refine types pass of constant allocation ops. The function used to set the dtype, `fillInDtypeGivenDtypeAndDataType`, takes two torch types as arguments, but a torch type and a standard MLIR type were being passed into it. This commit also fixes the way the dtype was calculated in `visitAtenToDtypeOp`. This op was also passing a standard MLIR type as an argument to the `fillInDtypeGivenDtypeAndDataType` function. Moreover, since the op `aten.to.dtype` has the dtype argument as not optional, all that is needed is to match against the int value to extract the dtype.pull/587/head
parent
1ab2e3260b
commit
c1167853db
|
@ -11,6 +11,22 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ZerosModuleDefaultDtype(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return torch.zeros(3, 4)
|
||||
|
||||
@register_test_case(module_factory=lambda: ZerosModuleDefaultDtype())
|
||||
def ZerosModuleDefaultDtype_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class ZerosModuleInt2D(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -92,6 +108,22 @@ def ZerosModuleFalsePinMemory_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class OnesModuleDefaultDtype(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return torch.ones(3, 4)
|
||||
|
||||
@register_test_case(module_factory=lambda: OnesModuleDefaultDtype())
|
||||
def OnesModuleDefaultDtype_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class OnesModuleInt(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -141,6 +173,22 @@ def OnesModuleFalsePinMemory_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class EmptyDefaultDtypeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return torch.pow(torch.empty((3, 4)), 0)
|
||||
|
||||
@register_test_case(module_factory=lambda: EmptyDefaultDtypeModule())
|
||||
def EmptyModule_defaultDtype(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class EmptyIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -191,6 +239,23 @@ def EmptyModule_falsePinMemory(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class EmptyLikeDefaultDtypeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.pow(torch.empty_like(a), 0.0)
|
||||
|
||||
@register_test_case(module_factory=lambda: EmptyLikeDefaultDtypeModule())
|
||||
def EmptyLikeModule_defaultDtype(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
|
||||
class EmptyLikeIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -244,6 +309,23 @@ def EmptyLikeModule_falsePinMemory(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ZerosLikeDefaultDtypeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.zeros_like(a)
|
||||
|
||||
@register_test_case(module_factory=lambda: ZerosLikeDefaultDtypeModule())
|
||||
def ZerosLikeModule_defaultDtype(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
|
||||
class ZerosLikeIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -296,6 +378,23 @@ def ZerosLikeModule_falsePinMemory(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class OnesLikeDefaultDtypeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ones_like(a)
|
||||
|
||||
@register_test_case(module_factory=lambda: OnesLikeDefaultDtypeModule())
|
||||
def OnesLikeModule_defaultDtype(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
|
||||
class OnesLikeIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -752,6 +752,8 @@ static void fillInDTypeGivenDTypeIntAndInputDType(ValueKnowledge &knowledge,
|
|||
// type.
|
||||
static void fillInDTypeGivenDTypeAndDataType(ValueKnowledge &knowledge,
|
||||
Value dtype, Type dataType) {
|
||||
assert(isa<TorchDialect>(dataType.getDialect()) &&
|
||||
"`dataType` must be a torch type");
|
||||
Type dtypeForDataType = getDefaultDtypeForTorchScalar(dataType);
|
||||
fillInDTypeGivenDTypeIntAndInputDType(knowledge, dtype, dtypeForDataType);
|
||||
}
|
||||
|
@ -1538,7 +1540,7 @@ ChangeResult TypeAnalyzer::visitConstantTensorAllocLikeOp(
|
|||
knowledge.hasSizes = true;
|
||||
knowledge.sizes = input.sizes;
|
||||
}
|
||||
fillInDTypeGivenDTypeAndDataType(knowledge, op.dtype(), input.dtype);
|
||||
fillInDTypeGivenDTypeIntAndInputDType(knowledge, op.dtype(), input.dtype);
|
||||
return getLatticeElement(op.getResult()).join(knowledge);
|
||||
}
|
||||
|
||||
|
@ -1552,7 +1554,9 @@ ChangeResult TypeAnalyzer::visitAtenToDtypeOp(
|
|||
if (input.hasSizes)
|
||||
knowledge.sizes = input.sizes;
|
||||
Value dtype = op.dtype();
|
||||
fillInDTypeGivenDTypeAndDataType(knowledge, dtype, input.dtype);
|
||||
int64_t dtypeInt;
|
||||
if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
|
||||
knowledge.dtype = getTypeForDTypeInteger(op->getContext(), dtypeInt);
|
||||
return getLatticeElement(op.getResult()).join(knowledge);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue