mirror of https://github.com/llvm/torch-mlir
[TorchDynamo] Enable Elemtwise ops for Scalar arg (#2744)
This commit provides dummy solution to support elmentwise operations (mul, add) with scalar argument. ( op(Tensor, Scalar) ) It replaces `torch.aten.add.Tensor` with `torch.aten.add.Scalar`. ``` Unexpected outcome summary: (torchdynamo) ****** Unexpectedly Passed tests - 22 tests XPASS - "AddCDivModule_basic" XPASS - "BatchNorm1DModule_basic" XPASS - "BatchNorm1DStaticShapeModule_basic" XPASS - "BatchNorm1DWith2DInputModule_basic" XPASS - "BatchNorm2DModule_basic" XPASS - "BatchNorm3DModule_basic" XPASS - "ElementwiseAddScalarInt64Module_basic" XPASS - "ElementwiseAddScalarIntModule_basic" XPASS - "ElementwiseMulScalarModule_basic" XPASS - "ElementwiseMulScalarModule_float" XPASS - "ElementwiseMulScalarModule_int" XPASS - "GroupNormModule_basic" XPASS - "GroupNormNoWeightAndBiasModule_basic" XPASS - "MobilenetV3Module_basic" XPASS - "NativeBatchNorm1DModule_basic" XPASS - "NativeBatchNorm2DModule_basic" XPASS - "NativeBatchNorm3DModule_basic" XPASS - "NativeBatchNormNoneWeightModule_basic" XPASS - "NativeGroupNormBackwardModule_basic" XPASS - "NativeGroupNormModule_basic" XPASS - "ResNet18Module_basic" XPASS - "ResNet18StaticModule_basic" ``` And segfault for test "ElementwiseAddScalar_TensorLiteralInt32_Module_basic". Somehow this change doesn't allow to use Tensors, that are not forward arguments, but local variables of model. e.g. `self.x = torch.tensor(..)` See also: #2745 Signed-off-by: Dmitrii Makarenko <dmitrii.makarenko@intel.com>pull/2920/head
parent
8fb28661f9
commit
4b1e87ce67
|
@ -215,10 +215,6 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
'ConstantBoolParameterModule_basic',
|
||||
|
||||
# START tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||
"AddCDivModule_basic",
|
||||
"ElementwiseMulScalarModule_basic",
|
||||
"ElementwiseMulScalarModule_float",
|
||||
"NativeGroupNormBackwardModule_basic",
|
||||
"UpSampleNearest2dDynamicSize_basic",
|
||||
"UpSampleNearest2dStaticFactor_basic",
|
||||
"UpSampleNearest2dStaticSize_basic",
|
||||
|
@ -226,23 +222,7 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
# END tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||
|
||||
# START tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||
"AtenInstanceNormModule_basic",
|
||||
"BatchNorm1DModule_basic",
|
||||
"BatchNorm1DWith2DInputModule_basic",
|
||||
"BatchNorm2DModule_basic",
|
||||
"BatchNorm3DModule_basic",
|
||||
"BatchNorm1DStaticShapeModule_basic",
|
||||
"ElementwiseAddScalarFloatModule_basic",
|
||||
"ElementwiseAddScalarInt64Module_basic",
|
||||
"ElementwiseAddScalarIntModule_basic",
|
||||
"MobilenetV3Module_basic",
|
||||
"NativeBatchNorm1DModule_basic",
|
||||
"NativeBatchNorm2DModule_basic",
|
||||
"NativeBatchNorm3DModule_basic",
|
||||
"NativeBatchNormNoneWeightModule_basic",
|
||||
"NativeGroupNormModule_basic",
|
||||
"ResNet18Module_basic",
|
||||
"ResNet18StaticModule_basic",
|
||||
# END tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||
|
||||
# ERROR: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int'
|
||||
|
@ -255,9 +235,6 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
# ERROR: 'torch.aten.div.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int'
|
||||
"ElementwiseAtenDivIntScalarModule_basic",
|
||||
|
||||
# ERROR: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int'
|
||||
"ElementwiseMulScalarModule_int",
|
||||
|
||||
# ERROR: 'torch.aten.sub.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||
"ElementwiseSubScalarFloatModule_basic",
|
||||
"ElementwiseSubScalarIntModule_basic",
|
||||
|
@ -315,10 +292,6 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
# ERROR: shape (torch.Size([12])) is not equal to golden shape (torch.Size([3, 4]))
|
||||
"ArangeStartOutViewModule_basic",
|
||||
|
||||
# ERROR: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
|
||||
"GroupNormModule_basic",
|
||||
"GroupNormNoWeightAndBiasModule_basic",
|
||||
|
||||
# Dynamo does not support tracing quantized tensors
|
||||
"ElementwiseDequantizePerChannelModule_basic",
|
||||
"ElementwiseDequantizePerTensorModule_basic",
|
||||
|
@ -376,6 +349,9 @@ TORCHDYNAMO_CRASHING_SET = {
|
|||
"MaxPool3dModule_basic",
|
||||
"MaxPool3dStaticCeilModeTrueModule_basic",
|
||||
"MaxPool3dStaticModule_basic",
|
||||
|
||||
# Looks like incorrect fx graph conversion
|
||||
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
|
||||
}
|
||||
|
||||
STABLEHLO_PASS_SET = {
|
||||
|
|
|
@ -53,6 +53,40 @@ def _returns_empty_tuple(fx_graph: torch.fx.GraphModule) -> bool:
|
|||
return False
|
||||
return True
|
||||
|
||||
# Replaces torch.aten.add.Tensor/torch.aten.mul.Tensor to
|
||||
# torch.aten.add.Scalar/torch.aten.mul.Scalar in case of Scalar argument
|
||||
# Cannot be done on earlier stage, e.g. in _FXGraphImporter as it
|
||||
# needs to check argument types, which are not yet determined.
|
||||
# Maybe schema or target should be changed, but it decided in
|
||||
# _dynamo eval_frame on pytorch side. Also Python schema not matches
|
||||
# with mlir Schema - check include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
|
||||
# So in general it covers some of overload cases, which done on Python side automatically.
|
||||
# e.g. conversion Scalar -> Tensor and vice versa
|
||||
def scalarize_tensor_ops_on_scalars(gm: torch.fx.GraphModule):
|
||||
# Modify gm.graph
|
||||
for node in gm.graph.nodes:
|
||||
# Checks if we're calling a function (i.e:
|
||||
# torch.add)
|
||||
if node.op == 'call_function':
|
||||
# The target attribute is the function
|
||||
# that call_function calls.
|
||||
# call_function[target=torch.ops.aten.add.Tensor](args = (%arg64_1, 1), kwargs = {})
|
||||
if node.target == torch.ops.aten.add.Tensor:
|
||||
if len(node.args) != 2 or node.kwargs != {}:
|
||||
continue
|
||||
elif not isinstance(node.args[1], torch.fx.node.Node):
|
||||
node.target = torch.ops.aten.add.Scalar
|
||||
if node.target == torch.ops.aten.mul.Tensor:
|
||||
if len(node.args) != 2 or node.kwargs != {}:
|
||||
continue
|
||||
elif not isinstance(node.args[1], torch.fx.node.Node):
|
||||
node.target = torch.ops.aten.mul.Scalar
|
||||
|
||||
gm.graph.lint() # Does some checks to make sure the
|
||||
|
||||
# Recompile the forward() method of `gm` from its Graph
|
||||
gm.recompile()
|
||||
|
||||
|
||||
def jit(
|
||||
model: torch.nn.Module,
|
||||
|
@ -87,6 +121,8 @@ def jit(
|
|||
# way of differentiating between the two.
|
||||
assert not _returns_empty_tuple(gm), "encountered graph that does not return anything"
|
||||
|
||||
scalarize_tensor_ops_on_scalars(gm)
|
||||
|
||||
nonlocal mlir_module
|
||||
*_, model_name, nth_graph = get_aot_compilation_context()
|
||||
mlir_module = import_fx_graph_as_func(gm.graph, model_name)
|
||||
|
|
Loading…
Reference in New Issue