[TorchDynamo] Enable Elemtwise ops for Scalar arg (#2744)

This commit provides dummy solution to support elmentwise operations
(mul, add) with scalar argument. ( op(Tensor, Scalar) )

It replaces `torch.aten.add.Tensor` with `torch.aten.add.Scalar`.
```
Unexpected outcome summary: (torchdynamo)

****** Unexpectedly Passed tests - 22 tests
    XPASS - "AddCDivModule_basic"
    XPASS - "BatchNorm1DModule_basic"
    XPASS - "BatchNorm1DStaticShapeModule_basic"
    XPASS - "BatchNorm1DWith2DInputModule_basic"
    XPASS - "BatchNorm2DModule_basic"
    XPASS - "BatchNorm3DModule_basic"
    XPASS - "ElementwiseAddScalarInt64Module_basic"
    XPASS - "ElementwiseAddScalarIntModule_basic"
    XPASS - "ElementwiseMulScalarModule_basic"
    XPASS - "ElementwiseMulScalarModule_float"
    XPASS - "ElementwiseMulScalarModule_int"
    XPASS - "GroupNormModule_basic"
    XPASS - "GroupNormNoWeightAndBiasModule_basic"
    XPASS - "MobilenetV3Module_basic"
    XPASS - "NativeBatchNorm1DModule_basic"
    XPASS - "NativeBatchNorm2DModule_basic"
    XPASS - "NativeBatchNorm3DModule_basic"
    XPASS - "NativeBatchNormNoneWeightModule_basic"
    XPASS - "NativeGroupNormBackwardModule_basic"
    XPASS - "NativeGroupNormModule_basic"
    XPASS - "ResNet18Module_basic"
    XPASS - "ResNet18StaticModule_basic"
```

And segfault for test
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic". Somehow this
change doesn't allow to use Tensors, that are not forward arguments, but
local variables of model.
e.g. `self.x = torch.tensor(..)`

See also: #2745

Signed-off-by: Dmitrii Makarenko <dmitrii.makarenko@intel.com>
pull/2920/head
Devjiu 2024-03-11 20:22:05 +01:00 committed by GitHub
parent 8fb28661f9
commit 4b1e87ce67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 39 additions and 27 deletions

View File

@ -215,10 +215,6 @@ TORCHDYNAMO_XFAIL_SET = {
'ConstantBoolParameterModule_basic', 'ConstantBoolParameterModule_basic',
# START tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' # START tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
"AddCDivModule_basic",
"ElementwiseMulScalarModule_basic",
"ElementwiseMulScalarModule_float",
"NativeGroupNormBackwardModule_basic",
"UpSampleNearest2dDynamicSize_basic", "UpSampleNearest2dDynamicSize_basic",
"UpSampleNearest2dStaticFactor_basic", "UpSampleNearest2dStaticFactor_basic",
"UpSampleNearest2dStaticSize_basic", "UpSampleNearest2dStaticSize_basic",
@ -226,23 +222,7 @@ TORCHDYNAMO_XFAIL_SET = {
# END tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' # END tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
# START tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' # START tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
"AtenInstanceNormModule_basic",
"BatchNorm1DModule_basic",
"BatchNorm1DWith2DInputModule_basic",
"BatchNorm2DModule_basic",
"BatchNorm3DModule_basic",
"BatchNorm1DStaticShapeModule_basic",
"ElementwiseAddScalarFloatModule_basic", "ElementwiseAddScalarFloatModule_basic",
"ElementwiseAddScalarInt64Module_basic",
"ElementwiseAddScalarIntModule_basic",
"MobilenetV3Module_basic",
"NativeBatchNorm1DModule_basic",
"NativeBatchNorm2DModule_basic",
"NativeBatchNorm3DModule_basic",
"NativeBatchNormNoneWeightModule_basic",
"NativeGroupNormModule_basic",
"ResNet18Module_basic",
"ResNet18StaticModule_basic",
# END tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' # END tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
# ERROR: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int' # ERROR: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int'
@ -255,9 +235,6 @@ TORCHDYNAMO_XFAIL_SET = {
# ERROR: 'torch.aten.div.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int' # ERROR: 'torch.aten.div.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int'
"ElementwiseAtenDivIntScalarModule_basic", "ElementwiseAtenDivIntScalarModule_basic",
# ERROR: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int'
"ElementwiseMulScalarModule_int",
# ERROR: 'torch.aten.sub.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' # ERROR: 'torch.aten.sub.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
"ElementwiseSubScalarFloatModule_basic", "ElementwiseSubScalarFloatModule_basic",
"ElementwiseSubScalarIntModule_basic", "ElementwiseSubScalarIntModule_basic",
@ -315,10 +292,6 @@ TORCHDYNAMO_XFAIL_SET = {
# ERROR: shape (torch.Size([12])) is not equal to golden shape (torch.Size([3, 4])) # ERROR: shape (torch.Size([12])) is not equal to golden shape (torch.Size([3, 4]))
"ArangeStartOutViewModule_basic", "ArangeStartOutViewModule_basic",
# ERROR: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
"GroupNormModule_basic",
"GroupNormNoWeightAndBiasModule_basic",
# Dynamo does not support tracing quantized tensors # Dynamo does not support tracing quantized tensors
"ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerChannelModule_basic",
"ElementwiseDequantizePerTensorModule_basic", "ElementwiseDequantizePerTensorModule_basic",
@ -376,6 +349,9 @@ TORCHDYNAMO_CRASHING_SET = {
"MaxPool3dModule_basic", "MaxPool3dModule_basic",
"MaxPool3dStaticCeilModeTrueModule_basic", "MaxPool3dStaticCeilModeTrueModule_basic",
"MaxPool3dStaticModule_basic", "MaxPool3dStaticModule_basic",
# Looks like incorrect fx graph conversion
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
} }
STABLEHLO_PASS_SET = { STABLEHLO_PASS_SET = {

View File

@ -53,6 +53,40 @@ def _returns_empty_tuple(fx_graph: torch.fx.GraphModule) -> bool:
return False return False
return True return True
# Replaces torch.aten.add.Tensor/torch.aten.mul.Tensor to
# torch.aten.add.Scalar/torch.aten.mul.Scalar in case of Scalar argument
# Cannot be done on earlier stage, e.g. in _FXGraphImporter as it
# needs to check argument types, which are not yet determined.
# Maybe schema or target should be changed, but it decided in
# _dynamo eval_frame on pytorch side. Also Python schema not matches
# with mlir Schema - check include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
# So in general it covers some of overload cases, which done on Python side automatically.
# e.g. conversion Scalar -> Tensor and vice versa
def scalarize_tensor_ops_on_scalars(gm: torch.fx.GraphModule):
# Modify gm.graph
for node in gm.graph.nodes:
# Checks if we're calling a function (i.e:
# torch.add)
if node.op == 'call_function':
# The target attribute is the function
# that call_function calls.
# call_function[target=torch.ops.aten.add.Tensor](args = (%arg64_1, 1), kwargs = {})
if node.target == torch.ops.aten.add.Tensor:
if len(node.args) != 2 or node.kwargs != {}:
continue
elif not isinstance(node.args[1], torch.fx.node.Node):
node.target = torch.ops.aten.add.Scalar
if node.target == torch.ops.aten.mul.Tensor:
if len(node.args) != 2 or node.kwargs != {}:
continue
elif not isinstance(node.args[1], torch.fx.node.Node):
node.target = torch.ops.aten.mul.Scalar
gm.graph.lint() # Does some checks to make sure the
# Recompile the forward() method of `gm` from its Graph
gm.recompile()
def jit( def jit(
model: torch.nn.Module, model: torch.nn.Module,
@ -87,6 +121,8 @@ def jit(
# way of differentiating between the two. # way of differentiating between the two.
assert not _returns_empty_tuple(gm), "encountered graph that does not return anything" assert not _returns_empty_tuple(gm), "encountered graph that does not return anything"
scalarize_tensor_ops_on_scalars(gm)
nonlocal mlir_module nonlocal mlir_module
*_, model_name, nth_graph = get_aot_compilation_context() *_, model_name, nth_graph = get_aot_compilation_context()
mlir_module = import_fx_graph_as_func(gm.graph, model_name) mlir_module = import_fx_graph_as_func(gm.graph, model_name)