mirror of https://github.com/llvm/torch-mlir
Fix error in RefineTypes for constant alloc ops (#579)
This commit fixes an error in the refine types pass of constant allocation ops. The function used to set the dtype, `fillInDtypeGivenDtypeAndDataType`, takes two torch types as arguments, but a torch type and a standard MLIR type were being passed into it. This commit also fixes the way the dtype was calculated in `visitAtenToDtypeOp`. This op was also passing a standard MLIR type as an argument to the `fillInDtypeGivenDtypeAndDataType` function. Moreover, since the op `aten.to.dtype` has the dtype argument as not optional, all that is needed is to match against the int value to extract the dtype.pull/587/head
parent
1ab2e3260b
commit
c1167853db
|
@ -11,6 +11,22 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ZerosModuleDefaultDtype(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return torch.zeros(3, 4)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ZerosModuleDefaultDtype())
|
||||||
|
def ZerosModuleDefaultDtype_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
class ZerosModuleInt2D(torch.nn.Module):
|
class ZerosModuleInt2D(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -92,6 +108,22 @@ def ZerosModuleFalsePinMemory_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class OnesModuleDefaultDtype(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return torch.ones(3, 4)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: OnesModuleDefaultDtype())
|
||||||
|
def OnesModuleDefaultDtype_basic(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
class OnesModuleInt(torch.nn.Module):
|
class OnesModuleInt(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -141,6 +173,22 @@ def OnesModuleFalsePinMemory_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class EmptyDefaultDtypeModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
])
|
||||||
|
def forward(self):
|
||||||
|
return torch.pow(torch.empty((3, 4)), 0)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: EmptyDefaultDtypeModule())
|
||||||
|
def EmptyModule_defaultDtype(module, tu: TestUtils):
|
||||||
|
module.forward()
|
||||||
|
|
||||||
|
|
||||||
class EmptyIntModule(torch.nn.Module):
|
class EmptyIntModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -191,6 +239,23 @@ def EmptyModule_falsePinMemory(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class EmptyLikeDefaultDtypeModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.pow(torch.empty_like(a), 0.0)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: EmptyLikeDefaultDtypeModule())
|
||||||
|
def EmptyLikeModule_defaultDtype(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 5))
|
||||||
|
|
||||||
|
|
||||||
class EmptyLikeIntModule(torch.nn.Module):
|
class EmptyLikeIntModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -244,6 +309,23 @@ def EmptyLikeModule_falsePinMemory(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ZerosLikeDefaultDtypeModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.zeros_like(a)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ZerosLikeDefaultDtypeModule())
|
||||||
|
def ZerosLikeModule_defaultDtype(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 5))
|
||||||
|
|
||||||
|
|
||||||
class ZerosLikeIntModule(torch.nn.Module):
|
class ZerosLikeIntModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -296,6 +378,23 @@ def ZerosLikeModule_falsePinMemory(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class OnesLikeDefaultDtypeModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.ones_like(a)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: OnesLikeDefaultDtypeModule())
|
||||||
|
def OnesLikeModule_defaultDtype(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 5))
|
||||||
|
|
||||||
|
|
||||||
class OnesLikeIntModule(torch.nn.Module):
|
class OnesLikeIntModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -752,6 +752,8 @@ static void fillInDTypeGivenDTypeIntAndInputDType(ValueKnowledge &knowledge,
|
||||||
// type.
|
// type.
|
||||||
static void fillInDTypeGivenDTypeAndDataType(ValueKnowledge &knowledge,
|
static void fillInDTypeGivenDTypeAndDataType(ValueKnowledge &knowledge,
|
||||||
Value dtype, Type dataType) {
|
Value dtype, Type dataType) {
|
||||||
|
assert(isa<TorchDialect>(dataType.getDialect()) &&
|
||||||
|
"`dataType` must be a torch type");
|
||||||
Type dtypeForDataType = getDefaultDtypeForTorchScalar(dataType);
|
Type dtypeForDataType = getDefaultDtypeForTorchScalar(dataType);
|
||||||
fillInDTypeGivenDTypeIntAndInputDType(knowledge, dtype, dtypeForDataType);
|
fillInDTypeGivenDTypeIntAndInputDType(knowledge, dtype, dtypeForDataType);
|
||||||
}
|
}
|
||||||
|
@ -1538,7 +1540,7 @@ ChangeResult TypeAnalyzer::visitConstantTensorAllocLikeOp(
|
||||||
knowledge.hasSizes = true;
|
knowledge.hasSizes = true;
|
||||||
knowledge.sizes = input.sizes;
|
knowledge.sizes = input.sizes;
|
||||||
}
|
}
|
||||||
fillInDTypeGivenDTypeAndDataType(knowledge, op.dtype(), input.dtype);
|
fillInDTypeGivenDTypeIntAndInputDType(knowledge, op.dtype(), input.dtype);
|
||||||
return getLatticeElement(op.getResult()).join(knowledge);
|
return getLatticeElement(op.getResult()).join(knowledge);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1552,7 +1554,9 @@ ChangeResult TypeAnalyzer::visitAtenToDtypeOp(
|
||||||
if (input.hasSizes)
|
if (input.hasSizes)
|
||||||
knowledge.sizes = input.sizes;
|
knowledge.sizes = input.sizes;
|
||||||
Value dtype = op.dtype();
|
Value dtype = op.dtype();
|
||||||
fillInDTypeGivenDTypeAndDataType(knowledge, dtype, input.dtype);
|
int64_t dtypeInt;
|
||||||
|
if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
|
||||||
|
knowledge.dtype = getTypeForDTypeInteger(op->getContext(), dtypeInt);
|
||||||
return getLatticeElement(op.getResult()).join(knowledge);
|
return getLatticeElement(op.getResult()).join(knowledge);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue