From 2c56ef9252c03f5372922cd710f65908e0d02b4c Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Mon, 8 Apr 2024 20:05:42 +0800 Subject: [PATCH] [Torch Dialect] canonicalize aten.sign to aten.sgn (#3112) * `aten.sign` is a sub-set of `aten.sgn` (`aten.sgn` support complex type). --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 91 ++++++++++--------- lib/Dialect/Torch/IR/TorchOps.cpp | 11 +++ .../Transforms/AbstractInterpLibrary.cpp | 8 ++ .../Torch/Transforms/DecomposeComplexOps.cpp | 48 ++++++---- projects/pt1/e2e_testing/xfail_sets.py | 5 + .../build_tools/abstract_interp_lib_gen.py | 8 ++ .../build_tools/torch_ods_gen.py | 2 +- .../test_suite/elementwise.py | 45 +++++++++ 8 files changed, 152 insertions(+), 66 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 233d3488f..79c1fb379 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -391,51 +391,6 @@ def Torch_AtenSigmoid_Op : Torch_Op<"aten.sigmoid_", [ }]; } -def Torch_AtenSignOp : Torch_Op<"aten.sign", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::sign : (Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self - ); - let results = (outs - AnyTorchOptionalTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenSignOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenSignOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; -} - -def Torch_AtenSign_Op : Torch_Op<"aten.sign_", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::sign_ : (Tensor) -> (Tensor)`"; - let arguments = (ins - Torch_NonValueTensorType:$self - ); - let results = (outs - AnyTorchOptionalNonValueTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenSign_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenSign_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; -} - def Torch_AtenSinhOp : Torch_Op<"aten.sinh", [ AllowsTypeRefinement, HasValueSemantics, @@ -4218,6 +4173,52 @@ def Torch_AtenRound_Op : Torch_Op<"aten.round_", [ }]; } +def Torch_AtenSignOp : Torch_Op<"aten.sign", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::sign : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSignOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenSignOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; + let hasCanonicalizer = 1; +} + +def Torch_AtenSign_Op : Torch_Op<"aten.sign_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::sign_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + AnyTorchOptionalNonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSign_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenSign_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenMaskedFillTensorOp : Torch_Op<"aten.masked_fill.Tensor", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index d0d63b8a4..5802e9122 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1793,6 +1793,17 @@ OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) { return {}; } +//===----------------------------------------------------------------------===// +// AtenSignOp +//===----------------------------------------------------------------------===// +void AtenSignOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenSignOp op, PatternRewriter &rewriter) { + rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf()); + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenMulScalarOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 444fdca68..9903ec841 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6442,6 +6442,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.sgn\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.detach\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -10129,6 +10133,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sgn\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.floor\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 77976ead9..e720ca04a 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -6971,46 +6971,54 @@ public: } // namespace namespace { -// Decompose `aten.sign` op into comparisons and aten.where. -class DecomposeAtenSignOp : public OpRewritePattern { +// Decompose `aten.sgn` op into comparisons and aten.where. +class DecomposeAtenSgnOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenSignOp op, + LogicalResult matchAndRewrite(AtenSgnOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto outType = op.getType().dyn_cast(); - if (!outType) - return rewriter.notifyMatchFailure( - op, "Only tensor types input are currently supported"); + auto outType = op.getType().cast(); + if (!outType.hasDtype()) { + return rewriter.notifyMatchFailure(op, + "expected result type to have dtype"); + } + // TODO: support complex type in future. + if (outType.getDtype().isa()) { + return rewriter.notifyMatchFailure(op, + "doesn't support complex type now"); + } auto zero = - rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); auto one = - rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); auto minusOne = - rewriter.create(loc, rewriter.getF64FloatAttr(-1.0)); + rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); auto compTy = outType.getWithSizesAndDtype(outType.getOptionalSizes(), rewriter.getI1Type()); auto greater = rewriter.create(loc, compTy, op.getSelf(), zero); - auto greaterEqual = - rewriter.create(loc, compTy, op.getSelf(), zero); + auto less = + rewriter.create(loc, compTy, op.getSelf(), zero); // Pseudo code: - // if (in >= 0) - // if (in > 0) + // if (in > 0) // return 1 - // else - // return 0 - // else + // else if (in < 0) // return -1 + // else + // return 0 + // note: return 0 if nan/0.0/-0.0 + // return 1 if inf + // return -1 if -inf auto selectGreater = rewriter.create(loc, outType, greater, one, zero); - rewriter.replaceOpWithNewOp( - op, outType, greaterEqual, selectGreater, minusOne); + rewriter.replaceOpWithNewOp(op, outType, less, + minusOne, selectGreater); return success(); } }; @@ -7606,7 +7614,7 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index ee4632b06..769940217 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -888,6 +888,8 @@ STABLEHLO_PASS_SET = { "ViewTwoFiveThreeStaticModule_basic", "ViewTwoToThreeStaticModule_basic", "ElementwiseLog1pModule_basic", + "ElementwiseSgnModule_basic", + "ElementwiseSignIntModule_basic", } STABLEHLO_CRASHING_SET = { @@ -897,6 +899,8 @@ STABLEHLO_CRASHING_SET = { # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "ElementwiseSgnModule_basic", + "ElementwiseSignIntModule_basic", "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", "AddCDivModule_basic", @@ -1567,6 +1571,7 @@ ONNX_XFAIL_SET = { "ViewSizeFromOtherTensor_basic", # Failure - onnx_export + "ElementwiseSgnModule_basic", "AdaptiveAvgPool1dGeneralDynamic_basic", "AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic", "AdaptiveAvgPool1dStaticLargerOutput_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 10be100db..ff8e7cf47 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -200,6 +200,9 @@ def aten〇floor〡shape(self: List[int]) -> List[int]: def aten〇sign〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇sgn〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇detach〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -2346,6 +2349,11 @@ def aten〇sign〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇sgn〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇floor〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index d8a6d0b45..b79277ee0 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -269,7 +269,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): "aten::log : (Tensor) -> (Tensor)", "aten::selu : (Tensor) -> (Tensor)", "aten::sigmoid : (Tensor) -> (Tensor)", - "aten::sign : (Tensor) -> (Tensor)", "aten::sinh : (Tensor) -> (Tensor)", "aten::sgn : (Tensor) -> (Tensor)", "aten::hardsigmoid : (Tensor) -> (Tensor)", @@ -357,6 +356,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::ceil : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True) + emit_with_mutating_variants("aten::sign : (Tensor) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index bfe7979f0..26f4676f3 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1986,6 +1986,51 @@ class ElementwiseSignModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseSignModule()) def ElementwiseSignModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([[-2.0, 0.0, 1.1, 2.0], + [6.0, -0.0, torch.inf, -torch.inf]])) + + +# ============================================================================== + + +class ElementwiseSignIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4], torch.int64, True), + ]) + def forward(self, a): + return torch.ops.aten.sign(a) + + +@register_test_case(module_factory=lambda: ElementwiseSignIntModule()) +def ElementwiseSignIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-100, high=100)) + + +# ============================================================================== + + +class ElementwiseSgnModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.sgn(a) + + +@register_test_case(module_factory=lambda: ElementwiseSgnModule()) +def ElementwiseSgnModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4))