mirror of https://github.com/llvm/torch-mlir
[TorchDynamo] Enable Elemtwise ops for Scalar arg (#2744)
This commit provides dummy solution to support elmentwise operations (mul, add) with scalar argument. ( op(Tensor, Scalar) ) It replaces `torch.aten.add.Tensor` with `torch.aten.add.Scalar`. ``` Unexpected outcome summary: (torchdynamo) ****** Unexpectedly Passed tests - 22 tests XPASS - "AddCDivModule_basic" XPASS - "BatchNorm1DModule_basic" XPASS - "BatchNorm1DStaticShapeModule_basic" XPASS - "BatchNorm1DWith2DInputModule_basic" XPASS - "BatchNorm2DModule_basic" XPASS - "BatchNorm3DModule_basic" XPASS - "ElementwiseAddScalarInt64Module_basic" XPASS - "ElementwiseAddScalarIntModule_basic" XPASS - "ElementwiseMulScalarModule_basic" XPASS - "ElementwiseMulScalarModule_float" XPASS - "ElementwiseMulScalarModule_int" XPASS - "GroupNormModule_basic" XPASS - "GroupNormNoWeightAndBiasModule_basic" XPASS - "MobilenetV3Module_basic" XPASS - "NativeBatchNorm1DModule_basic" XPASS - "NativeBatchNorm2DModule_basic" XPASS - "NativeBatchNorm3DModule_basic" XPASS - "NativeBatchNormNoneWeightModule_basic" XPASS - "NativeGroupNormBackwardModule_basic" XPASS - "NativeGroupNormModule_basic" XPASS - "ResNet18Module_basic" XPASS - "ResNet18StaticModule_basic" ``` And segfault for test "ElementwiseAddScalar_TensorLiteralInt32_Module_basic". Somehow this change doesn't allow to use Tensors, that are not forward arguments, but local variables of model. e.g. `self.x = torch.tensor(..)` See also: #2745 Signed-off-by: Dmitrii Makarenko <dmitrii.makarenko@intel.com>pull/2920/head
parent
8fb28661f9
commit
4b1e87ce67
|
@ -215,10 +215,6 @@ TORCHDYNAMO_XFAIL_SET = {
|
||||||
'ConstantBoolParameterModule_basic',
|
'ConstantBoolParameterModule_basic',
|
||||||
|
|
||||||
# START tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
# START tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||||
"AddCDivModule_basic",
|
|
||||||
"ElementwiseMulScalarModule_basic",
|
|
||||||
"ElementwiseMulScalarModule_float",
|
|
||||||
"NativeGroupNormBackwardModule_basic",
|
|
||||||
"UpSampleNearest2dDynamicSize_basic",
|
"UpSampleNearest2dDynamicSize_basic",
|
||||||
"UpSampleNearest2dStaticFactor_basic",
|
"UpSampleNearest2dStaticFactor_basic",
|
||||||
"UpSampleNearest2dStaticSize_basic",
|
"UpSampleNearest2dStaticSize_basic",
|
||||||
|
@ -226,23 +222,7 @@ TORCHDYNAMO_XFAIL_SET = {
|
||||||
# END tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
# END tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||||
|
|
||||||
# START tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
# START tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||||
"AtenInstanceNormModule_basic",
|
|
||||||
"BatchNorm1DModule_basic",
|
|
||||||
"BatchNorm1DWith2DInputModule_basic",
|
|
||||||
"BatchNorm2DModule_basic",
|
|
||||||
"BatchNorm3DModule_basic",
|
|
||||||
"BatchNorm1DStaticShapeModule_basic",
|
|
||||||
"ElementwiseAddScalarFloatModule_basic",
|
"ElementwiseAddScalarFloatModule_basic",
|
||||||
"ElementwiseAddScalarInt64Module_basic",
|
|
||||||
"ElementwiseAddScalarIntModule_basic",
|
|
||||||
"MobilenetV3Module_basic",
|
|
||||||
"NativeBatchNorm1DModule_basic",
|
|
||||||
"NativeBatchNorm2DModule_basic",
|
|
||||||
"NativeBatchNorm3DModule_basic",
|
|
||||||
"NativeBatchNormNoneWeightModule_basic",
|
|
||||||
"NativeGroupNormModule_basic",
|
|
||||||
"ResNet18Module_basic",
|
|
||||||
"ResNet18StaticModule_basic",
|
|
||||||
# END tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
# END tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||||
|
|
||||||
# ERROR: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int'
|
# ERROR: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int'
|
||||||
|
@ -255,9 +235,6 @@ TORCHDYNAMO_XFAIL_SET = {
|
||||||
# ERROR: 'torch.aten.div.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int'
|
# ERROR: 'torch.aten.div.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int'
|
||||||
"ElementwiseAtenDivIntScalarModule_basic",
|
"ElementwiseAtenDivIntScalarModule_basic",
|
||||||
|
|
||||||
# ERROR: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int'
|
|
||||||
"ElementwiseMulScalarModule_int",
|
|
||||||
|
|
||||||
# ERROR: 'torch.aten.sub.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
# ERROR: 'torch.aten.sub.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||||
"ElementwiseSubScalarFloatModule_basic",
|
"ElementwiseSubScalarFloatModule_basic",
|
||||||
"ElementwiseSubScalarIntModule_basic",
|
"ElementwiseSubScalarIntModule_basic",
|
||||||
|
@ -315,10 +292,6 @@ TORCHDYNAMO_XFAIL_SET = {
|
||||||
# ERROR: shape (torch.Size([12])) is not equal to golden shape (torch.Size([3, 4]))
|
# ERROR: shape (torch.Size([12])) is not equal to golden shape (torch.Size([3, 4]))
|
||||||
"ArangeStartOutViewModule_basic",
|
"ArangeStartOutViewModule_basic",
|
||||||
|
|
||||||
# ERROR: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
|
||||||
"GroupNormModule_basic",
|
|
||||||
"GroupNormNoWeightAndBiasModule_basic",
|
|
||||||
|
|
||||||
# Dynamo does not support tracing quantized tensors
|
# Dynamo does not support tracing quantized tensors
|
||||||
"ElementwiseDequantizePerChannelModule_basic",
|
"ElementwiseDequantizePerChannelModule_basic",
|
||||||
"ElementwiseDequantizePerTensorModule_basic",
|
"ElementwiseDequantizePerTensorModule_basic",
|
||||||
|
@ -376,6 +349,9 @@ TORCHDYNAMO_CRASHING_SET = {
|
||||||
"MaxPool3dModule_basic",
|
"MaxPool3dModule_basic",
|
||||||
"MaxPool3dStaticCeilModeTrueModule_basic",
|
"MaxPool3dStaticCeilModeTrueModule_basic",
|
||||||
"MaxPool3dStaticModule_basic",
|
"MaxPool3dStaticModule_basic",
|
||||||
|
|
||||||
|
# Looks like incorrect fx graph conversion
|
||||||
|
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
STABLEHLO_PASS_SET = {
|
STABLEHLO_PASS_SET = {
|
||||||
|
|
|
@ -53,6 +53,40 @@ def _returns_empty_tuple(fx_graph: torch.fx.GraphModule) -> bool:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
# Replaces torch.aten.add.Tensor/torch.aten.mul.Tensor to
|
||||||
|
# torch.aten.add.Scalar/torch.aten.mul.Scalar in case of Scalar argument
|
||||||
|
# Cannot be done on earlier stage, e.g. in _FXGraphImporter as it
|
||||||
|
# needs to check argument types, which are not yet determined.
|
||||||
|
# Maybe schema or target should be changed, but it decided in
|
||||||
|
# _dynamo eval_frame on pytorch side. Also Python schema not matches
|
||||||
|
# with mlir Schema - check include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
|
||||||
|
# So in general it covers some of overload cases, which done on Python side automatically.
|
||||||
|
# e.g. conversion Scalar -> Tensor and vice versa
|
||||||
|
def scalarize_tensor_ops_on_scalars(gm: torch.fx.GraphModule):
|
||||||
|
# Modify gm.graph
|
||||||
|
for node in gm.graph.nodes:
|
||||||
|
# Checks if we're calling a function (i.e:
|
||||||
|
# torch.add)
|
||||||
|
if node.op == 'call_function':
|
||||||
|
# The target attribute is the function
|
||||||
|
# that call_function calls.
|
||||||
|
# call_function[target=torch.ops.aten.add.Tensor](args = (%arg64_1, 1), kwargs = {})
|
||||||
|
if node.target == torch.ops.aten.add.Tensor:
|
||||||
|
if len(node.args) != 2 or node.kwargs != {}:
|
||||||
|
continue
|
||||||
|
elif not isinstance(node.args[1], torch.fx.node.Node):
|
||||||
|
node.target = torch.ops.aten.add.Scalar
|
||||||
|
if node.target == torch.ops.aten.mul.Tensor:
|
||||||
|
if len(node.args) != 2 or node.kwargs != {}:
|
||||||
|
continue
|
||||||
|
elif not isinstance(node.args[1], torch.fx.node.Node):
|
||||||
|
node.target = torch.ops.aten.mul.Scalar
|
||||||
|
|
||||||
|
gm.graph.lint() # Does some checks to make sure the
|
||||||
|
|
||||||
|
# Recompile the forward() method of `gm` from its Graph
|
||||||
|
gm.recompile()
|
||||||
|
|
||||||
|
|
||||||
def jit(
|
def jit(
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
|
@ -87,6 +121,8 @@ def jit(
|
||||||
# way of differentiating between the two.
|
# way of differentiating between the two.
|
||||||
assert not _returns_empty_tuple(gm), "encountered graph that does not return anything"
|
assert not _returns_empty_tuple(gm), "encountered graph that does not return anything"
|
||||||
|
|
||||||
|
scalarize_tensor_ops_on_scalars(gm)
|
||||||
|
|
||||||
nonlocal mlir_module
|
nonlocal mlir_module
|
||||||
*_, model_name, nth_graph = get_aot_compilation_context()
|
*_, model_name, nth_graph = get_aot_compilation_context()
|
||||||
mlir_module = import_fx_graph_as_func(gm.graph, model_name)
|
mlir_module = import_fx_graph_as_func(gm.graph, model_name)
|
||||||
|
|
Loading…
Reference in New Issue