From 7673a8ff2851d4ffc3a79584eddb0ddef741e704 Mon Sep 17 00:00:00 2001 From: Bratislav Filipovic Date: Thu, 19 Sep 2024 14:55:21 +0200 Subject: [PATCH] [TorchToLinalg]Lower torch.gcd to linalg and scf Add verify() method to check if tensors are of integer type. Also check if tensors are of same shape, or if the second tensor is a single element tensor. Add e2e tests. Put them into onnx and stablehlo xfailed sets. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 ++++ lib/Conversion/TorchToLinalg/Linear.cpp | 110 ++++++++++++++++++ lib/Dialect/Torch/IR/TorchOps.cpp | 34 ++++++ .../Transforms/AbstractInterpLibrary.cpp | 81 ++++++++++--- projects/pt1/e2e_testing/xfail_sets.py | 6 + .../build_tools/abstract_interp_lib_gen.py | 11 ++ .../build_tools/torch_ods_gen.py | 1 + .../test_suite/elementwise.py | 49 ++++++++ projects/pt1/tools/e2e_test.sh | 2 - 9 files changed, 302 insertions(+), 17 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 0b1a8b257..edadc94dd 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -13148,6 +13148,31 @@ def Torch_AtenStftOp : Torch_Op<"aten.stft", [ }]; } +def Torch_AtenGcdOp : Torch_Op<"aten.gcd", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::gcd : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenGcdOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenGcdOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasVerifier = 1; +} + def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 52765411b..5bedc826f 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -13,6 +13,8 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" @@ -213,6 +215,112 @@ public: }; } // namespace +namespace { +class ConvertAtenGcdOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(torch::Torch::AtenGcdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto self = adaptor.getSelf(); // tensor A + auto other = adaptor.getOther(); // tensor B of the same size + auto loc = op.getLoc(); + + TensorType resultType = + cast(getTypeConverter()->convertType(op.getType())); + + auto gcdPayloadBody = [&](OpBuilder &b, Location loc, + ValueRange payloadArgs) { + auto A = payloadArgs[0]; + A = b.create(loc, A); + auto B = payloadArgs[1]; + B = b.create(loc, B); + auto two = b.create(loc, 2, A.getType()); + auto one = b.create(loc, 1, A.getType()); + auto zero = b.create(loc, 0, A.getType()); + + auto trailingZeroConditionBlock = [&](mlir::OpBuilder &b, + mlir::Location loc, + mlir::ValueRange whileArgs) { + auto current = whileArgs[0]; + auto counter = whileArgs[1]; + auto currentAndOne = b.create(loc, current, one); + auto cmp = b.create( + loc, mlir::arith::CmpIPredicate::sgt, currentAndOne, one); + b.create(loc, cmp, + ValueRange{current, counter}); + }; + auto trailingZerosBodyBlock = [&](mlir::OpBuilder &b, mlir::Location loc, + mlir::ValueRange args) { + auto current = args[0]; + auto counter = args[1]; + auto divided = b.create(loc, current, two); + auto newCounter = b.create(loc, counter, one); + b.create( + loc, ValueRange{divided.getResult(), newCounter.getResult()}); + }; + + auto AtrailingZerosOp = b.create( + loc, TypeRange{A.getType(), zero.getType()}, ValueRange{A, zero}, + trailingZeroConditionBlock, trailingZerosBodyBlock); + auto BtrailingZerosOp = b.create( + loc, TypeRange{B.getType(), zero.getType()}, ValueRange{B, zero}, + trailingZeroConditionBlock, trailingZerosBodyBlock); + + Value AtrailingZerosCount = AtrailingZerosOp.getResult(0); + Value BtrailingZerosCount = BtrailingZerosOp.getResult(0); + auto smalerZerosCount = b.create( + loc, AtrailingZerosCount, BtrailingZerosCount); + auto shiftedA = b.create(loc, A, smalerZerosCount); + auto shiftedB = b.create(loc, B, smalerZerosCount); + + auto findGcdConditionBlock = [&](mlir::OpBuilder &b, mlir::Location loc, + mlir::ValueRange args) { + Value min = b.create(loc, args[0], args[1]); + Value max = + b.create(loc, payloadArgs[0], payloadArgs[1]); + + auto cmp = b.create( + loc, mlir::arith::CmpIPredicate::ne, min, zero); + b.create(loc, cmp, ValueRange{min, max}); + }; + auto findGcdBodyBlock = [&](mlir::OpBuilder &b, mlir::Location loc, + mlir::ValueRange args) { + Value min = args[0]; + Value max = args[1]; + max = b.create(loc, max, min); + + auto maxTrailingZerosOp = b.create( + loc, TypeRange{B.getType(), zero.getType()}, ValueRange{max, zero}, + trailingZeroConditionBlock, trailingZerosBodyBlock); + Value maxTrailingZerosCount = maxTrailingZerosOp.getResult(0); + max = b.create(loc, max, maxTrailingZerosCount); + b.create(loc, ValueRange{min, max}); + }; + + auto findGcdWhileOp = b.create( + loc, TypeRange{shiftedA.getType(), shiftedB.getType()}, + ValueRange{shiftedA, shiftedB}, findGcdConditionBlock, + findGcdBodyBlock); + + Value gcdResult = findGcdWhileOp.getResult(1); + gcdResult = + b.create(loc, gcdResult, smalerZerosCount); + + b.create(loc, gcdResult); + }; + + other = torch_to_linalg::createElementwiseLinalgGeneric( + rewriter, loc, ValueRange{self, other}, + cast(self.getType()).getElementType(), gcdPayloadBody); + + rewriter.replaceOpWithNewOp(op, resultType, other); + return success(); + } +}; +} // namespace + namespace { class ConvertAtenFlipOp : public OpConversionPattern { public: @@ -1400,4 +1508,6 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index bed228671..d45e7cff8 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -5524,3 +5524,37 @@ LogicalResult AtenRot90Op::verify() { return success(); } + +LogicalResult AtenGcdOp::verify() { + + auto selfType = cast(getSelf().getType()); + auto otherType = cast(getOther().getType()); + + if (!selfType.hasDtype() || !selfType.hasSizes() || !otherType.hasDtype() || + !otherType.hasSizes()) + return success(); + + auto selfShape = selfType.getSizes(); + auto otherShape = selfType.getSizes(); + int64_t selfRank = selfShape.size(); + int64_t otherRank = otherShape.size(); + auto selfDtype = selfType.getDtype(); + + if (!isa(selfDtype)) + return emitOpError("expected an integer type for input tensor, but got ") + << selfDtype; + + if (otherRank == 1 && otherShape[0] == 1) + return success(); + + if (selfRank != otherRank) + return emitOpError("Tensors must be of same rank or second tensor must be " + "a single element tensor"); + + for (int i = 0; i < selfRank; i++) { + if (selfShape[i] != otherShape[i]) + return emitOpError("Dimensions od tensors font match in dim ") << i; + } + + return success(); +} diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 59cf69393..5c38a9d74 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6639,6 +6639,72 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %8 : !torch.tuple, list>\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.gcd\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: Shapes must be the same or 'other' must be a single element tensor.\"\n" +" %false = torch.constant.bool false\n" +" %true = torch.constant.bool true\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.eq.int_list %arg0, %arg1 : !torch.list, !torch.list -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %5 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %6 = torch.aten.eq.int %5, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %6 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If.yield %4 : !torch.bool\n" +" }\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.gcd\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: aten.gcd works only with integer types\"\n" +" %false = torch.constant.bool false\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" +" %4 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %4 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n" +" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() : () -> !torch.list\n" +" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" +" return %1 : !torch.bool\n" +" }\n" +" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() -> !torch.list {\n" +" %int4 = torch.constant.int 4\n" +" %int3 = torch.constant.int 3\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %int11 = torch.constant.int 11\n" +" %0 = torch.prim.ListConstruct %int11, %int0, %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.detach\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -11238,21 +11304,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %3 : !torch.int\n" " }\n" -" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n" -" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() : () -> !torch.list\n" -" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" -" return %1 : !torch.bool\n" -" }\n" -" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() -> !torch.list {\n" -" %int4 = torch.constant.int 4\n" -" %int3 = torch.constant.int 3\n" -" %int2 = torch.constant.int 2\n" -" %int1 = torch.constant.int 1\n" -" %int0 = torch.constant.int 0\n" -" %int11 = torch.constant.int 11\n" -" %0 = torch.prim.ListConstruct %int11, %int0, %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" -" return %0 : !torch.list\n" -" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.sin\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index bdb4d7f47..1f2d586df 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -923,6 +923,9 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "SplitTensorNegativeDimModule_basic", "SplitWithSizesListUnpackModule_basic", "SplitWithSizes_Module_basic", + "GCDBatchedModule_I32", + "GCDDynamicModule_I32", + "GCDModule_I32", } FX_IMPORTER_STABLEHLO_CRASHING_SET = { @@ -3126,6 +3129,9 @@ ONNX_XFAIL_SET = { "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", "UnfoldModule_basic", + "GCDBatchedModule_I32", + "GCDDynamicModule_I32", + "GCDModule_I32", } if torch_version_for_comparison() < version.parse("2.3.0.dev"): diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index bc49757ee..5cc558f0f 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -265,6 +265,17 @@ def aten〇linalg_slogdet〡shape(A: List[int]) -> Tuple[List[int], List[int]]: shape = upstream_shape_functions.zero_dim_tensor(A) return shape, shape +def aten〇gcd〡shape(self: List[int], other: List[int]) -> List[int]: + assert self == other or (len(other) == 1 and other[0]==0), "Shapes must be the same or 'other' must be a single element tensor." + return self + +def aten〇gcd〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + other_rank, other_dtype = other_rank_dtype + assert is_integer_dtype(self_dtype) and is_integer_dtype(other_dtype), "aten.gcd works only with integer types" + return self_dtype + + def aten〇detach〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 5f53e17b9..f5a3d6608 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -964,6 +964,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit( "aten::stft : (Tensor, int, int?, int?, Tensor?, bool, bool?, bool?) -> (Tensor)" ) + emit("aten::gcd : (Tensor, Tensor) -> (Tensor)", has_verifier=True) # Functionalization ops emit("aten::alias_copy : (Tensor) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 9b4dbe659..dde3c3074 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -6845,3 +6845,52 @@ class TrilIndicesOfssetGreaterThanRowModule(torch.nn.Module): @register_test_case(module_factory=lambda: TrilIndicesOfssetGreaterThanRowModule()) def TrilIndicesOfssetGreaterThanRowModule_basic(module, tu: TestUtils): module.forward() + + +# ============================================================================== + + +class GCDModule(torch.nn.Module): + @export + @annotate_args([None, [(4, 4), torch.int32, True], [(4, 4), torch.int32, True]]) + def forward(self, A, B): + return torch.gcd(A, B) + + +@register_test_case(module_factory=lambda: GCDModule()) +def GCDModule_I32(module, tu: TestUtils): + A = tu.rand(4, 4).to(dtype=torch.int32) + B = tu.rand(4, 4).to(dtype=torch.int32) + module.forward(A, B) + + +class GCDBatchedModule(torch.nn.Module): + @export + @annotate_args( + [None, [(4, 4, 4), torch.int32, True], [(4, 4, 4), torch.int32, True]] + ) + def forward(self, A, B): + return torch.gcd(A, B) + + +@register_test_case(module_factory=lambda: GCDBatchedModule()) +def GCDBatchedModule_I32(module, tu: TestUtils): + A = tu.rand(4, 4, 4).to(dtype=torch.int32) + B = tu.rand(4, 4, 4).to(dtype=torch.int32) + module.forward(A, B) + + +class GCDDynamicModule(torch.nn.Module): + @export + @annotate_args( + [None, [(-1, -1, -1), torch.int32, True], [(-1, -1, -1), torch.int32, True]] + ) + def forward(self, A, B): + return torch.gcd(A, B) + + +@register_test_case(module_factory=lambda: GCDDynamicModule()) +def GCDDynamicModule_I32(module, tu: TestUtils): + A = tu.rand(3, 4, 4).to(dtype=torch.int32) + B = tu.rand(3, 4, 4).to(dtype=torch.int32) + module.forward(A, B) diff --git a/projects/pt1/tools/e2e_test.sh b/projects/pt1/tools/e2e_test.sh index a16929302..73d3361b6 100755 --- a/projects/pt1/tools/e2e_test.sh +++ b/projects/pt1/tools/e2e_test.sh @@ -8,6 +8,4 @@ cd "$src_dir" # Ensure PYTHONPATH is set for export to child processes, even if empty. export PYTHONPATH=${PYTHONPATH-} -source $project_dir/.env - python -m e2e_testing.main "$@"