From 5684dc0441a62d89968230cb9b3e9801ecc87898 Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Sun, 28 Apr 2024 17:23:40 +0800 Subject: [PATCH 01/30] [Torch] emit aten.celu and decompose it (#3247) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CELU(x)=max(0,x)+min(0,α∗(exp(x/α)−1)) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 47 +++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 8 ++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 45 ++++++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 3 ++ .../build_tools/abstract_interp_lib_gen.py | 8 ++++ .../build_tools/torch_ods_gen.py | 1 + .../test_suite/elementwise.py | 46 ++++++++++++++++++ 8 files changed, 159 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 281637f15..4a234307d 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4810,6 +4810,53 @@ def Torch_AtenPreluOp : Torch_Op<"aten.prelu", [ }]; } +def Torch_AtenCeluOp : Torch_Op<"aten.celu", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::celu : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$alpha + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCeluOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenCeluOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenCelu_Op : Torch_Op<"aten.celu_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::celu_ : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + AnyTorchScalarType:$alpha + ); + let results = (outs + AnyTorchOptionalNonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCelu_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenCelu_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenRealOp : Torch_Op<"aten.real", [ AllowsTypeRefinement, ReadOnly diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index faf3d0a25..553a8dc74 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6998,6 +6998,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.celu\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -10480,6 +10484,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.celu\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 6cb02297d..677ccc4f2 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2415,6 +2415,50 @@ public: } // namespace +// CELU(x)=max(0,x)+min(0,alpha∗(exp(x/alpha)−1)) +namespace { +class DecomposeAtenCeluOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenCeluOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = op.getSelf(); + Value alpha = op.getAlpha(); + auto resType = cast(op.getType()); + if (!resType.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + + Value constantZero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value constantOne = + rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + + // positiveOutput = max(0,x) + Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero); + Value positiveOutput = + rewriter.create(loc, resType, zeroTensor, input); + + // negativeOutput = min(0,alpha∗(exp(x/alpha)−1)) + Value scaledInput = + rewriter.create(loc, resType, input, alpha); + Value expX = rewriter.create(loc, resType, scaledInput); + Value expXM1 = rewriter.create(loc, resType, expX, + constantOne, constantOne); + Value scaledExpXM1 = + rewriter.create(loc, resType, expXM1, alpha); + Value negativeOutput = + rewriter.create(loc, resType, zeroTensor, scaledExpXM1); + Value celuOutput = rewriter.create( + loc, resType, positiveOutput, negativeOutput, constantOne); + + rewriter.replaceOp(op, celuOutput); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenLerpScalarOp : public OpRewritePattern { public: @@ -7705,6 +7749,7 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index c5855a1fa..e7bed6463 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -474,6 +474,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 87344fb99..276cc47c1 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -951,6 +951,7 @@ STABLEHLO_PASS_SET = { "ElementwiseBitwiseRightShiftInt64Module_basic", "ElementwiseBitwiseRightShiftInt8Module_basic", "ElementwiseCeilModule_basic", + "ElementwiseCeluStaticModule_basic", "ElementwiseClampMaxModule_basic", "ElementwiseClampMinModule_basic", "ElementwiseClampMinTensorFloatModule_basic", @@ -1571,6 +1572,8 @@ TOSA_PASS_SET = { "ElementwiseBitwiseXorModule_basic", "ElementwiseBitwiseXorStaticShapeModule_basic", "ElementwiseCeilModule_basic", + "ElementwiseCeluModule_basic", + "ElementwiseCeluStaticModule_basic", "ElementwiseClampMaxModule_basic", "ElementwiseClampMinModule_basic", "ElementwiseClampModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 6574c0bdc..da486fe46 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -526,6 +526,9 @@ def aten〇elu〡shape(self: List[int], alpha: float = 1, scale: float = 1, inpu def aten〇prelu〡shape(self: List[int], weight: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇celu〡shape(self: List[int], alpha: float = 1.) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇selu〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -2652,6 +2655,11 @@ def aten〇prelu〡dtype(self_rank_dtype: Tuple[int, int], weight_rank_dtype: Tu assert self_dtype == weight_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, alpha=1.)) +def aten〇celu〡dtype(self_rank_dtype: Tuple[int, int], alpha: Union[int, float, complex] = 1.) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool})) def aten〇relu6〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 3ebd00753..6e449c277 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -472,6 +472,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::floor_divide : (Tensor, Tensor) -> (Tensor)") emit("aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::prelu : (Tensor, Tensor) -> (Tensor)") + emit_with_mutating_variants("aten::celu : (Tensor, Scalar) -> (Tensor)") emit("aten::real : (Tensor) -> (Tensor)") emit("aten::imag : (Tensor) -> (Tensor)") emit("aten::view_as_complex : (Tensor) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 47f4a6403..d034e6d1f 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1016,6 +1016,52 @@ def ElementwisePreluStaticModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseCeluModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.celu(x, 0.5) + + +@register_test_case(module_factory=lambda: ElementwiseCeluModule()) +def ElementwiseCeluModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 3, low=-1, high=1)) + + +# ============================================================================== + + +class ElementwiseCeluStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([5, 3], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.celu(x) + + +@register_test_case(module_factory=lambda: ElementwiseCeluStaticModule()) +def ElementwiseCeluStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 3, low=-1, high=1)) + + +# ============================================================================== + + class ElementwiseGeluModule(torch.nn.Module): def __init__(self): super().__init__() From 9f64748f97fa543a2b6b227cd26f570622cd26f1 Mon Sep 17 00:00:00 2001 From: penguin_wwy <940375606@qq.com> Date: Mon, 29 Apr 2024 10:09:00 +0800 Subject: [PATCH 02/30] [FxImporter] Synchronize the collection of symbolic torch ops (#3236) --- python/torch_mlir/extras/fx_importer.py | 16 ++++------------ python/torch_mlir/fx.py | 4 ++-- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index c1eec37aa..9acf4ad03 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -236,12 +236,6 @@ _IS_TORCH_2_1_OR_EARLIER = torch.__version__.split("+")[0] <= "2.1.0" # set and just check key existence in SYMBOLIC_OP_TO_TORCH_OP if _IS_TORCH_2_1_OR_EARLIER: - SYMBOLIC_TORCH_OPS = { - torch.ops.aten.sym_size, - torch.ops.aten.sym_stride, - torch.ops.aten.sym_numel, - } - SYMBOLIC_OP_TO_TORCH_OP = { (torch.ops.aten.sym_size, 1): torch.ops.aten.size.default, (torch.ops.aten.sym_size, 2): torch.ops.aten.size.int, @@ -249,13 +243,9 @@ if _IS_TORCH_2_1_OR_EARLIER: (torch.ops.aten.sym_stride, 2): torch.ops.aten.stride.int, (torch.ops.aten.sym_numel, 1): torch.ops.aten.numel.default, } -else: - SYMBOLIC_TORCH_OPS = { - torch.ops.aten.sym_size.int, - torch.ops.aten.sym_stride.int, - torch.ops.aten.sym_numel.default, - } + SYMBOLIC_TORCH_OPS = {key[0] for key in SYMBOLIC_OP_TO_TORCH_OP} +else: SYMBOLIC_OP_TO_TORCH_OP = { torch.ops.aten.sym_size.default: torch.ops.aten.size.default, torch.ops.aten.sym_size.int: torch.ops.aten.size.int, @@ -264,6 +254,8 @@ else: torch.ops.aten.sym_numel.default: torch.ops.aten.numel.default, } + SYMBOLIC_TORCH_OPS = {key for key in SYMBOLIC_OP_TO_TORCH_OP} + @dataclass(frozen=True) class SparsityMeta: diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py index 0879dbe31..651ccae67 100644 --- a/python/torch_mlir/fx.py +++ b/python/torch_mlir/fx.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. -from typing import Optional, Union, Dict, Tuple, Any +from typing import Optional, Union, Dict, Tuple, Any, Callable import warnings @@ -25,7 +25,7 @@ def export_and_import( dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, experimental_support_mutation: bool = False, hooks: Optional[FxImporterHooks] = None, - decomposition_table: Optional[list] = None, + decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None, func_name: str = "main", enable_graph_printing: bool = False, **kwargs, From aed2cf3351ab2ffc8e9ccf1cc7e1f4a498071b13 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Mon, 29 Apr 2024 10:51:17 +0800 Subject: [PATCH 03/30] [Torch] emit aten.__contains__.str_list and add folder (#3249) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 ++++++++++++ .../torch-mlir/Dialect/Torch/IR/TorchOps.h | 31 ++++++++++++++ lib/Dialect/Torch/IR/TorchOps.cpp | 24 +++++++++++ .../build_tools/torch_ods_gen.py | 1 + test/Dialect/Torch/canonicalize.mlir | 40 +++++++++++++++---- 5 files changed, 113 insertions(+), 8 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 4a234307d..8ebd7b162 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -13626,6 +13626,31 @@ def Torch_AtenWarnOp : Torch_Op<"aten.warn", [ }]; } +def Torch_Aten__Contains__StrListOp : Torch_Op<"aten.__contains__.str_list", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::__contains__.str_list : (str[], str) -> (bool)`"; + let arguments = (ins + AnyTorchListOfTorchStringType:$l, + Torch_StringType:$item + ); + let results = (outs + Torch_BoolType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten__Contains__StrListOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void Aten__Contains__StrListOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; +} + def Torch_AtenFloatScalarOp : Torch_Op<"aten.Float.Scalar", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h index 4508518bf..f49fef072 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h @@ -239,6 +239,37 @@ m_TorchListOfConstantBools(SmallVectorImpl &bind_values) { return detail::torch_list_of_constant_bools_op_binder(bind_values); } +namespace detail { +/// Matches the constant strs stored in a `torch.ListConstruct`. +struct torch_list_of_constant_strs_op_binder { + SmallVectorImpl &bind_values; + + /// Creates a matcher instance that binds the value to bvs if match succeeds. + torch_list_of_constant_strs_op_binder(SmallVectorImpl &bvs) + : bind_values(bvs) {} + + bool match(Operation *op) { + auto listConstruct = dyn_cast(op); + if (!listConstruct) + return false; + for (Value value : listConstruct.getElements()) { + std::string str; + if (matchPattern(value, m_TorchConstantStr(str))) + bind_values.push_back(str); + else + return false; + } + return true; + } +}; +} // namespace detail + +/// Matches the constant strs stored in a `torch.prim.ListConstruct`. +inline detail::torch_list_of_constant_strs_op_binder +m_TorchListOfConstantStrs(SmallVectorImpl &bind_values) { + return detail::torch_list_of_constant_strs_op_binder(bind_values); +} + namespace detail { /// Matches the expected tensor and dim from `torch.aten.size.int`. struct torch_tensor_size_int_op_binder { diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 33079e35f..376e7dd2e 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2385,6 +2385,30 @@ OpFoldResult AtenNeStrOp::fold(FoldAdaptor adaptor) { return nullptr; } +//===----------------------------------------------------------------------===// +// Aten__Contains__StrListOp +//===----------------------------------------------------------------------===// + +OpFoldResult Aten__Contains__StrListOp::fold(FoldAdaptor adaptor) { + StringAttr item = dyn_cast(adaptor.getItem()); + if (!item) + return nullptr; + + if (auto listConstruct = getL().getDefiningOp()) { + if (isListPotentiallyMutated(listConstruct)) + return nullptr; + } + llvm::SmallVector strs; + if (matchPattern(getL(), m_TorchListOfConstantStrs(strs))) { + for (const auto &str : strs) { + if (item.getValue().str() == str) + return getI1IntegerAttr(getContext(), true); + } + return getI1IntegerAttr(getContext(), false); + } + return nullptr; +} + //===----------------------------------------------------------------------===// // AtenLtIntOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 6e449c277..ca867723c 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -974,6 +974,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::format : (...) -> (str)") emit("aten::join : (str, str[]) -> (str)") emit("aten::warn : (str, int) -> ()") + emit("aten::__contains__.str_list : (str[], str) -> (bool)", has_folder=True) # Type conversion ops. emit("aten::Float.Scalar : (Scalar) -> (float)", has_folder=True) diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 7fd4e9832..a317e4011 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -504,8 +504,8 @@ func.func @torch.aten.eq.str$different_value() -> !torch.bool { // CHECK-LABEL: func.func @torch.aten.eq.str$same_operand( // CHECK-SAME: %{{.*}}: !torch.str) -> !torch.bool { -// CHECK-NEXT: %[[F:.*]] = torch.constant.bool true -// CHECK-NEXT: return %[[F]] : !torch.bool +// CHECK-NEXT: %[[TRUE:.*]] = torch.constant.bool true +// CHECK-NEXT: return %[[TRUE]] : !torch.bool func.func @torch.aten.eq.str$same_operand(%arg0: !torch.str) -> !torch.bool { %0 = torch.aten.eq.str %arg0, %arg0 : !torch.str, !torch.str -> !torch.bool return %0 : !torch.bool @@ -522,8 +522,8 @@ func.func @torch.aten.eq.str$same_value() -> !torch.bool { } // CHECK-LABEL: func.func @torch.aten.ne.str$different_value() -> !torch.bool { -// CHECK: %[[FALSE:.*]] = torch.constant.bool true -// CHECK: return %[[FALSE]] : !torch.bool +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[TRUE]] : !torch.bool func.func @torch.aten.ne.str$different_value() -> !torch.bool { %str4 = torch.constant.str "4" %str5 = torch.constant.str "5" @@ -533,16 +533,16 @@ func.func @torch.aten.ne.str$different_value() -> !torch.bool { // CHECK-LABEL: func.func @torch.aten.ne.str$same_operand( // CHECK-SAME: %{{.*}}: !torch.str) -> !torch.bool { -// CHECK-NEXT: %[[F:.*]] = torch.constant.bool false -// CHECK-NEXT: return %[[F]] : !torch.bool +// CHECK-NEXT: %[[FALSE:.*]] = torch.constant.bool false +// CHECK-NEXT: return %[[FALSE]] : !torch.bool func.func @torch.aten.ne.str$same_operand(%arg0: !torch.str) -> !torch.bool { %0 = torch.aten.ne.str %arg0, %arg0 : !torch.str, !torch.str -> !torch.bool return %0 : !torch.bool } // CHECK-LABEL: func.func @torch.aten.ne.str$same_value() -> !torch.bool { -// CHECK: %[[TRUE:.*]] = torch.constant.bool false -// CHECK: return %[[TRUE]] : !torch.bool +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: return %[[FALSE]] : !torch.bool func.func @torch.aten.ne.str$same_value() -> !torch.bool { %str4 = torch.constant.str "4" %str4_0 = torch.constant.str "4" @@ -568,6 +568,30 @@ func.func @torch.aten.len.str$empty() -> !torch.int { return %2 : !torch.int } +// CHECK-LABEL: func.func @torch.aten.__contains__.str_list$false() -> !torch.bool { +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: return %[[FALSE]] : !torch.bool +func.func @torch.aten.__contains__.str_list$false() -> !torch.bool { + %str = torch.constant.str "c" + %str_0 = torch.constant.str "b" + %str_1 = torch.constant.str "a" + %1 = torch.prim.ListConstruct %str_1, %str_0 : (!torch.str, !torch.str) -> !torch.list + %2 = torch.aten.__contains__.str_list %1, %str : !torch.list, !torch.str -> !torch.bool + return %2 : !torch.bool +} + +// CHECK-LABEL: func.func @torch.aten.__contains__.str_list$true() -> !torch.bool { +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[TRUE]] : !torch.bool +func.func @torch.aten.__contains__.str_list$true() -> !torch.bool { + %str = torch.constant.str "aa" + %str_0 = torch.constant.str "aa" + %str_1 = torch.constant.str "ccc" + %1 = torch.prim.ListConstruct %str_1, %str_0 : (!torch.str, !torch.str) -> !torch.list + %2 = torch.aten.__contains__.str_list %1, %str : !torch.list, !torch.str -> !torch.bool + return %2 : !torch.bool +} + // CHECK-LABEL: func.func @torch.aten.__not__ // CHECK: %[[TRUE:.*]] = torch.constant.bool true // CHECK: return %[[TRUE]] : !torch.bool From b2185195e8fecb3568d53a97a502fc77a22a6daf Mon Sep 17 00:00:00 2001 From: penguin_wwy <940375606@qq.com> Date: Mon, 29 Apr 2024 11:06:01 +0800 Subject: [PATCH 04/30] [NFC] Update black version (#3256) * Update black version to support 3.11/3.12 * Reformat code --- .pre-commit-config.yaml | 2 +- build_tools/scrape_releases.py | 1 + .../torchscript_stablehlo_backend_tinybert.py | 1 + .../python/torch_mlir/_dynamo_fx_importer.py | 6 ++-- .../build_tools/torch_ods_gen.py | 6 ++-- .../configs/onnx_backend.py | 4 +-- .../linalg_on_tensors_backends/refbackend.py | 10 +++--- .../test_suite/elementwise.py | 4 +-- .../test_suite/slice_like.py | 1 + .../jit_ir/ivalue_import/debug-module-name.py | 1 + .../object-identity-torch-bug.py | 1 + .../jit_ir/ivalue_import/quantization.py | 1 + .../importer/jit_ir/node_import/debug-info.py | 1 + .../importer/jit_ir/node_import/elif.py | 1 + .../jit_ir/node_import/function-derefine.py | 1 + .../python/importer/jit_ir/node_import/if.py | 1 + .../importer/jit_ir/node_import/loop.py | 1 + .../importer/jit_ir/node_import/prim.py | 1 + .../importer/jit_ir/node_import/tuple.py | 1 + .../importer/jit_ir/node_import/types-bool.py | 1 + .../importer/jit_ir/node_import/types-none.py | 1 + .../importer/jit_ir/node_import/utils.py | 1 + python/torch_mlir/extras/fx_importer.py | 3 +- python/torch_mlir/extras/onnx_importer.py | 31 ++++++++++--------- 24 files changed, 49 insertions(+), 33 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 72329026f..f2938e28e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: - id: check-yaml - id: check-added-large-files - repo: https://github.com/psf/black - rev: 22.10.0 + rev: 24.4.2 hooks: - id: black diff --git a/build_tools/scrape_releases.py b/build_tools/scrape_releases.py index 88f19d92b..77aa41c15 100644 --- a/build_tools/scrape_releases.py +++ b/build_tools/scrape_releases.py @@ -2,6 +2,7 @@ See https://github.com/llvm/torch-mlir/issues/1374 """ + import argparse import json diff --git a/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py b/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py index af2af2de3..840ec519d 100644 --- a/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py +++ b/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py @@ -3,6 +3,7 @@ from torch_mlir import torchscript from transformers import BertForMaskedLM + # Wrap the bert model to avoid multiple returns problem class BertTinyWrapper(torch.nn.Module): def __init__(self) -> None: diff --git a/projects/pt1/python/torch_mlir/_dynamo_fx_importer.py b/projects/pt1/python/torch_mlir/_dynamo_fx_importer.py index fcea14dc1..81908d801 100644 --- a/projects/pt1/python/torch_mlir/_dynamo_fx_importer.py +++ b/projects/pt1/python/torch_mlir/_dynamo_fx_importer.py @@ -257,9 +257,9 @@ class _FXGraphImporter: # FakeTensor's in case of a tuple return with multiple elements. self._env: Dict[Tuple[torch.fx.Node, int], ir.Value] = {} self._module = ir.Module.create(ir.Location.unknown()) - self._module.operation.attributes[ - "torch.debug_module_name" - ] = ir.StringAttr.get(func_name) + self._module.operation.attributes["torch.debug_module_name"] = ( + ir.StringAttr.get(func_name) + ) function_type = _extract_function_type_from_graph(g) func = func_dialect.FuncOp( func_name, diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index ca867723c..eea8d31a9 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -285,9 +285,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): (ns, unqual + "_", overload if not is_functional_op else "") ), emitter_td, - traits=["IsTrailingUnderscoreInplaceVariant"] - if not is_functional_op - else [], + traits=( + ["IsTrailingUnderscoreInplaceVariant"] if not is_functional_op else [] + ), ) # ========================================================================== diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py index 6fa845ab3..7f630074e 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py @@ -46,7 +46,7 @@ def convert_onnx(model, inputs): examples = [] input_names = [] dynamic_tensors = {} - for (index, arg) in enumerate(inputs): + for index, arg in enumerate(inputs): shape = map(lambda d: d if d >= 0 else 1, arg.shape) shape = tuple(shape) examples.append(torch.zeros(size=shape, dtype=arg.dtype)) @@ -55,7 +55,7 @@ def convert_onnx(model, inputs): input_names.append(input_name) dynamic_dims = {} - for (dimindex, dim) in enumerate(arg.shape): + for dimindex, dim in enumerate(arg.shape): if dim < 0: dynamic_dims[dimindex] = "dim_{}_{}".format(index, dimindex) diff --git a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index a1611a1e5..1e958a4d9 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -101,10 +101,12 @@ class RefBackendInvoker: def consume_return_funcs(*args): self.result = tuple( [ - arg - if type in elemental_type_to_ctype - else unranked_memref_to_numpy( - arg, memref_type_to_np_dtype[type] + ( + arg + if type in elemental_type_to_ctype + else unranked_memref_to_numpy( + arg, memref_type_to_np_dtype[type] + ) ) for arg, type in zip(args, ret_types) ] diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index d034e6d1f..8e2875842 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -803,9 +803,7 @@ class QuantizedReluInt32(torch.nn.Module): @register_test_case(module_factory=lambda: QuantizedReluInt32()) def QuantizedReluInt32_basic(module, tu: TestUtils): - module.forward( - tu.randint(7, 4, low=(-(2**31)), high=(2**31 - 1)).to(torch.int32) - ) + module.forward(tu.randint(7, 4, low=(-(2**31)), high=(2**31 - 1)).to(torch.int32)) # ============================================================================== diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py index 07f064de7..be2a80d84 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py @@ -342,6 +342,7 @@ def SelectIntNegativeDimAndIndexStaticModule_basic(module, tu: TestUtils): # ============================================================================== + # For aten.slice_scatter op, The arguments are: SliceScatter(input, src, dim=0, start=None, end=None, step=1). # For aten.select_scatter op, The arguments are: SelectScatter(input, src, dim=0, index). class SliceScatterModule(torch.nn.Module): diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/debug-module-name.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/debug-module-name.py index bd21c4e8b..5af1a6b89 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/debug-module-name.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/debug-module-name.py @@ -11,6 +11,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder mb = ModuleBuilder() + # CHECK: module attributes {torch.debug_module_name = "TestModule"} class TestModule(torch.nn.Module): def __init__(self): diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-torch-bug.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-torch-bug.py index 4c323ec01..4c325308b 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-torch-bug.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/object-identity-torch-bug.py @@ -18,6 +18,7 @@ mb = ModuleBuilder() # `torch.Tensor` is just a pointer to a TensorImpl under the hood, and so # naively duplicating a Tensor retains the identity of the TensorImpl. + # CHECK-LABEL: torch.class_type @__torch__.TestModule { class TestModule(torch.nn.Module): def __init__(self): diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/quantization.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/quantization.py index e33985fac..df6f1736c 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/quantization.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/quantization.py @@ -12,6 +12,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder mb = ModuleBuilder() + # CHECK-LABEL: torch.class_type @__torch__.TestModule { class TestModule(torch.nn.Module): def __init__(self): diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/debug-info.py b/projects/pt1/test/python/importer/jit_ir/node_import/debug-info.py index 1bc258a42..7e8df49a0 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/debug-info.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/debug-info.py @@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder mb = ModuleBuilder() + # CHECK-LABEL: func.func @__torch__.add3 # Note that line-level debug information for parts unannotated in the Torch # graph are ascribed to the first op that carries source information. Presently diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/elif.py b/projects/pt1/test/python/importer/jit_ir/node_import/elif.py index 5ee16e391..f3ee0a557 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/elif.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/elif.py @@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder mb = ModuleBuilder() + # CHECK-LABEL: @__torch__.f @mb.import_function @torch.jit.script diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/function-derefine.py b/projects/pt1/test/python/importer/jit_ir/node_import/function-derefine.py index 2acde08ca..f9505b91f 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/function-derefine.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/function-derefine.py @@ -11,6 +11,7 @@ import typing mb = ModuleBuilder() + # CHECK-LABEL: func.func @__torch__.optional_return( # CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.optional { # CHECK: %[[RET:.*]] = torch.derefine %[[ARG]] : !torch.int to !torch.optional diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/if.py b/projects/pt1/test/python/importer/jit_ir/node_import/if.py index 86390f707..02cb8d9f0 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/if.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/if.py @@ -13,6 +13,7 @@ mb = ModuleBuilder() # else branch and making all defined values optional, so no special handling # is needed. + # CHECK-LABEL: @__torch__.prim_If( # CHECK-SAME: %[[B:.*]]: !torch.bool, # CHECK-SAME: %[[I:.*]]: !torch.int) -> !torch.int { diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/loop.py b/projects/pt1/test/python/importer/jit_ir/node_import/loop.py index d432cd6ee..b28d63bb0 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/loop.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/loop.py @@ -11,6 +11,7 @@ import typing mb = ModuleBuilder() + # CHECK-LABEL: func.func @__torch__.prim_Loop_forlike( # CHECK-SAME: %[[MAX_ITERATIONS:.*]]: !torch.int) -> !torch.float { # CHECK: %[[BOOL_TRUE:.*]] = torch.constant.bool true diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/prim.py b/projects/pt1/test/python/importer/jit_ir/node_import/prim.py index 66959257e..759292b6d 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/prim.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/prim.py @@ -15,6 +15,7 @@ import typing mb = ModuleBuilder() + # CHECK-LABEL: func.func @__torch__.prim_NumToTensor( # CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.tensor { # CHECK: %[[RET:.*]] = torch.prim.NumToTensor.Scalar %[[ARG]] : !torch.int -> !torch.tensor diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/tuple.py b/projects/pt1/test/python/importer/jit_ir/node_import/tuple.py index a1f06c390..b6a313cd4 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/tuple.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/tuple.py @@ -13,6 +13,7 @@ from utils import create_script_function mb = ModuleBuilder() NT = NamedTuple("NT", [("f1", Optional[torch.Tensor]), ("f2", Optional[torch.Tensor])]) + # CHECK-LABEL: func.func @__torch__.tuple( # CHECK-SAME: %[[T0:.*]]: !torch.tensor, # CHECK-SAME: %[[T1:.*]]: !torch.tensor) -> diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/types-bool.py b/projects/pt1/test/python/importer/jit_ir/node_import/types-bool.py index 0a27692fc..7cd4c3c16 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/types-bool.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/types-bool.py @@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder mb = ModuleBuilder() + # CHECK: @__torch__.returns_bool @mb.import_function @torch.jit.script diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/types-none.py b/projects/pt1/test/python/importer/jit_ir/node_import/types-none.py index 16a3359da..b0358467c 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/types-none.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/types-none.py @@ -9,6 +9,7 @@ from torch_mlir.jit_ir_importer import ModuleBuilder mb = ModuleBuilder() + # CHECK: @__torch__.returns_none @mb.import_function @torch.jit.script diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/utils.py b/projects/pt1/test/python/importer/jit_ir/node_import/utils.py index 613ccb6a8..b06c38fdf 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/utils.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/utils.py @@ -9,6 +9,7 @@ from torch._C import CompilationUnit # RUN: %PYTHON %s + # Import TorchScript IR string as ScriptFunction. def create_script_function(func_name, ts_ir_str, **kwargs): cu = CompilationUnit() diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 9acf4ad03..24bda3f5b 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -1849,8 +1849,7 @@ def _emit_operation( # Opaque value to indicate something is empty. Used in cases where 'None' # may have a different meaning. -class EmptyType: - ... +class EmptyType: ... Empty = EmptyType() diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index f1064f976..8d0e4cf5a 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -156,8 +156,7 @@ class GraphInfo: return "" -class OnnxImportError(Exception): - ... +class OnnxImportError(Exception): ... class NodeImporter: @@ -235,22 +234,22 @@ class NodeImporter: else: default_opset_version = opset_import.version if default_opset_version: - container_op.attributes[ - "torch.onnx_meta.opset_version" - ] = IntegerAttr.get(i64_type, default_opset_version) + container_op.attributes["torch.onnx_meta.opset_version"] = ( + IntegerAttr.get(i64_type, default_opset_version) + ) if opset_versions: - container_op.attributes[ - "torch.onnx_meta.opset_versions" - ] = DictAttr.get(opset_versions) + container_op.attributes["torch.onnx_meta.opset_versions"] = ( + DictAttr.get(opset_versions) + ) container_op.attributes["torch.onnx_meta.ir_version"] = IntegerAttr.get( IntegerType.get_signed(64), m.ir_version ) container_op.attributes["torch.onnx_meta.producer_name"] = StringAttr.get( m.producer_name ) - container_op.attributes[ - "torch.onnx_meta.producer_version" - ] = StringAttr.get(m.producer_version) + container_op.attributes["torch.onnx_meta.producer_version"] = ( + StringAttr.get(m.producer_version) + ) def import_all(self, func=True): """Imports all nodes topologically.""" @@ -658,9 +657,11 @@ ELEM_TYPE_SPLAT_TENSOR_PROTO_CB = { RankedTensorType.get(shape, IntegerType.get_signed(64)), IntegerAttr.get( IntegerType.get_signed(64), - int.from_bytes(tp.raw_data, "little", signed=True) - if tp.HasField("raw_data") - else tp.int64_data[0], + ( + int.from_bytes(tp.raw_data, "little", signed=True) + if tp.HasField("raw_data") + else tp.int64_data[0] + ), ), ), # TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB @@ -703,7 +704,7 @@ ELEM_TYPE_INLINE_TENSOR_PROTO_CB = { ), onnx.TensorProto.DataType.UINT64: lambda tp: DenseElementsAttr.get( np.asarray(tp.uint64_data, dtype=np.uint64).reshape(tp.dims), signless=False - ) + ), # Intentionally unsupported: STRING } From b1e22414794db1b25d938b6f8f7dca6376a50990 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Mon, 29 Apr 2024 09:30:01 +0530 Subject: [PATCH 05/30] [ONNX] Fix Onnx.Selu lowering and canonicalizer for IntImplicit op (#3221) Signed-Off By: Vivek Khandelwal --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 35 ++++++++++++++++--- lib/Dialect/Torch/IR/TorchOps.cpp | 19 +++++++--- projects/pt1/e2e_testing/xfail_sets.py | 3 -- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 16 ++++++--- 4 files changed, 56 insertions(+), 17 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 586b8d4ff..edb36aee9 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -847,15 +847,21 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( patterns.onOp( "Selu", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // y = gamma * (alpha * e^x - alpha) for x <= 0, y = gamma * x for x > 0 Torch::ValueTensorType resultType; float alpha, gamma; Value operand; + // Refer https://onnx.ai/onnx/operators/onnx__Selu.html for the default + // alpha and gamma values. if (binder.tensorOperand(operand) || - binder.f32FloatAttr(alpha, "alpha") || - binder.f32FloatAttr(gamma, "gamma") || + binder.f32FloatAttr(alpha, "alpha", 1.67326) || + binder.f32FloatAttr(gamma, "gamma", 1.0507) || binder.tensorResultType(resultType)) return failure(); + Torch::ValueTensorType inputType = + operand.getType().cast(); + Value vAlpha = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), alpha)); @@ -864,12 +870,31 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), gamma)); - Value vInputScale = rewriter.create( + Value cstOne = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), 1.0)); - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand, vAlpha, vScale, vInputScale); + Value cstNone = rewriter.create(binder.getLoc()); + Value zeroTensor = rewriter.create( + binder.getLoc(), resultType, operand, cstNone, cstNone, cstNone, + cstNone, cstNone); + Value exp = rewriter.create(binder.getLoc(), + resultType, operand); + Value expMulAlpha = rewriter.create( + binder.getLoc(), resultType, exp, vAlpha); + Value expMulAlphaSubAlpha = rewriter.create( + binder.getLoc(), resultType, expMulAlpha, vAlpha, cstOne); + Value neg = rewriter.create( + binder.getLoc(), resultType, expMulAlphaSubAlpha, vScale); + Value pos = rewriter.create( + binder.getLoc(), resultType, operand, vScale); + Type compareType = inputType.getWithSizesAndDtype( + inputType.getOptionalSizes(), rewriter.getI1Type()); + Value xLessThanZero = rewriter.create( + binder.getLoc(), compareType, operand, zeroTensor); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, xLessThanZero, neg, pos); return success(); }); patterns.onOp("ReduceL1", 1, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 376e7dd2e..29911961d 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -140,7 +140,7 @@ static Value getScalarIntValue(Value input, Location loc, return nullptr; Type inputDtype = inputTensorType.getOptionalDtype(); - if (!inputDtype || !inputDtype.isInteger(64)) + if (!inputDtype || !(inputDtype.isInteger(64) || inputDtype.isInteger(1))) return nullptr; std::optional inputRank = getTensorRank(input); @@ -148,10 +148,19 @@ static Value getScalarIntValue(Value input, Location loc, return nullptr; if (auto valueTensorLiteralOp = input.getDefiningOp()) { - auto val = cast(valueTensorLiteralOp.getValue()) - .getSplatValue(); - return rewriter.create( - loc, rewriter.getI64IntegerAttr(val)); + if (inputDtype.isInteger(64)) { + auto val = valueTensorLiteralOp.getValue() + .cast() + .getSplatValue(); + return rewriter.create( + loc, rewriter.getI64IntegerAttr(val)); + } else { + auto val = valueTensorLiteralOp.getValue() + .cast() + .getSplatValue(); + return rewriter.create( + loc, rewriter.getI64IntegerAttr(val)); + } } else if (auto primNumToTensorScalarOp = input.getDefiningOp()) { return primNumToTensorScalarOp.getA(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 276cc47c1..e45839617 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2124,7 +2124,6 @@ ONNX_XFAIL_SET = { "ElementwiseAtenFloorDivideTensorNegativeModule_basic", "ElementwiseLog10IntModule_basic", "ElementwiseLog2IntModule_basic", - "ElementwiseSeluModule_basic", "FlipModuleStaticShape_basic", "FlipNegativeIndexModule_basic", "HardsigmoidModule_basic", @@ -2637,8 +2636,6 @@ ONNX_XFAIL_SET = { "CopyWithDifferentDTypesModule_basic", "CosineSimilarityStaticBroadcastModule_basic", "CumsumInputDtypeInt32Module_basic", - "DropoutTrainModule_basic", - "DropoutTrainStaticShapeModule_basic", "ElementwiseAcosIntModule_basic", "ElementwiseAsinIntModule_basic", "ElementwiseAtanTensorIntModule_basic", diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 9c0ab3512..5fe9c79d3 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -582,10 +582,18 @@ func.func @test_softmax_negative_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !to // CHECK-LABEL: func.func @test_selu func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 6 : si64} { - // CHECK-DAG: %[[F1:.+]] = torch.constant.float 1 - // CHECK-DAG: %[[F2:.+]] = torch.constant.float 2 - // CHECK-DAG: %[[F3:.+]] = torch.constant.float 3 - // CHECK: %[[ELU:.+]] = torch.aten.elu %arg0, %[[F2]], %[[F3]], %[[F1]] + // CHECK: %[[F2:.+]] = torch.constant.float 2.000000e+00 + // CHECK: %[[F3:.+]] = torch.constant.float 3.000000e+00 + // CHECK: %[[F1:.+]] = torch.constant.float 1.000000e+00 + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[ZEROS:.+]] = torch.aten.zeros_like %arg0, %none, %none, %none, %none, %none : !torch.vtensor<[3,4,5],f32>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[EXP:.+]] = torch.aten.exp %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[MUL:.+]] = torch.aten.mul.Scalar %[[EXP]], %[[F2]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[SUB:.+]] = torch.aten.sub.Scalar %[[MUL]], %[[F2]], %[[F1]] : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[MUL_1:.+]] = torch.aten.mul.Scalar %[[SUB]], %[[F3]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[MUL_2:.+]] = torch.aten.mul.Scalar %arg0, %[[F3]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[LT:.+]] = torch.aten.lt.Tensor %arg0, %[[ZEROS]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],i1> + // CHECK: torch.aten.where.self %[[LT]], %[[MUL_1]], %[[MUL_2]] : !torch.vtensor<[3,4,5],i1>, !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.Selu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32, torch.onnx.gamma = 3.000000e+00 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } From 0a5ff68d9d57c9c3948b6d60c1edb32da9fe3670 Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Mon, 29 Apr 2024 17:40:30 +0800 Subject: [PATCH 06/30] [stablehlo] Support PrimsCollapseOp and PrimsSplitDimOp in stablehlo (#3230) --- .../TorchToStablehlo/StablehloLegalizeUtils.h | 11 ++ .../StablehloLegalizeUtils.cpp | 131 ++++++++++++++++++ lib/Conversion/TorchToStablehlo/ViewLike.cpp | 63 +++++---- projects/pt1/e2e_testing/xfail_sets.py | 8 +- 4 files changed, 182 insertions(+), 31 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h index 6e14b324b..734ba81ea 100644 --- a/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h @@ -69,6 +69,17 @@ FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, Value tensor, ArrayRef inputUnsqzDims, size_t dimSizeIndexBits); +// Get a tensor that collapse the specified dimensions of the input tensor +FailureOr collapseTensor(PatternRewriter &rewriter, Operation *op, + Value tensor, int64_t collapseStartDim, + int64_t collapseEndDim, + size_t dimSizeIndexBits); + +// Get a tensor that splits the specified dimensions of the input tensor +FailureOr splitTensor(PatternRewriter &rewriter, Operation *op, + Value tensor, int64_t splitDim, + int64_t outerLength, size_t dimSizeIndexBits); + Value getConstantOfShape(PatternRewriter &rewriter, Location loc, const APFloat &constant, Value shape, TensorType outType); diff --git a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp index 40ec715cd..c4d629d4f 100644 --- a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp +++ b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp @@ -9,6 +9,7 @@ #include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" @@ -306,6 +307,136 @@ FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, .getResult(); } +FailureOr collapseTensor(PatternRewriter &rewriter, Operation *op, + Value tensor, int64_t collapseStartDim, + int64_t collapseEndDim, + size_t dimSizeIndexBits) { + + auto dimSizesInfo = + getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits); + + if (failed(dimSizesInfo)) + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + + auto dimSizes = *dimSizesInfo; + int64_t rank = dimSizes.size(); + + collapseStartDim = toPositiveDim(collapseStartDim, rank); + collapseEndDim = toPositiveDim(collapseEndDim, rank); + + int64_t newRank = rank - (collapseEndDim - collapseStartDim + 1); + + auto loc = op->getLoc(); + auto rankTy = dyn_cast(tensor.getType()); + auto oldShape = rankTy.getShape(); + Type intType = rewriter.getIntegerType(dimSizeIndexBits); + + std::vector newDimSizes; + std::vector newShape; + newDimSizes.reserve(newRank); + newShape.reserve(newRank); + + Value collapseDimSize = rewriter.create( + loc, rewriter.getIntegerAttr(intType, 1)); + int64_t collapseShape = 1; + + for (int64_t k = collapseStartDim; k <= collapseEndDim; ++k) { + if (k < 0 || k >= rank) { + return rewriter.notifyMatchFailure( + op, "collapse dimensions must be within the rank of the tensor"); + } + if (collapseShape == ShapedType::kDynamic || + oldShape[k] == ShapedType::kDynamic) { + collapseShape = ShapedType::kDynamic; + } else { + collapseShape *= oldShape[k]; + } + collapseDimSize = + rewriter.create(loc, collapseDimSize, dimSizes[k]); + } + + for (int64_t k = 0; k < collapseStartDim; ++k) { + newDimSizes.push_back(dimSizes[k]); + newShape.push_back(oldShape[k]); + } + newDimSizes.push_back(collapseDimSize); + newShape.push_back(collapseShape); + for (int64_t k = collapseEndDim + 1; k < rank; ++k) { + newDimSizes.push_back(dimSizes[k]); + newShape.push_back(oldShape[k]); + } + + auto outTy = RankedTensorType::get(newShape, rankTy.getElementType()); + auto shape = rewriter.create(loc, newDimSizes); + return rewriter.create(loc, outTy, tensor, shape) + .getResult(); +} + +// TODO: support splitDim & outerLength to be Value +FailureOr splitTensor(PatternRewriter &rewriter, Operation *op, + Value tensor, int64_t splitDim, + int64_t outerLength, size_t dimSizeIndexBits) { + auto dimSizesInfo = + getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits); + + if (failed(dimSizesInfo)) + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + + auto dimSizes = *dimSizesInfo; + int64_t rank = dimSizes.size(); + splitDim = toPositiveDim(splitDim, rank); + + auto loc = op->getLoc(); + auto rankTy = dyn_cast(tensor.getType()); + auto oldShape = rankTy.getShape(); + Type intType = rewriter.getIntegerType(dimSizeIndexBits); + + if (splitDim < 0 || splitDim >= rank) { + return rewriter.notifyMatchFailure( + op, "split dimensions must be within the rank of the tensor"); + } + + int64_t newRank = rank + 1; + auto outerLengthValue = rewriter.create( + loc, rewriter.getIntegerAttr(intType, outerLength)); + + auto innerLengthValue = rewriter.create( + loc, dimSizes[splitDim], outerLengthValue); + + int64_t originShape = oldShape[splitDim]; + int64_t outerShape = outerLength; + int64_t innerShape = originShape == ShapedType::kDynamic + ? ShapedType::kDynamic + : originShape / outerLength; + + std::vector newDimSizes; + std::vector newShape; + + newDimSizes.reserve(newRank); + newShape.reserve(newRank); + + for (int64_t k = 0; k < splitDim; ++k) { + newDimSizes.push_back(dimSizes[k]); + newShape.push_back(oldShape[k]); + } + newDimSizes.push_back(outerLengthValue); + newShape.push_back(outerShape); + newDimSizes.push_back(innerLengthValue); + newShape.push_back(innerShape); + + for (int64_t k = splitDim + 1; k < rank; ++k) { + newDimSizes.push_back(dimSizes[k]); + newShape.push_back(oldShape[k]); + } + + auto outTy = RankedTensorType::get(newShape, rankTy.getElementType()); + auto shape = rewriter.create(loc, newDimSizes); + return rewriter.create(loc, outTy, tensor, shape) + .getResult(); +} + Value getConstantOfShape(PatternRewriter &rewriter, Location loc, const APFloat &constant, Value shape, TensorType outType) { diff --git a/lib/Conversion/TorchToStablehlo/ViewLike.cpp b/lib/Conversion/TorchToStablehlo/ViewLike.cpp index e43105ea1..04952d843 100644 --- a/lib/Conversion/TorchToStablehlo/ViewLike.cpp +++ b/lib/Conversion/TorchToStablehlo/ViewLike.cpp @@ -414,34 +414,44 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "only constant end is currently supported"); - start = toPositiveDim(start, rank); - end = toPositiveDim(end, rank); - SmallVector dims; - dims.reserve(rank); - for (int r = 0; r < start; ++r) - dims.push_back(r); - int64_t collapsedDimSize = 1; - for (int r = start; r <= end; ++r) { - if (selfType.getShape()[r] == ShapedType::kDynamic) - return rewriter.notifyMatchFailure( - op, "the size of the dimension being collapsed is can't be unknown"); - collapsedDimSize *= selfType.getShape()[r]; - } - dims.push_back(collapsedDimSize); - for (int r = end + 1; r < rank; ++r) - dims.push_back(r); + auto collapseTensorInfo = hlo::collapseTensor( + rewriter, op, adaptor.getA(), start, end, options.dimSizeIndexBits); + if (failed(collapseTensorInfo)) + return rewriter.notifyMatchFailure(op, "failed to create collapsed tensor"); - auto newDimSizesInfo = hlo::getDimSizesOfTensor( - rewriter, op, adaptor.getA(), dims, options.dimSizeIndexBits); - if (failed(newDimSizesInfo)) + rewriter.replaceOp(op, *collapseTensorInfo); + return success(); +} + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + PrimsSplitDimOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto selfType = adaptor.getA().getType().dyn_cast(); + if (!selfType) { + return op.emitError("only tensor types are currently supported"); + } + + auto rank = selfType.getRank(); + if (rank == 0) return rewriter.notifyMatchFailure( - op, "failed to get dimension sizes of the input"); - auto newDimSizes = *newDimSizesInfo; - auto stablehloShape = - rewriter.create(op.getLoc(), newDimSizes); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), adaptor.getA(), - stablehloShape); + op, "the rank of tensor must be greater than 0"); + + int64_t dim, outerLength; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure( + op, "only constant dim is currently supported"); + if (!matchPattern(op.getOuterLength(), m_TorchConstantInt(&outerLength))) + return rewriter.notifyMatchFailure( + op, "only constant outerLength is currently supported"); + + auto splitTensorInfo = hlo::splitTensor( + rewriter, op, adaptor.getA(), dim, outerLength, options.dimSizeIndexBits); + + if (failed(splitTensorInfo)) + return rewriter.notifyMatchFailure(op, "failed to create split tensor"); + + rewriter.replaceOp(op, *splitTensorInfo); return success(); } @@ -458,6 +468,7 @@ void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenSqueezeDimOp); INSERT_ATENOP_PATTERN(AtenUnsqueezeOp); INSERT_ATENOP_PATTERN(PrimsCollapseOp); + INSERT_ATENOP_PATTERN(PrimsSplitDimOp); #undef INSERT_ATENOP_PATTERN #define INSERT_VIEW_OP_PATTERN(AtenOp) \ diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e45839617..10c24b657 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -678,11 +678,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "NumToTensorIntModule_basic", "NumelModule_basic", "NumelZeroRankModule_basic", - "PixelShuffleModuleFullDynamic_basic", - "PixelShuffleModuleSpatiallyDynamic_basic", - "PixelShuffleModuleSpatiallyStatic_basic", - "PixelShuffleModuleStaticRank3Int64_basic", - "PixelShuffleModuleStaticRank4Float32_basic", "PowIntFloatModule_basic", "PrimMaxIntModule_basic", "PrimMinIntDynamicModule_basic", @@ -1157,6 +1152,8 @@ STABLEHLO_PASS_SET = { "Permute0RankModule_basic", "PermuteModule_basic", "PermuteNegativeIndexModule_basic", + "PixelShuffleModuleStaticRank3Int64_basic", + "PixelShuffleModuleStaticRank4Float32_basic", "PowIntFloatModule_basic", "PrimListUnpackNumMismatchModule_basic", "PrimMaxIntModule_basic", @@ -1240,6 +1237,7 @@ STABLEHLO_PASS_SET = { "SliceWholeTensorModule_basic", "SortIntListReverse_basic", "SortIntList_basic", + "SplitDimStaticModule_basic", "SplitTensorGetItem_Module_basic", "SplitTensorLastSmallerModule_basic", "SplitTensorListUnpackModule_basic", From 2176176fefd696d929b9d61b5587a419fae8386d Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Mon, 29 Apr 2024 09:21:12 -0700 Subject: [PATCH 07/30] [FX] Add broadcast test with dynamic dim (#3123) This scenario was uncovered in a downstream test that failed with a previous snapshot of torch-mlir. See https://github.com/cruise-automation/mlir-tcp/actions/runs/8605480116/job/23581829102?pr=65. ``` File "/home/runner/.cache/bazel/_bazel_runner/ce288f117ee4ca92dc028a6a28476a3d/sandbox/processwrapper-sandbox/2380/execroot/mlir-tcp/bazel-out/k8-opt-exec-2B5CBBC6/bin/test/AotCompile/broadcast_unit_dim_to_dynamic_with_unchanged_dim_dynamic_torch_exporter.runfiles/pip_deps_torch_mlir/site-packages/torch_mlir/extras/fx_importer.py", line 969, in value_info_to_type raise NotImplementedError( NotImplementedError: Could not deduce type from value info: tensor_meta=None, val=s1, sparsity=None ``` It seems to have resolved on current HEAD. Adding this test to ensure coverage in the future. --- test/python/fx_importer/basic_test.py | 29 ++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index 08ef9fdc9..fde318630 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -105,6 +105,33 @@ def test_import_frozen_exported_program_with_dynamic_shapes(): print(m) +@run +# CHECK-LABEL: test_broadcast_with_dynamic_shapes +# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,2],f32> +def test_broadcast_with_dynamic_shapes(): + class Basic(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.broadcast_to(x, (y.shape[0], -1)) + + # Sample inputs + x = torch.randn(1, 2) + y = torch.randn(10) + + dim_0 = Dim("dim_0") + dynamic_shapes = { + "x": {}, + "y": {0: dim_0}, + } + + m = fx.export_and_import( + Basic(), x, y, dynamic_shapes=dynamic_shapes, func_name="test_net" + ) + print(m) + + @make_boxed_compiler def fx_import_aot_autograd_backend( gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor] @@ -117,7 +144,7 @@ def fx_import_aot_autograd_backend( @run # CHECK-LABEL: test_stateless_fx_import -# CHECK: func.func @basic_forward__6_inference_0(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> +# CHECK: func.func @[[basic:[a-zA-Z0-9_]+]](%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> # CHECK-NEXT: %0 = torch.aten.tanh %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> # CHECK-NEXT: return %0 : !torch.vtensor<[3,4],f32> def test_stateless_fx_import(): From 087fea0608dac3995b74e5c22ae7950287fe7a73 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Mon, 29 Apr 2024 21:54:04 +0530 Subject: [PATCH 08/30] build: manually update PyTorch version (#3257) Set PyTorch and TorchVision version to nightly release 2024-04-28. Signed-Off By: Vivek Khandelwal --- pytorch-hash.txt | 2 +- pytorch-requirements.txt | 2 +- torchvision-requirements.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch-hash.txt b/pytorch-hash.txt index c2f8c830c..400586976 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -0a3e5f5badd8a0cb7fac97f5ec9d48c304e5c0b7 +34ade3521ca41f20af3469bba276c2b0499c3892 diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index 256104030..7cd8d44e5 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torch==2.4.0.dev20240422 +torch==2.4.0.dev20240428 diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index a530cc800..148f66152 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torchvision==0.19.0.dev20240422 +torchvision==0.19.0.dev20240428 From db6721084a2b3f41216e9cc7e0ea9263c33f196e Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 29 Apr 2024 12:01:40 -0700 Subject: [PATCH 09/30] Integrate LLVM at llvm/llvm-project@593f6fdcb4bb3ff81ba4e6f89d7b16540c4b9eaf (#3260) --- externals/llvm-project | 2 +- .../Dialect/TMTensor/IR/TMTensorInterfaces.h | 4 ++-- lib/Dialect/TMTensor/IR/TMTensorOps.cpp | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/externals/llvm-project b/externals/llvm-project index a952c1238..593f6fdcb 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit a952c123880eb1168f1021b116485e27170d48ca +Subproject commit 593f6fdcb4bb3ff81ba4e6f89d7b16540c4b9eaf diff --git a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h index 159bcea78..50045438f 100644 --- a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h +++ b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h @@ -30,8 +30,6 @@ namespace detail { LogicalResult verifyTMTensorOpInterface(Operation *op); } -#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h.inc" // IWYU pragma: export - /// Include the generated interface declarations. #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOpInterfaces.h.inc" // IWYU pragma: export @@ -39,4 +37,6 @@ LogicalResult verifyTMTensorOpInterface(Operation *op); } // namespace torch } // namespace mlir +#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h.inc" // IWYU pragma: export + #endif // TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_ diff --git a/lib/Dialect/TMTensor/IR/TMTensorOps.cpp b/lib/Dialect/TMTensor/IR/TMTensorOps.cpp index be07ca276..218ecad33 100644 --- a/lib/Dialect/TMTensor/IR/TMTensorOps.cpp +++ b/lib/Dialect/TMTensor/IR/TMTensorOps.cpp @@ -936,7 +936,7 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern { // If no operand comes from a tensor::CastOp and can be folded then fail. bool hasTensorCastOperand = llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) { - if (opOperand->get().isa()) + if (isa(opOperand->get())) return false; auto castOp = opOperand->get().getDefiningOp(); return castOp && canFoldIntoConsumerOp(castOp); From 122cf22cc2b1d2006607dc18e8d2309a94172321 Mon Sep 17 00:00:00 2001 From: "Jae Hoon (Antonio) Kim" <17433012+antoniojkim@users.noreply.github.com> Date: Mon, 29 Apr 2024 15:02:12 -0400 Subject: [PATCH 10/30] Re-enable LTC Build (#3261) The LTC Build was disabled in https://github.com/llvm/torch-mlir/pull/3210 due to a regression in the packaging of the torch nightly wheels (https://github.com/pytorch/pytorch/issues/124941) which is now resolved. So, re-enabling LTC build in this PR --- build_tools/ci/build_posix.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/build_tools/ci/build_posix.sh b/build_tools/ci/build_posix.sh index bacb736ba..fec5e252e 100755 --- a/build_tools/ci/build_posix.sh +++ b/build_tools/ci/build_posix.sh @@ -50,6 +50,7 @@ cmake -S "$repo_root/externals/llvm-project/llvm" -B "$build_dir" \ -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$repo_root" \ -DLLVM_TARGETS_TO_BUILD=host \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DTORCH_MLIR_ENABLE_LTC=ON echo "::endgroup::" echo "::group::Build" From b64c22cfc12c110f9e77857530d014978b2577b8 Mon Sep 17 00:00:00 2001 From: jinchen <49575973+jinchen62@users.noreply.github.com> Date: Tue, 30 Apr 2024 00:44:41 -0700 Subject: [PATCH 11/30] Fix onnx sinh lowering (#3253) iree tests `test_sinh` and `test_sinh_example` passed --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 35 +++++++++++++------ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 10 ++++-- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index edb36aee9..197d9c536 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1449,18 +1449,31 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return success(); }); - patterns.onOp("Sinh", 9, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) - return failure(); + patterns.onOp( + "Sinh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand); - return success(); - }); + // 1/2 * (exp(x) – exp(-x)) + Value x = rewriter.create(binder.getLoc(), resultType, + operand); + Value neg = rewriter.create(binder.getLoc(), + resultType, operand); + Value y = + rewriter.create(binder.getLoc(), resultType, neg); + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value z = rewriter.create( + binder.getLoc(), resultType, x, y, cstOne); + Value cstTwo = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(2)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, z, cstTwo); + return success(); + }); // split with fixed-size parts // Arguments: diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 5fe9c79d3..2748a640a 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -1265,9 +1265,15 @@ func.func @test_reduce_prod_keepdims_random(%arg0: !torch.vtensor<[3,2,2],f32>, // ----- -// CHECK-LABEL: func.func @test_sinh +// CHECK-LABEL: func.func @test_sinh_example func.func @test_sinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64} { - // CHECK: torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[X:.+]] = torch.aten.exp %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[NEG:.+]] = torch.aten.neg %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[Y:.+]] = torch.aten.exp %[[NEG]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[C1:.+]] = torch.constant.int 1 + // CHECK: %[[SUB:.+]] = torch.aten.sub.Tensor %[[X]], %[[Y]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + // CHECK: %[[C2:.+]] = torch.constant.int 2 + // CHECK: torch.aten.div.Scalar %[[SUB]], %[[C2]] : !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> %0 = torch.operator "onnx.Sinh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } From aa471f1d9612eb3a3b47a041aaef565944398dd2 Mon Sep 17 00:00:00 2001 From: jinchen <49575973+jinchen62@users.noreply.github.com> Date: Tue, 30 Apr 2024 00:49:29 -0700 Subject: [PATCH 12/30] Fix onnx cosh lowering (#3254) iree tests `test_cosh` and `test_cosh_example` passed --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 36 +++++++++++++------ .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 16 +++++++-- 2 files changed, 39 insertions(+), 13 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 96f4e55fb..401c83991 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1348,17 +1348,31 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); - patterns.onOp("Cosh", 9, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) - return failure(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand); - return success(); - }); + patterns.onOp( + "Cosh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + + // 1/2 * (exp(x) + exp(-x)) + Value x = rewriter.create(binder.getLoc(), resultType, + operand); + Value neg = rewriter.create(binder.getLoc(), + resultType, operand); + Value y = + rewriter.create(binder.getLoc(), resultType, neg); + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value z = rewriter.create( + binder.getLoc(), resultType, x, y, cstOne); + Value cstTwo = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(2)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, z, cstTwo); + return success(); + }); patterns.onOp( "CumSum", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Location loc = binder.getLoc(); diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index f53e55a16..0719512f0 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -665,7 +665,13 @@ func.func @test_cos(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5 // CHECK-LABEL: @test_cosh_example func.func @test_cosh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: torch.aten.cosh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[X:.+]] = torch.aten.exp %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[NEG:.+]] = torch.aten.neg %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[Y:.+]] = torch.aten.exp %[[NEG]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[C1:.+]] = torch.constant.int 1 + // CHECK: %[[ADD:.+]] = torch.aten.add.Tensor %[[X]], %[[Y]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + // CHECK: %[[C2:.+]] = torch.constant.int 2 + // CHECK: torch.aten.div.Scalar %[[ADD]], %[[C2]] : !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> %0 = torch.operator "onnx.Cosh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } @@ -674,7 +680,13 @@ func.func @test_cosh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[ // CHECK-LABEL: @test_cosh func.func @test_cosh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: torch.aten.cosh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[X:.+]] = torch.aten.exp %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[NEG:.+]] = torch.aten.neg %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[Y:.+]] = torch.aten.exp %[[NEG]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[C1:.+]] = torch.constant.int 1 + // CHECK: %[[ADD:.+]] = torch.aten.add.Tensor %[[X]], %[[Y]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[C2:.+]] = torch.constant.int 2 + // CHECK: torch.aten.div.Scalar %[[ADD]], %[[C2]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.Cosh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } From fb499192dfe60476c72838e84b8d5b42dfcd6072 Mon Sep 17 00:00:00 2001 From: jinchen <49575973+jinchen62@users.noreply.github.com> Date: Tue, 30 Apr 2024 00:49:44 -0700 Subject: [PATCH 13/30] Fix onnx acosh lowering (#3262) iree tests `test_acosh` and `test_acosh_example` passed --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 34 +++++++++++++------ .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 14 ++++++-- 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 401c83991..f5b05327c 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -242,17 +242,29 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); - patterns.onOp("Acosh", 9, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) - return failure(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand); - return success(); - }); + patterns.onOp( + "Acosh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + + // log(x + sqrt(x**2 - 1)) + Value square = rewriter.create( + binder.getLoc(), resultType, operand); + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value sub = rewriter.create( + binder.getLoc(), resultType, square, cstOne, cstOne); + Value sqrt = rewriter.create(binder.getLoc(), + resultType, sub); + Value add = rewriter.create( + binder.getLoc(), resultType, operand, sqrt, cstOne); + rewriter.replaceOpWithNewOp(binder.op, resultType, + add); + return success(); + }); patterns.onOp("BatchNormalization", 15, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 0719512f0..967c35f13 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -695,7 +695,12 @@ func.func @test_cosh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, // CHECK-LABEL: @test_acosh_example func.func @test_acosh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: torch.aten.acosh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[SQUARE:.+]] = torch.aten.square %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[C1:.+]] = torch.constant.int 1 + // CHECK: %[[SUB:.+]] = torch.aten.sub.Scalar %[[SQUARE]], %[[C1]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.int, !torch.int -> !torch.vtensor<[3],f32> + // CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[SUB]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[ADD:.+]] = torch.aten.add.Tensor %arg0, %[[SQRT]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + // CHECK: torch.aten.log %[[ADD]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> %0 = torch.operator "onnx.Acosh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } @@ -704,7 +709,12 @@ func.func @test_acosh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor< // CHECK-LABEL: @test_acosh func.func @test_acosh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: torch.aten.acosh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[SQUARE:.+]] = torch.aten.square %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[C1:.+]] = torch.constant.int 1 + // CHECK: %[[SUB:.+]] = torch.aten.sub.Scalar %[[SQUARE]], %[[C1]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[SUB]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[ADD:.+]] = torch.aten.add.Tensor %arg0, %[[SQRT]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> + // CHECK: torch.aten.log %[[ADD]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.Acosh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } From bf04b53b072aa90ea72723b0189c418cbdc4857f Mon Sep 17 00:00:00 2001 From: jinchen <49575973+jinchen62@users.noreply.github.com> Date: Tue, 30 Apr 2024 00:49:57 -0700 Subject: [PATCH 14/30] Fix onnx asinh lowering (#3263) iree tests `test_asinh` and `test_asinh_example` passed --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 34 +++++++++++++------ .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 14 ++++++-- 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index f5b05327c..7b44e8510 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -198,17 +198,29 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); - patterns.onOp("Asinh", 9, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) - return failure(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand); - return success(); - }); + patterns.onOp( + "Asinh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + + // log(x + sqrt(x**2 + 1)) + Value square = rewriter.create( + binder.getLoc(), resultType, operand); + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value add0 = rewriter.create( + binder.getLoc(), resultType, square, cstOne, cstOne); + Value sqrt = rewriter.create(binder.getLoc(), + resultType, add0); + Value add1 = rewriter.create( + binder.getLoc(), resultType, operand, sqrt, cstOne); + rewriter.replaceOpWithNewOp(binder.op, resultType, + add1); + return success(); + }); patterns.onOp("Atan", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 967c35f13..aca59b8ae 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -741,7 +741,12 @@ func.func @test_asin(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, // CHECK-LABEL: @test_asinh_example func.func @test_asinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: torch.aten.asinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[SQUARE:.+]] = torch.aten.square %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[C1:.+]] = torch.constant.int 1 + // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %[[SQUARE]], %[[C1]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.int, !torch.int -> !torch.vtensor<[3],f32> + // CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[ADD]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[ADD_0:.+]] = torch.aten.add.Tensor %arg0, %[[SQRT]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + // CHECK: torch.aten.log %[[ADD_0]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> %0 = torch.operator "onnx.Asinh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } @@ -750,7 +755,12 @@ func.func @test_asinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor< // CHECK-LABEL: @test_asinh func.func @test_asinh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: torch.aten.asinh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[SQUARE:.+]] = torch.aten.square %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[C1:.+]] = torch.constant.int 1 + // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %[[SQUARE]], %[[C1]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[ADD]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[ADD_0:.+]] = torch.aten.add.Tensor %arg0, %[[SQRT]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> + // CHECK: torch.aten.log %[[ADD_0]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.Asinh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } From fbbad2d81e7cad20b2590fbd2087889a207e2eb6 Mon Sep 17 00:00:00 2001 From: jinchen <49575973+jinchen62@users.noreply.github.com> Date: Tue, 30 Apr 2024 00:50:08 -0700 Subject: [PATCH 15/30] Fix onnx atanh lowering (#3264) iree tests `test_atanh` and `test_atanh_example` passed --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 38 +++++++++++++------ .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 9 ++++- 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 7b44e8510..716ea3d6e 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -232,17 +232,33 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); - patterns.onOp("Atanh", 9, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) - return failure(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand); - return success(); - }); + patterns.onOp( + "Atanh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + + // 1/2 * log((1 + x) / (1 - x)) + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value add = rewriter.create( + binder.getLoc(), resultType, operand, cstOne, cstOne); + Value neg = rewriter.create(binder.getLoc(), + resultType, operand); + Value sub = rewriter.create( + binder.getLoc(), resultType, neg, cstOne, cstOne); + Value div = rewriter.create( + binder.getLoc(), resultType, add, sub); + Value log = + rewriter.create(binder.getLoc(), resultType, div); + Value cstTwo = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(2)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, log, cstTwo); + return success(); + }); patterns.onOp("Acos", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index aca59b8ae..eb2cde696 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -201,7 +201,14 @@ func.func @test_atan(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, // CHECK-LABEL: @test_atanh func.func @test_atanh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: torch.aten.atanh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %arg0, %[[C1]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[NEG:.*]] = torch.aten.neg %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[SUB:.*]] = torch.aten.add.Scalar %[[NEG]], %[[C1]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[DIV:.*]] = torch.aten.div.Tensor %[[ADD]], %[[SUB]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[LOG:.*]] = torch.aten.log %[[DIV]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: torch.aten.div.Scalar %[[LOG]], %[[C2]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.Atanh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } From fb8aed09076bc5073808dfe7057268b3b80543d7 Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Tue, 30 Apr 2024 00:55:25 -0700 Subject: [PATCH 16/30] [Release Builds] Use `-no-build-isolation` to decouple from `pyproject.toml` (#3266) Fixes https://github.com/llvm/torch-mlir/issues/3258 In addition disabling the LTC builds since they are already covered in CI (build_posix.sh) and I am not aware of a consumer of this flow in the binary releases of torch-mlir (the main dependency there is from source). --- build_tools/python_deploy/build_linux_packages.sh | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/build_tools/python_deploy/build_linux_packages.sh b/build_tools/python_deploy/build_linux_packages.sh index 4feccdd64..625020836 100755 --- a/build_tools/python_deploy/build_linux_packages.sh +++ b/build_tools/python_deploy/build_linux_packages.sh @@ -432,6 +432,8 @@ function clean_build() { } function build_torch_mlir() { + # Disable LTC build for releases + export TORCH_MLIR_ENABLE_LTC=0 local torch_version="$1" case $torch_version in nightly) @@ -440,7 +442,7 @@ function build_torch_mlir() { --extra-index-url https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html CMAKE_GENERATOR=Ninja \ TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \ - python -m pip wheel -v -w /wheelhouse /main_checkout/torch-mlir \ + python -m pip wheel -v --no-build-isolation -w /wheelhouse /main_checkout/torch-mlir \ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html \ -r /main_checkout/torch-mlir/whl-requirements.txt ;; @@ -450,7 +452,7 @@ function build_torch_mlir() { python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/build-requirements.txt CMAKE_GENERATOR=Ninja \ TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \ - python -m pip wheel -v -w /wheelhouse /main_checkout/torch-mlir + python -m pip wheel -v --no-build-isolation -w /wheelhouse /main_checkout/torch-mlir ;; *) echo "Unrecognized torch version '$torch_version'" @@ -474,7 +476,7 @@ function build_torch_mlir_core() { TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \ TORCH_MLIR_ENABLE_JIT_IR_IMPORTER=0 \ TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS=1 \ - python -m pip wheel -v -w /wheelhouse /main_checkout/torch-mlir + python -m pip wheel -v --no-build-isolation -w /wheelhouse /main_checkout/torch-mlir } function clean_wheels() { From f32ada993d393581ae1e70ac6b47dbdd4a70dca1 Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Wed, 1 May 2024 00:06:13 +0800 Subject: [PATCH 17/30] [Stablehlo] Improve the lowering of pool op in stablehlo (#3259) 1. Handle case stride == None 2. add avgpool3d maxpool1d maxpool3d lowering --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 28 ++ lib/Conversion/TorchToStablehlo/Pooling.cpp | 279 +++++++++++------- .../Transforms/AbstractInterpLibrary.cpp | 14 +- projects/pt1/e2e_testing/xfail_sets.py | 3 + .../build_tools/abstract_interp_lib_gen.py | 13 +- .../build_tools/torch_ods_gen.py | 1 + 6 files changed, 216 insertions(+), 122 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 8ebd7b162..cb08ffd53 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6637,6 +6637,34 @@ def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [ }]; } +def Torch_AtenMaxPool1dOp : Torch_Op<"aten.max_pool1d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$kernel_size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$dilation, + Torch_BoolType:$ceil_mode + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMaxPool1dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenMaxPool1dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToStablehlo/Pooling.cpp b/lib/Conversion/TorchToStablehlo/Pooling.cpp index 132410a2a..9219b4af3 100644 --- a/lib/Conversion/TorchToStablehlo/Pooling.cpp +++ b/lib/Conversion/TorchToStablehlo/Pooling.cpp @@ -36,7 +36,7 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, auto constType = RankedTensorType::get({}, elementTy); // Avg pooling if (isa(op)) { + AtenAvgPool3dOp, AtenCumsumOp>(op)) { if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( constType, {APFloat::getZero( @@ -54,7 +54,8 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, } // Max pooling - if (isa(op)) { + if (isa(op)) { if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( constType, @@ -75,101 +76,6 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, return nullptr; } -// AtenMaxPool2dOp -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenMaxPool2dOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value input = adaptor.getSelf(); - auto inputTy = cast(input.getType()); - auto inputElemTy = inputTy.getElementType(); - - auto inputRank = inputTy.getRank(); - auto outTy = - cast(getTypeConverter()->convertType(op.getType())); - - if (inputRank <= 2) { - return op.emitError( - "max_pooling2d only supports inputs with rank higher than 2"); - } - SmallVector padding, kernelSize, stride, dilation; - bool ceilMode = false; - - if (!(matchPattern(op.getKernelSize(), - m_TorchListOfConstantInts(kernelSize)))) { - return rewriter.notifyMatchFailure( - op, "non-const int kernel size unsupported!"); - } - if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) { - return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!"); - } - if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) { - return rewriter.notifyMatchFailure(op, - "non-const int padding unsupported!"); - } - if (!(matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilation)))) { - return rewriter.notifyMatchFailure(op, - "non-const int dilation unsupported!"); - } - if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) { - return rewriter.notifyMatchFailure(op, - "non-const bool ceil_mode unsupported!"); - } - - // prepend 1 to kernelSize, stride, dilation until they are of same rank as - // input - SmallVector stablehloStride(inputRank, 1); - SmallVector stablehloDilation(inputRank, 1); - SmallVector stablehloKernelSize(inputRank, 1); - SmallVector stablehloPadding(inputRank * 2, 0); - std::copy(dilation.begin(), dilation.end(), - stablehloDilation.begin() + inputRank - 2); - std::copy(stride.begin(), stride.end(), - stablehloStride.begin() + inputRank - 2); - std::copy(kernelSize.begin(), kernelSize.end(), - stablehloKernelSize.begin() + inputRank - 2); - - Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); - - stablehloPadding[stablehloPadding.size() - 4] = padding[0]; - stablehloPadding[stablehloPadding.size() - 3] = padding[0]; - stablehloPadding[stablehloPadding.size() - 2] = padding[1]; - stablehloPadding[stablehloPadding.size() - 1] = padding[1]; - - auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize); - auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride); - DenseI64ArrayAttr baseDilations; - auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation); - DenseIntElementsAttr pad = DenseIntElementsAttr::get( - RankedTensorType::get( - {static_cast(inputRank), static_cast(2)}, - rewriter.getI64Type()), - stablehloPadding); - auto reduceWindowOp = rewriter.create( - op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, - baseDilations, windowDilations, pad); - - Block &block = reduceWindowOp.getBody().emplaceBlock(); - - auto blockArgumentTy = RankedTensorType::get({}, inputElemTy); - block.addArgument(blockArgumentTy, op->getLoc()); - block.addArgument(blockArgumentTy, op->getLoc()); - - auto *firstArg = block.args_begin(); - auto secondArg = block.args_rbegin(); - - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&block); - Value result = - rewriter.create(op->getLoc(), *firstArg, *secondArg); - rewriter.create(op->getLoc(), result); - } - - rewriter.replaceOp(op, reduceWindowOp.getResults()); - return success(); -} - // AtenMaxPool2dWithIndicesOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -356,6 +262,129 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +namespace { +template +class ConvertAtenMaxPoolOp : public ConvertAtenOp { +public: + using ConvertAtenOp::ConvertAtenOp; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value input = adaptor.getSelf(); + auto inputTy = cast(input.getType()); + auto inputElemTy = inputTy.getElementType(); + auto inputRank = inputTy.getRank(); + auto outTy = cast( + ConvertAtenOp::getTypeConverter()->convertType(op.getType())); + + if (inputRank <= Dim) { + return op.emitError( + "max_pooling1d/2d only supports inputs with rank higher than 1/2"); + } + SmallVector padding, kernelSize, stride, dilation; + bool ceilMode = false; + + if (!(matchPattern(op.getKernelSize(), + m_TorchListOfConstantInts(kernelSize)))) { + return rewriter.notifyMatchFailure( + op, "non-const int kernel size unsupported!"); + } + if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) { + return rewriter.notifyMatchFailure(op, + "non-const int stride unsupported!"); + } + if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) { + return rewriter.notifyMatchFailure(op, + "non-const int padding unsupported!"); + } + if (!(matchPattern(op.getDilation(), + m_TorchListOfConstantInts(dilation)))) { + return rewriter.notifyMatchFailure(op, + "non-const int dilation unsupported!"); + } + if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) { + return rewriter.notifyMatchFailure( + op, "non-const bool ceil_mode unsupported!"); + } + + if (stride.empty()) { + stride = kernelSize; + } + + // prepend 1 to kernelSize, stride, dilation until they are of same rank + // as input + SmallVector stablehloStride(inputRank, 1); + SmallVector stablehloDilation(inputRank, 1); + SmallVector stablehloKernelSize(inputRank, 1); + SmallVector stablehloPadding(inputRank * 2, 0); + std::copy(dilation.begin(), dilation.end(), + stablehloDilation.begin() + inputRank - Dim); + std::copy(stride.begin(), stride.end(), + stablehloStride.begin() + inputRank - Dim); + std::copy(kernelSize.begin(), kernelSize.end(), + stablehloKernelSize.begin() + inputRank - Dim); + + Value initVal = + createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); + + if (Dim == 1) { + stablehloPadding[stablehloPadding.size() - 2] = padding[0]; + stablehloPadding[stablehloPadding.size() - 1] = padding[0]; + } else if (Dim == 2) { + stablehloPadding[stablehloPadding.size() - 4] = padding[0]; + stablehloPadding[stablehloPadding.size() - 3] = padding[0]; + stablehloPadding[stablehloPadding.size() - 2] = padding[1]; + stablehloPadding[stablehloPadding.size() - 1] = padding[1]; + } else if (Dim == 3) { + stablehloPadding[stablehloPadding.size() - 6] = padding[0]; + stablehloPadding[stablehloPadding.size() - 5] = padding[0]; + stablehloPadding[stablehloPadding.size() - 4] = padding[1]; + stablehloPadding[stablehloPadding.size() - 3] = padding[1]; + stablehloPadding[stablehloPadding.size() - 2] = padding[2]; + stablehloPadding[stablehloPadding.size() - 1] = padding[2]; + } else { + assert(false && "Unsupported pooling dimension"); + } + auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize); + auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride); + DenseI64ArrayAttr baseDilations; + auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation); + + DenseIntElementsAttr pad = DenseIntElementsAttr::get( + RankedTensorType::get( + {static_cast(inputRank), static_cast(2)}, + rewriter.getI64Type()), + stablehloPadding); + + auto reduceWindowOp = rewriter.create( + op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, + baseDilations, windowDilations, pad); + + Block &block = reduceWindowOp.getBody().emplaceBlock(); + + // Add bb argument + auto blockArgumentType = RankedTensorType::get({}, inputElemTy); + block.addArgument(blockArgumentType, op->getLoc()); + block.addArgument(blockArgumentType, op->getLoc()); + auto *firstArg = block.args_begin(); + auto secondArg = block.args_rbegin(); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + + Value result = rewriter.create(op->getLoc(), *firstArg, + *secondArg); + rewriter.create(op->getLoc(), result); + } + + rewriter.replaceOp(op, reduceWindowOp.getResults()); + return success(); + } +}; +} // namespace + namespace { template class ConvertAtenAvgPoolOp : public ConvertAtenOp { @@ -375,8 +404,8 @@ public: auto outShape = outTy.getShape(); if (inputRank <= Dim) { - return op.emitError( - "avg_pooling1d/2d only supports inputs with rank higher than 1/2"); + return op.emitError("avg_pooling1d/2d/3d only supports inputs with rank " + "higher than 1/2/3"); } SmallVector padding, kernelSize, stride; bool ceilMode = false; @@ -405,6 +434,10 @@ public: op, "non-const bool count_include_pad unsupported!"); } + if (stride.empty()) { + stride = kernelSize; + } + if constexpr (std::is_same()) { if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride()))) return rewriter.notifyMatchFailure( @@ -425,11 +458,20 @@ public: if (Dim == 1) { stablehloPadding[stablehloPadding.size() - 2] = padding[0]; stablehloPadding[stablehloPadding.size() - 1] = padding[0]; - } else { + } else if (Dim == 2) { stablehloPadding[stablehloPadding.size() - 4] = padding[0]; stablehloPadding[stablehloPadding.size() - 3] = padding[0]; stablehloPadding[stablehloPadding.size() - 2] = padding[1]; stablehloPadding[stablehloPadding.size() - 1] = padding[1]; + } else if (Dim == 3) { + stablehloPadding[stablehloPadding.size() - 6] = padding[0]; + stablehloPadding[stablehloPadding.size() - 5] = padding[0]; + stablehloPadding[stablehloPadding.size() - 4] = padding[1]; + stablehloPadding[stablehloPadding.size() - 3] = padding[1]; + stablehloPadding[stablehloPadding.size() - 2] = padding[2]; + stablehloPadding[stablehloPadding.size() - 1] = padding[2]; + } else { + assert(false && "Unsupported pooling dimension"); } Value initVal = @@ -474,10 +516,17 @@ public: divisor = hlo::getConstTensor(rewriter, op, {kernelSize[0]}, {}) .value(); - } else { + } else if (Dim == 2) { divisor = hlo::getConstTensor( rewriter, op, {kernelSize[0] * kernelSize[1]}, {}) .value(); + } else if (Dim == 3) { + divisor = hlo::getConstTensor( + rewriter, op, + {kernelSize[0] * kernelSize[1] * kernelSize[2]}, {}) + .value(); + } else { + assert(false && "Unsupported pooling dimension"); } divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy); DenseI64ArrayAttr bcastDimensions; @@ -611,22 +660,28 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToStablehloOptions &options) { MLIRContext *context = patterns.getContext(); - target.addIllegalOp(); - patterns.add>(typeConverter, context, options); - target.addIllegalOp(); - patterns.add>(typeConverter, context, options); - target.addIllegalOp(); - patterns.add>(typeConverter, context, options); - target.addIllegalOp(); - patterns.add>(typeConverter, - context, options); - target.addIllegalOp(); - patterns.add>(typeConverter, context, options); +#define INSERT_ATEN_POOLING_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context, options) + INSERT_ATEN_POOLING_PATTERN(AtenMaxPool2dWithIndicesOp); + INSERT_ATEN_POOLING_PATTERN(AtenCumsumOp); +#undef INSERT_ATEN_POOLING_PATTERN + +#define INSERT_ATEN_MAXPOOL_PATTERN(AtenOp, Dim) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context, \ + options) + INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool1dOp, 1); + INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool2dOp, 2); + INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool3dOp, 3); +#undef INSERT_ATEN_MAXPOOL_PATTERN + #define INSERT_ATEN_AVGPOOL_PATTERN(AtenOp, Dim) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context, \ options) INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool1dOp, 1); INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool2dOp, 2); + INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool3dOp, 3); #undef INSERT_ATEN_AVGPOOL_PATTERN } diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 553a8dc74..d9ac7a6d0 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7845,19 +7845,19 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %arg2 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.avg_pool1d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list {\n" -" %0 = call @__torch__.avg_pool1d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool) -> !torch.list\n" +" %0 = call @__torch__.pool1d(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.bool) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @__torch__.avg_pool1d(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list {\n" +" func.func @__torch__.pool1d(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool) -> !torch.list {\n" " %int-1 = torch.constant.int -1\n" " %int-2 = torch.constant.int -2\n" " %int-3 = torch.constant.int -3\n" " %str = torch.constant.str \"AssertionError: \"\n" -" %str_0 = torch.constant.str \"AssertionError: avg_pool1d: padding must be a single int\"\n" -" %str_1 = torch.constant.str \"AssertionError: avg_pool1d: stride must either be omitted, or a single int\"\n" +" %str_0 = torch.constant.str \"AssertionError: pool1d: padding must be a single int\"\n" +" %str_1 = torch.constant.str \"AssertionError: pool1d: stride must either be omitted, or a single int\"\n" " %true = torch.constant.bool true\n" " %none = torch.constant.none\n" -" %str_2 = torch.constant.str \"AssertionError: avg_pool1d: kernel_size must be a single int\"\n" +" %str_2 = torch.constant.str \"AssertionError: pool1d: kernel_size must be a single int\"\n" " %int1 = torch.constant.int 1\n" " %int0 = torch.constant.int 0\n" " %int2 = torch.constant.int 2\n" @@ -7940,6 +7940,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %23 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.max_pool1d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.pool1d(%arg0, %arg1, %arg2, %arg3, %arg5) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.bool) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.adaptive_avg_pool1d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.adaptive_avg_pool1d(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 10c24b657..8ffe8d1c3 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1075,6 +1075,9 @@ STABLEHLO_PASS_SET = { "Matmul_vecmat", "MatmulStaticBroadcast_basic", "MaxPool2dStaticModule_basic", + "MaxPool2dEmptyStrideStaticModule_basic", + "MaxPool3dStaticModule_basic", + "MaxPool3dEmptyStrideStaticModule_basic", "MeanDimAllReduceModule_basic", "MeanDimEmptyDimModule_basic", "MeanDimNoneDimModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index da486fe46..eb6062056 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -961,14 +961,14 @@ def avg_pool2d(input: List[int], kernel_size: List[int], stride: List[int], padd # TODO: This should be upstreamed. # See https://github.com/pytorch/pytorch/pull/76889 for an example. -def avg_pool1d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool, count_include_pad: bool): - assert len(kernel_size) == 1, "avg_pool1d: kernel_size must be a single int" +def pool1d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool): + assert len(kernel_size) == 1, "pool1d: kernel_size must be a single int" kL = kernel_size[0] - assert len(stride) == 0 or len(stride) == 1, "avg_pool1d: stride must either be omitted, or a single int" + assert len(stride) == 0 or len(stride) == 1, "pool1d: stride must either be omitted, or a single int" dL = kL if len(stride) == 0 else stride[0] - assert len(padding) == 1, "avg_pool1d: padding must be a single int" + assert len(padding) == 1, "pool1d: padding must be a single int" padL = padding[0] dilationL = 1 @@ -1004,7 +1004,10 @@ def adaptive_avg_pool1d(self: List[int], out: List[int]): return shape def aten〇avg_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> List[int]: - return avg_pool1d(self, kernel_size, stride, padding, ceil_mode, count_include_pad) + return pool1d(self, kernel_size, stride, padding, ceil_mode) + +def aten〇max_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), dilation: List[int] = (1,), ceil_mode: bool = False) -> List[int]: + return pool1d(self, kernel_size, stride, padding, ceil_mode) def aten〇adaptive_avg_pool1d〡shape(self: List[int], output_size: List[int]) -> List[int]: return adaptive_avg_pool1d(self, output_size) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index eea8d31a9..e0329c8df 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -591,6 +591,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit( "aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)" ) + emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") emit( "aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)" From 05f8b69bf66f7727fc4870e51efa74c2f276b624 Mon Sep 17 00:00:00 2001 From: Vinayak Dev <104419489+vinayakdsci@users.noreply.github.com> Date: Tue, 30 Apr 2024 21:51:27 +0530 Subject: [PATCH 18/30] [MLIR][TORCH] Add OnnxToTorch support for BlackmanWindow function (#3181) Implements OnnxToTorch lowering for the BlackmanWindow Function. --- .../Conversion/TorchOnnxToTorch/Utils.h | 7 + .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 122 ++++++++++++++++++ .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 8 -- .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 78 +++++++++++ 4 files changed, 207 insertions(+), 8 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h index d4ace352a..919146c6a 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h @@ -38,6 +38,13 @@ Value createConstantIntList(OpBinder binder, Type getQTorchTypeFromTorchIntType(Type ty); +template +Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter, + Value &ofItem) { + return rewriter.create(binder.getLoc(), + rewriter.getType(), ofItem); +} + LogicalResult OnnxLstmExpander(OpBinder binder, ConversionPatternRewriter &rewriter); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 716ea3d6e..bd5c57fac 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -2240,4 +2240,126 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, cstEquation, tensorList, /*path=*/cstNone); return success(); }); + patterns.onOp( + "BlackmanWindow", 17, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Value size; + Torch::ValueTensorType resultType; + int64_t periodic, output_datatype; + if (binder.tensorOperand(size) || + binder.s64IntegerAttr(output_datatype, "output_datatype", 1) || + binder.s64IntegerAttr(periodic, "periodic", 1) || + binder.tensorResultType(resultType)) { + return failure(); + } + double isPeriodicFp = static_cast(periodic); + Value a0 = rewriter.create( + binder.getLoc(), + rewriter.getFloatAttr(rewriter.getF64Type(), 0.42)); + Value a1 = rewriter.create( + binder.getLoc(), + rewriter.getFloatAttr(rewriter.getF64Type(), -0.5)); + Value a2 = rewriter.create( + binder.getLoc(), + rewriter.getFloatAttr(rewriter.getF64Type(), 0.08)); + Value zero = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(0.0)); + Value one = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(1.0)); + Value two = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(2.0)); + + constexpr double pi = llvm::numbers::pi; + Value tau = rewriter.create( + binder.getLoc(), + rewriter.getFloatAttr(rewriter.getF64Type(), 2.0 * pi)); + + Value noneVal = rewriter.create(binder.getLoc()); + Value cstFalse = + rewriter.create(binder.getLoc(), false); + Value float32Type = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(/*float32Type*/ 6)); + + // Create an f32 ValueTensorType with thse same size as size, the + // operand + auto shapeOfOperand = size.getType() + .dyn_cast() + .getOptionalSizes(); + auto f32ResultType = rewriter.getType( + shapeOfOperand, rewriter.getF32Type()); + Value periodicSizeFloat = rewriter.create( + binder.getLoc(), f32ResultType, size, float32Type, cstFalse, + cstFalse, noneVal); + Value symmetricSizeFloat = rewriter.create( + binder.getLoc(), periodicSizeFloat.getType(), periodicSizeFloat, + one, one); + + Value isPeriodic = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(isPeriodicFp)); + Value isSymmetricFloat = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(1.0 - isPeriodicFp)); + + Value periodicComponent = rewriter.create( + binder.getLoc(), periodicSizeFloat.getType(), periodicSizeFloat, + isPeriodic); + Value symmetricComponent = rewriter.create( + binder.getLoc(), symmetricSizeFloat.getType(), symmetricSizeFloat, + isSymmetricFloat); + Value sizeFloat = rewriter.create( + binder.getLoc(), symmetricComponent.getType(), symmetricComponent, + periodicComponent, one); + + // Here, size can be used in the place of periodicSizeFloat, as the + // latter is just a float representation of the former. + Value scalarLimit = getItemOp(binder, rewriter, size); + + Value rangeArr = rewriter.create( + binder.getLoc(), resultType, zero, scalarLimit, one, noneVal, + noneVal, noneVal, noneVal); + + Value rangeTimesTau = rewriter.create( + binder.getLoc(), resultType, rangeArr, tau); + Value rangeAngular = rewriter.create( + binder.getLoc(), resultType, rangeTimesTau, sizeFloat); + Value twoRangeAngular = rewriter.create( + binder.getLoc(), resultType, rangeAngular, two); + + Value cosRangeAngular = rewriter.create( + binder.getLoc(), resultType, rangeAngular); + Value cosTwoRangeAngular = rewriter.create( + binder.getLoc(), resultType, twoRangeAngular); + + Value a1Component = rewriter.create( + binder.getLoc(), resultType, cosRangeAngular, a1); + Value a2Component = rewriter.create( + binder.getLoc(), resultType, cosTwoRangeAngular, a2); + + // AtenSubScalarOp actually requires a tensor operand as the LHS, that + // is, operand #1. Therefore, to avoid errors, the onnx implementation + // has been modified. a1 has been changed to negative half, and the + // AtenSubScalarOp has been replaced with AtenAddScalarOp, as the add + // operation is commutative. + Value subA1Component = rewriter.create( + binder.getLoc(), resultType, a1Component, a0, one); + Value result = rewriter.create( + binder.getLoc(), resultType, subA1Component, a2Component, one); + + std::optional dtypeIntTorch = + onnxDtypeIntToTorchDtypeInt(output_datatype); + if (!dtypeIntTorch.has_value()) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented support for the given dtype conversion"); + } + Value outputDtype = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + dtypeIntTorch.value())); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, result, outputDtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/noneVal); + return success(); + }); } diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 197d9c536..5f9da3faa 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -31,15 +31,7 @@ using namespace mlir::torch::onnx_c; // thing here, so we simplify. // utilities -// Templatized function to get an item op of a type namespace { -template -Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter, - Value &ofItem) { - return rewriter.create(binder.getLoc(), - rewriter.getType(), ofItem); -} - // In case the ReduceSum Op was not the first operation performed on the data, // we provide the original operand through storeResult, which will be modified // if the result will be passed onto another operation, and will be used for diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index eb2cde696..a068acbf2 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -2035,3 +2035,81 @@ func.func @test_eyelike_dynamic(%arg0: !torch.vtensor<[3,?],f32>) -> !torch.vten %0 = torch.operator "onnx.EyeLike"(%arg0) {torch.onnx.k = -1 : si64} : (!torch.vtensor<[3,?],f32>) -> !torch.vtensor<[3,?],f32> return %0 : !torch.vtensor<[3,?],f32> } + +// ----- + +// CHECK-LABEL: func.func @test_blackmanwindow_symmetric +func.func @test_blackmanwindow_symmetric(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[A0:.+]] = torch.constant.float 4.200000e-01 + // CHECK-DAG: %[[A1:.+]] = torch.constant.float -5.000000e-01 + // CHECK-DAG: %[[A2:.+]] = torch.constant.float 8.000000e-02 + // CHECK-DAG: %[[FLOAT0_0:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[FLOAT1:.+]] = torch.constant.float 1.000000e+00 + // CHECK-DAG: %[[FLOAT2:.+]] = torch.constant.float 2.000000e+00 + // CHECK-DAG: %[[TWOPI:.+]] = torch.constant.float 6.2831853071795862 + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK-DAG: %[[INT6:.+]] = torch.constant.int 6 + // CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[FLOAT1]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 1.000000e+00 + // CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int + // CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[FLOAT0_0]], %[[RANGELIM]], %[[FLOAT1]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TWOPI]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[FLOAT2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[COSTWORANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[COSTWORANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6 + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32> + // CHECK: return %[[CAST]] : !torch.vtensor<[10],f32> + %0 = torch.operator "onnx.BlackmanWindow"(%arg0) {torch.onnx.periodic = 0 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> + return %0 : !torch.vtensor<[10],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_blackmanwindow +func.func @test_blackmanwindow(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[A0:.+]] = torch.constant.float 4.200000e-01 + // CHECK-DAG: %[[A1:.+]] = torch.constant.float -5.000000e-01 + // CHECK-DAG: %[[A2:.+]] = torch.constant.float 8.000000e-02 + // CHECK-DAG: %[[FLOAT0_0:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[FLOAT1:.+]] = torch.constant.float 1.000000e+00 + // CHECK-DAG: %[[FLOAT2:.+]] = torch.constant.float 2.000000e+00 + // CHECK-DAG: %[[TWOPI:.+]] = torch.constant.float 6.2831853071795862 + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK-DAG: %[[INT6:.+]] = torch.constant.int 6 + // CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[FLOAT1]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 1.000000e+00 + // CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[FLOAT1]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int + // CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[FLOAT0_0]], %[[RANGELIM]], %[[FLOAT1]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TWOPI]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[FLOAT2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[COSTWORANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[COSTWORANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[FLOAT1]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6 + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32> + // CHECK: return %[[CAST]] : !torch.vtensor<[10],f32> + %0 = torch.operator "onnx.BlackmanWindow"(%arg0) {torch.onnx.periodic = 1 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> + return %0 : !torch.vtensor<[10],f32> +} From 9442c6685698b111c936c2d0c2e173b5e56b88d7 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Tue, 30 Apr 2024 09:21:39 -0700 Subject: [PATCH 19/30] [torch-mlir][sparse] add a few missing passes to the ref pipeline (#3265) For some sparse programs (and I am sure other not-seen corner cases for dense), some passes were missing in the reference pipeline, eventually resulting in e.g. a unresolved unrealized cast issue. This PR adds some very obvious missing passes to avoid this situation. --- .../linalg_on_tensors_backends/refbackend.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index 1e958a4d9..08e8ff64d 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -180,6 +180,7 @@ LOWERING_PIPELINE = ( "func.func(tm-tensor-to-loops)", "func.func(refback-munge-memref-copy)", "func.func(convert-linalg-to-loops)", + "func.func(expand-realloc)", "func.func(lower-affine)", "convert-scf-to-cf", "func.func(refback-expand-ops-for-llvm)", @@ -193,6 +194,7 @@ LOWERING_PIPELINE = ( "convert-bufferization-to-memref", "finalize-memref-to-llvm", "func.func(convert-arith-to-llvm)", + "convert-vector-to-llvm", "convert-func-to-llvm", "convert-cf-to-llvm", "convert-complex-to-llvm", From 72349f7522195645d1af7b468bea15b64a37b105 Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Tue, 30 Apr 2024 11:23:09 -0500 Subject: [PATCH 20/30] [TorchToLinalg] Adds Quantization Support for ConvTranspose (#3240) I spent a little while debugging numerics issues with some tests similar to the ones in quantized_models.py, only to find that pytorch's quantized conv transpose is catastrophically inaccurate. I'll upstream the issue and only leave the tests here which are of the form quantize -> dequantize -> op. --- lib/Conversion/TorchToLinalg/Linear.cpp | 59 +++++++++++-------- projects/pt1/e2e_testing/xfail_sets.py | 5 ++ .../torch_mlir_e2e_test/test_suite/conv.py | 53 +++++++++++++++++ 3 files changed, 92 insertions(+), 25 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 3f4e6ed66..c49646e2f 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -43,7 +43,8 @@ static void signShift(PatternRewriter &rewriter, Location loc, Value &arg, if (!isUnsignedType) return; int64_t minSI = -(1 << (numBits - 1)); - Value minSIValue = rewriter.create(loc, minSI, 32); + Value minSIValue = rewriter.create( + loc, minSI, zp.getType().cast().getWidth()); zp = rewriter.create(loc, zp, minSIValue); minSIValue = rewriter.create(loc, minSI, numBits); arg = torch_to_linalg::createElementwiseLinalgGeneric( @@ -797,6 +798,8 @@ public: auto resultTy = cast(op.getType()); Value inputZp, weightZp; + bool inputUnsigned = false; + bool weightUnsigned = false; if (auto make = op.getInput() .getDefiningOp()) { input = make.getSelf(); @@ -806,6 +809,8 @@ public: inputZp = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(inputZp.getType()), inputZp); + auto torchDtype = cast(make.getType()).getDtype(); + inputUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype); } if (auto make = op.getWeight() @@ -818,6 +823,8 @@ public: weightZp = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(weightZp.getType()), weightZp); + auto torchDtype = cast(make.getType()).getDtype(); + weightUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype); } if (static_cast(inputZp) != static_cast(weightZp)) { @@ -916,15 +923,35 @@ public: SmallVector strideIntValues = getAsConstantIntValues(rewriter, loc, strideInts); + // convert any uint8 quantization to int8 quantization + if (auto integerType = dyn_cast(inputDTy)) { + int64_t width = integerType.getWidth(); + signShift(rewriter, loc, input, inputZp, inputUnsigned, width); + } + if (auto integerType = dyn_cast(weightDTy)) { + int64_t width = integerType.getWidth(); + signShift(rewriter, loc, weight, weightZp, weightUnsigned, width); + } // Pad the input tensor according to padding. SmallVector outDims{inBatch, weightBatch}; Value paddedInput; - if (transposed) { - if (!isa(inputDTy) || !isa(weightDTy) || - !isa(resultDTy)) - return rewriter.notifyMatchFailure( - op, "transpose does not support non-fp type yet"); + Value pad = inputZp; + if (!pad) { + if (isa(inputDTy)) + pad = rewriter.create( + op.getLoc(), rewriter.getFloatAttr(inputDTy, 0.0)); + if (isa(inputDTy)) + pad = rewriter.create( + op.getLoc(), rewriter.getIntegerAttr(inputDTy, 0)); + } + if (pad.getType() != inputDTy) { + if (isa(inputDTy)) + pad = rewriter.create(op.getLoc(), inputDTy, pad); + if (isa(inputDTy)) + pad = rewriter.create(op.getLoc(), inputDTy, pad); + } + if (transposed) { Value c0 = rewriter.create(loc, rewriter.getIndexAttr(0)); Value c1 = @@ -994,7 +1021,7 @@ public: // Allocate padded input tensor Value initTensor = - createZeroInitTensor(rewriter, loc, outerSizes, inputDTy); + createInitTensor(rewriter, loc, outerSizes, inputDTy, pad); // Insert input into allocated tensor SmallVector strideIndexValues{c1, c1}; @@ -1017,24 +1044,6 @@ public: strideInts.clear(); strideInts.append(numSpatialDims, 1); } else { - Value pad = inputZp; - if (!pad) { - if (isa(inputDTy)) - pad = rewriter.create( - op.getLoc(), rewriter.getFloatAttr(inputDTy, 0.0)); - if (isa(inputDTy)) - pad = rewriter.create( - op.getLoc(), rewriter.getIntegerAttr(inputDTy, 0)); - } - - if (pad.getType() != inputDTy) { - if (isa(inputDTy)) - pad = rewriter.create(op.getLoc(), inputDTy, pad); - - if (isa(inputDTy)) - pad = rewriter.create(op.getLoc(), inputDTy, pad); - } - // Pad input paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor( op, rewriter, input, paddingIntValues, /*unpaddedDims=*/2, pad); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 8ffe8d1c3..fcb7e053a 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -272,6 +272,7 @@ TORCHDYNAMO_XFAIL_SET = { "QuantizedReluInt8_basic", "QuantizedReluUint8_basic", "Conv2dQInt8Module_basic", + "ConvTranspose2DQInt8_basic", # Dynamo not supporting conv_tbc "ConvTbcModule_basic", "FloatImplicitModule_basic", @@ -372,6 +373,7 @@ FX_IMPORTER_XFAIL_SET = { "Conv2dQInt8Module_basic", "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "ConvTbcModule_basic", + "ConvTranspose2DQInt8_basic", "ConvolutionBackwardModule2DPadded_basic", "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", @@ -544,6 +546,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "ContainsIntList_True", "Conv2dQInt8Module_basic", "ConvTbcModule_basic", + "ConvTranspose2DQInt8_basic", "ConvolutionBackwardModule2DPadded_basic", "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", @@ -2100,6 +2103,7 @@ LTC_XFAIL_SET = { "ElementwiseBitwiseAndScalarInt32Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic", "Conv2dQInt8Module_basic", + "ConvTranspose2DQInt8_basic", } ONNX_XFAIL_SET = { @@ -2254,6 +2258,7 @@ ONNX_XFAIL_SET = { "Conv2dWithPaddingModule_basic", "Conv3dModule_basic", "ConvTbcModule_basic", + "ConvTranspose2DQInt8_basic", "Conv_Transpose2dModule_basic", "Convolution2DModule_basic", "Convolution2DStridedModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index 9600b0900..e99525c32 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -1046,3 +1046,56 @@ def Conv2dQInt8Module_basic(module, tu: TestUtils): weight = tu.randint(3, 4, 3, 2, low=-128, high=127).to(torch.int8) bias = torch.rand(3) module.forward(inputVec, weight, bias) + + +N = 10 +Cin = 5 +Cout = 7 +Hin = 10 +Win = 8 +Hker = 3 +Wker = 2 + + +class ConvTranspose2DQInt8Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.int8, True), + ([-1, -1, -1, -1], torch.int8, True), + ([-1], torch.float, True), + ] + ) + def forward(self, input, weight, bias): + qinput = torch._make_per_tensor_quantized_tensor(input, 0.01, -25) + qinput = torch.dequantize(qinput) + qweight = torch._make_per_tensor_quantized_tensor(weight, 0.01, 50) + qweight = torch.dequantize(qweight) + qbias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32) + qbias = torch.dequantize(qbias) + qz = torch.ops.aten.convolution( + qinput, + qweight, + bias=qbias, + stride=[2, 1], + padding=[1, 1], + dilation=[1, 1], + transposed=True, + output_padding=[0, 0], + groups=1, + ) + return qz + + +@register_test_case(module_factory=lambda: ConvTranspose2DQInt8Module()) +def ConvTranspose2DQInt8_basic(module, tu: TestUtils): + module.forward( + tu.randint(N, Cin, Hin, Win, low=-128, high=127).to(torch.int8), + tu.randint(Cin, Cout, Hker, Wker, low=-128, high=127).to(torch.int8), + torch.rand(Cout), + ) From 315dc6c3e377b74e8981776237cff2e733667811 Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Tue, 30 Apr 2024 13:41:03 -0400 Subject: [PATCH 21/30] [torch] `aten.eye` should use dynamic dims when no static dims are available (#3202) Co-authored-by: Xida Ren --- .../Torch/Transforms/DecomposeComplexOps.cpp | 47 +++++++++---------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 677ccc4f2..cc21f2155 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1059,44 +1059,44 @@ public: LogicalResult matchAndRewrite(AtenEyeMOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - int64_t n; - - if (!matchPattern(op.getN(), m_TorchConstantInt(&n))) - return rewriter.notifyMatchFailure(op, - "unimplemented: n must be constant"); - int64_t m; - if (!matchPattern(op.getM(), m_TorchConstantInt(&m))) - return rewriter.notifyMatchFailure(op, - "unimplemented: m must be constant"); - Value none = rewriter.create(loc); - auto outType = dyn_cast(op.getType()); + auto outType = op.getType().dyn_cast(); if (!outType) return rewriter.notifyMatchFailure( op, "Only tensor types input are currently supported"); if (!outType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } - if (n < 0) { - return rewriter.notifyMatchFailure(op, "n must be greater or equal to 0"); - } - if (m < 0) { - return rewriter.notifyMatchFailure(op, "m must be greater or equal to 0"); - } - + Value none = rewriter.create(loc); auto context = op.getContext(); auto int64Dtype = getDtypeIntValueForType( rewriter, loc, rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)); auto si64Type = IntegerType::get(context, 64, IntegerType::Signed); - auto arangeType = outType.getWithSizesAndDtype(llvm::ArrayRef(n), si64Type); + + int64_t n = kUnknownSize; + int64_t m = kUnknownSize; + // prioritize getting shape from output shape + if (outType.hasSizes() && outType.getSizes().size() == 2) { + n = outType.getSizes().front(); + m = outType.getSizes().back(); + } + // if output shape is not available, try to get shape from input + if (n == kUnknownSize) + matchPattern(op.getN(), m_TorchConstantInt(&n)); + if (m == kUnknownSize) + matchPattern(op.getM(), m_TorchConstantInt(&m)); + + // prepare two unsqueezed ranges that are equal on and only on the diagonal + auto rangeNSize = llvm::SmallVector({n}); + Type rangeNType = outType.getWithSizesAndDtype(rangeNSize, si64Type); Value rangeN = rewriter.create( - loc, arangeType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none, + loc, rangeNType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none, /*device=*/op.getDevice(), /*pin_memory=*/none); - auto arangeType1 = - outType.getWithSizesAndDtype(llvm::ArrayRef(m), si64Type); + auto rangeMSize = llvm::SmallVector({m}); + Type rangeMType = outType.getWithSizesAndDtype(rangeMSize, si64Type); Value rangeM = rewriter.create( - loc, arangeType1, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none, + loc, rangeMType, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none, /*device=*/none, /*pin_memory=*/none); Value constMinusOne = rewriter.create( @@ -1109,7 +1109,6 @@ public: } Value unsqzRangeN = *unsqzTensorInfo; - // compare unsqueezed input with boundaries auto eqType = ValueTensorType::get( context, cast(op.getType()).getSizes(), IntegerType::get(context, 1)); From 33eef15e428f848e3848d1038ed71faab893a686 Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Tue, 30 Apr 2024 14:36:40 -0400 Subject: [PATCH 22/30] Support onnx.If (#2825) This is probably a decent PR for learning about blocks and regions. If you're here to learn about that, consider also looking at lib/Conversion/TorchToSCF/TorchToSCF.cpp While this doesn't include an e2e test, it is tested downstream in https://github.com/nod-ai/SHARK-TestSuite/blob/main/e2eshark/onnx/operators/If/model.py --------- Co-authored-by: Xida Ren --- .../Conversion/TorchOnnxToTorch/Patterns.h | 25 +++++++++ .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 54 +++++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 12 ++--- .../test_suite/diagonal.py | 31 +++++++++++ python/torch_mlir/extras/onnx_importer.py | 8 ++- test/Conversion/TorchOnnxToTorch/ops/if.mlir | 20 +++++++ 6 files changed, 141 insertions(+), 9 deletions(-) create mode 100644 test/Conversion/TorchOnnxToTorch/ops/if.mlir diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index d3260500c..3230cc8b4 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -97,6 +97,31 @@ struct OpBinder { return success(); } + ParseResult tensorResultTypes(llvm::SmallVector &typeList) { + for (auto result : op->getResults()) { + auto t = toValidTensorType(result.getType()); + if (!t) + return failure(); + typeList.push_back(t); + } + return success(); + } + + // The importer imports Onnx.GraphProto attributes as regions attached to the + // op. + ParseResult getRegionAtIndex(mlir::Region *®ion, int64_t idx) { + if (idx >= op->getNumRegions()) + return failure(); + + region = &op->getRegion(idx); + + if (region == nullptr) { + return failure(); + } + + return success(); + } + ParseResult tensorResultTypeAtIndex(Torch::ValueTensorType &typeIdx, int64_t idx) { if (idx >= op->getNumResults()) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 7a150794c..1f1e2e5d7 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -158,6 +158,60 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( alignCorners); return success(); }); + patterns.onOp( + "If", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Value conditionTensor; + if (binder.tensorOperand(conditionTensor)) { + return rewriter.notifyMatchFailure(binder.op, + "condition bind failure"); + } + + auto conditionType = + conditionTensor.getType().cast(); + if (!conditionType || conditionType.getSizes().size() != 1) + return rewriter.notifyMatchFailure( + binder.op, "condition must have one single element per " + "https://onnx.ai/onnx/operators/onnx__If.html"); + auto conditionInt = rewriter.create( + binder.getLoc(), rewriter.getType(), + conditionTensor); + auto conditionBool = rewriter.create( + binder.getLoc(), rewriter.getType(), conditionInt); + + llvm::SmallVector resultTypes; + if (binder.tensorResultTypes(resultTypes)) { + return rewriter.notifyMatchFailure(binder.op, + "result type bind failure"); + } + + Region *thenRegion, *elseRegion; + if (binder.getRegionAtIndex(elseRegion, 0) || + binder.getRegionAtIndex(thenRegion, 1)) { + return rewriter.notifyMatchFailure(binder.op, "region bind failure"); + } + + auto primIfOp = rewriter.create( + binder.getLoc(), TypeRange(resultTypes), conditionBool); + + auto inlineIfCase = [&](Region &srcRegion, Region &dstRegion) { + rewriter.inlineRegionBefore(srcRegion, dstRegion, dstRegion.begin()); + }; + inlineIfCase(*thenRegion, primIfOp.getThenRegion()); + inlineIfCase(*elseRegion, primIfOp.getElseRegion()); + + auto replaceTerminator = [&](Region ®ion) { + PatternRewriter::InsertionGuard guard(rewriter); + Operation *terminator = region.front().getTerminator(); + rewriter.setInsertionPoint(terminator); + rewriter.replaceOpWithNewOp( + terminator, terminator->getOperands()); + }; + replaceTerminator(primIfOp.getThenRegion()); + replaceTerminator(primIfOp.getElseRegion()); + + rewriter.replaceOp(binder.op, primIfOp.getResults()); + return success(); + }); patterns.onOp("Less", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index fcb7e053a..25d8fa9be 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2562,16 +2562,12 @@ ONNX_XFAIL_SET = { "_ConvolutionDeprecated2DCudnnModule_basic", "_ConvolutionDeprecated2DDeterministicModule_basic", "_SoftmaxModule_basic", + # Failure - onnx_import # Failure - onnx_lowering: onnx.AveragePool "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", - # Failure - onnx_lowering: onnx.If - "DiagonalModule_basic", - "DiagonalModule_nonsquare", - "DiagonalModule_transposed", - "DiagonalModule_with_dims", - "DiagonalModule_with_dims_and_offset", - "DiagonalModule_with_negative_dims", - "DiagonalModule_with_offset", + # these diagonal modules are currently failing due to dynamic shape. + # We are currently testing aten.diagonal using DiagonalWithStaticShapeModule instead. + # when the issue is fixed, please remove DiagonalWithStaticShapeModule as well as the xfails here. "TileBigDimsSizeModule_basic", "TileSmallDimsSizeModule_basic", # Failure - onnx_lowering: onnx.MaxPool diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py index 6371f9a8d..3bd3796da 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py @@ -39,6 +39,37 @@ def DiagonalModule_nonsquare(module, tu: TestUtils): # ============================================================================== +class DiagonalWithStaticShapeModule(torch.nn.Module): + """ + Diagonal with static shape. The other diagonal modules are failing in onnx + because DecomoposeAtenEyeMOp requires constants n, m, which are only constant + when the shape is static. + + Please remove this module and associated test once the issue is fixed. + """ + + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([5, 9], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.diagonal(a) + + +@register_test_case(module_factory=lambda: DiagonalWithStaticShapeModule()) +def DiagonalWithStaticShapeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 9)) + + +# ============================================================================== + + class DiagonalTransposedModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index 8d0e4cf5a..e0d3529d9 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -347,8 +347,14 @@ class NodeImporter: continue elif handler is False: # Active error. + # try matching attribute type ID to name for a more descriptive error message + try: + attr_type_name = onnx.AttributeProto.AttributeType.Name(attr_type) + except ValueError: + attr_type_name = "UNKNOWN" raise OnnxImportError( - f"ONNX importer does not support generic node attribute type {attr_type}. " + f"ONNX importer does not support generic node attribute type {attr_type_name} " + f"with ID {attr_type}. " f"This likely means that this is a special node which requires specific " f"handling in the importer: {onnx_attr}" ) diff --git a/test/Conversion/TorchOnnxToTorch/ops/if.mlir b/test/Conversion/TorchOnnxToTorch/ops/if.mlir new file mode 100644 index 000000000..1d95a3f5f --- /dev/null +++ b/test/Conversion/TorchOnnxToTorch/ops/if.mlir @@ -0,0 +1,20 @@ +// RUN: torch-mlir-opt <%s --split-input-file -convert-torch-onnx-to-torch | FileCheck %s + +// CHECK-LABEL: func.func @test_ifop_basic +// CHECK: %[[IF:.*]] = torch.prim.If %{{.*}} -> (!torch.vtensor<[1],f32>) +// CHECK-DAG: %[[SUB:.*]] = torch.aten.sub.Tensor %arg1, %arg2, %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.int -> !torch.vtensor<[1],f32> +// CHECK-DAG: torch.prim.If.yield %[[SUB]] : !torch.vtensor<[1],f32> +// CHECK-DAG: } else { +// CHECK-DAG: %[[ADD:.*]] = torch.aten.add.Tensor %arg1, %arg2, %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.int -> !torch.vtensor<[1],f32> +// CHECK-DAG: torch.prim.If.yield %[[ADD]] : !torch.vtensor<[1],f32> +func.func @test_ifop_basic(%arg0: !torch.vtensor<[1],i1>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "conditional_example", torch.onnx_meta.producer_version = ""} { + %none = torch.constant.none + %0 = torch.operator "onnx.If"(%arg0) : (!torch.vtensor<[1],i1>) -> !torch.vtensor<[1],f32> { + %1 = torch.operator "onnx.Add"(%arg1, %arg2) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32> + torch.operator_terminator %1 : !torch.vtensor<[1],f32> + }, { + %1 = torch.operator "onnx.Sub"(%arg1, %arg2) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32> + torch.operator_terminator %1 : !torch.vtensor<[1],f32> + } + return %0 : !torch.vtensor<[1],f32> +} From 0a2d21b108602d2b11c208ca1a713a72f483f6c1 Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Tue, 30 Apr 2024 17:48:01 -0400 Subject: [PATCH 23/30] Add `.yamllint` and disable some annoying recurring warnings on every pr (#3224) --- .yamllint.yml | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 .yamllint.yml diff --git a/.yamllint.yml b/.yamllint.yml new file mode 100644 index 000000000..ec40711eb --- /dev/null +++ b/.yamllint.yml @@ -0,0 +1,22 @@ +--- + +extends: default + +rules: + # These do not appear to be conventional in GitHub actions. + document-end: + present: false + document-start: + present: false + # GitHub actions use "on" for triggers. + truthy: disable + # We have lots of long strings and command lines. + line-length: disable + comments: + # Formatters may do this (e.g. Prettier does) and it seems like the most + # trivial thing to get a failing check for. + min-spaces-from-content: 1 + # This is not a useful check, especially when disabling entire blocks. + comments-indentation: disable + +ignore: /third_party/* From 8c48135a426b84fa412b031fc92e12826ff60b31 Mon Sep 17 00:00:00 2001 From: Prashant Kumar Date: Wed, 1 May 2024 12:06:53 +0530 Subject: [PATCH 24/30] [linalg] Fix bug for conversion of complex dtype (#3269) The conversion of complex type wasn't supported or checked; the support and required tests were added. Fixes: https://github.com/iree-org/iree/issues/17226#issuecomment-2087779158 --- lib/Conversion/Utils/Utils.cpp | 21 ++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 2 ++ .../test_suite/elementwise.py | 28 +++++++++++++++++++ 3 files changed, 51 insertions(+) diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index bae25cc7a..e014fbeaa 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -10,6 +10,7 @@ #include "torch-mlir/Conversion/Utils/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -349,6 +350,26 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, return b.create(loc, dtype, scalar); } + if (auto dtypeComplex = dyn_cast(dtype)) { + if (auto scalarComplex = dyn_cast(scalarType)) { + auto dtypeElemType = dtypeComplex.getElementType(); + + // Extract the real and imaginary parts of the scalar. + // Cast them to the target element type, and create a new complex + // value with the target complex type. + Value realVal = b.create(loc, scalar); + Value imgVal = b.create(loc, scalar); + + realVal = convertScalarToDtype(b, loc, realVal, dtypeElemType); + imgVal = convertScalarToDtype(b, loc, imgVal, dtypeElemType); + + return b.create(loc, dtypeComplex, realVal, imgVal); + } + mlir::emitError(loc) << "unsupported scalar type for convertScalarToDtype " + << scalarType << "(scalar type) -> " << dtype + << "(dtype)"; + } + llvm_unreachable("convertScalarToDtype should handle all the types"); } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 25d8fa9be..33f1ed702 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -575,6 +575,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "ElementwiseErfIntModule_basic", "ElementwiseLogitModule_basic", "ElementwiseMulTensorComplexModule_basic", + "ElementwiseMulTensorComplexDiffModule_basic", "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", "ElementwiseReciprocalIntModule_basic", @@ -2314,6 +2315,7 @@ ONNX_XFAIL_SET = { "ElementwiseExpm1Module_basic", "ElementwiseFmodTensor_Int_basic", "ElementwiseMulTensorComplexModule_basic", + "ElementwiseMulTensorComplexDiffModule_basic", "ElementwiseOrTensorModule_basic", "ElementwiseOrTensorStaticShapeModule_basic", "ElementwiseQuantizePerTensorModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 8e2875842..a26fd9809 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1839,6 +1839,34 @@ def ElementwiseMulTensorComplexModule_basic(module, tu: TestUtils): # ============================================================================== +# torch.complex32 is not supported by the refbackend. +class ElementwiseMulTensorComplexDiffModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.complex64, True), + ([-1], torch.complex128, True), + ] + ) + def forward(self, a, b): + return torch.mul(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseMulTensorComplexDiffModule()) +def ElementwiseMulTensorComplexDiffModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(4, high=10).type(torch.complex64), + tu.randint(4, high=10).type(torch.complex128), + ) + + +# ============================================================================== + + class ElementwiseMishModule(torch.nn.Module): def __init__(self): super().__init__() From 11cd7cd9e7705fd69f40fabdad2e0e5b5b738914 Mon Sep 17 00:00:00 2001 From: Ze Zhang Date: Thu, 2 May 2024 00:03:41 -0700 Subject: [PATCH 25/30] Folder and Canonicalizer for PrimsConvertElementTypeOp and AtenMaxPool2dWithIndicesOp (#3272) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit While playing with TorchDynamo on ResNet18. I notice following issues: - `prims.convert_element_type` can’t be canonicalized even if the input and the output share the same type - `aten.max_pool2d_with_indices` is always used instead of `aten.max_pool2d`, even if the second returned output (indices) has no user This PR fixes above issues by adding a folder to the PrimsConvertElementTypeOp and a canonicalizer to the AtenMaxPool2dWithIndicesOp Lit test: `cmake --build build --target check-torch-mlir-all` --------- Co-authored-by: Ze Zhang --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 2 + lib/Dialect/Torch/IR/TorchOps.cpp | 39 ++++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 5 --- .../build_tools/torch_ods_gen.py | 5 ++- test/Dialect/Torch/canonicalize.mlir | 41 +++++++++++++++++++ 5 files changed, 85 insertions(+), 7 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index cb08ffd53..95d92af99 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6720,6 +6720,7 @@ def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices", printDefaultTorchOp(printer, *this, 6, 2); } }]; + let hasCanonicalizer = 1; } def Torch_AtenMaxPool2dWithIndicesBackwardOp : Torch_Op<"aten.max_pool2d_with_indices_backward", [ @@ -15982,6 +15983,7 @@ def Torch_PrimsConvertElementTypeOp : Torch_Op<"prims.convert_element_type", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasFolder = 1; } def Torch_PrimsVarOp : Torch_Op<"prims.var", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 29911961d..1d0ff41f7 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -4715,6 +4715,45 @@ LogicalResult AtenPermuteOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// PrimsConvertElementTypeOp +//===----------------------------------------------------------------------===// + +OpFoldResult PrimsConvertElementTypeOp::fold(FoldAdaptor adaptor) { + auto inputType = cast(getA().getType()); + auto outputType = cast(getResult().getType()); + if (inputType != outputType) + return nullptr; + if (!inputType.hasDtype() || !outputType.hasDtype()) + return nullptr; + if (inputType.getDtype() != outputType.getDtype()) + return nullptr; + return getA(); +} + +//===----------------------------------------------------------------------===// +// AtenMaxPool2dWithIndicesOp +//===----------------------------------------------------------------------===// + +void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(+[](AtenMaxPool2dWithIndicesOp op, PatternRewriter &rewriter) { + if (!op.getResult1().use_empty()) { + return rewriter.notifyMatchFailure( + op, "result1 of MaxPool2dWithIndices should be unused"); + } + + Value result = rewriter.create( + op->getLoc(), op.getResult0().getType(), op.getSelf(), + op.getKernelSize(), op.getStride(), op.getPadding(), op.getDilation(), + op.getCeilMode()); + + op.getResult0().replaceAllUsesWith(result); + rewriter.eraseOp(op); + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenLinalgCrossOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 33f1ed702..d8529cb38 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1924,11 +1924,6 @@ MAKE_FX_TOSA_PASS_SET = ( # Dynamic shape, has extra unsupported broadcast ops "Matmul_3d", "MatmulStaticBroadcast_basic", - # failed to legalize operation 'torch.aten.max_pool2d_with_indices - "MaxPool2dEmptyStrideStaticModule_basic", - "MaxPool2dStaticCeilModeTrueModule_basic", - "MaxPool2dStaticModule_basic", - "ResNet18StaticModule_basic", # Unimplemented operator 'aten._index_put_impl_.hacked_twin' "IndexPutImpl1DFloatNonAccumulateModule_basic", "IndexPutImpl1DIntNonAccumulateModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index e0329c8df..d4d547456 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -594,7 +594,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") emit( - "aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)" + "aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)", + has_canonicalizer=True, ) emit( "aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)" @@ -1104,7 +1105,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): # `prims::` namespace. # ========================================================================== - emit("prims::convert_element_type : (Tensor, int) -> (Tensor)") + emit("prims::convert_element_type : (Tensor, int) -> (Tensor)", has_folder=True) emit("prims::var : (Tensor, int[]?, float, int?) -> (Tensor)") emit("prims::sqrt : (Tensor) -> (Tensor)") emit("prims::collapse : (Tensor, int, int) -> (Tensor)") diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index a317e4011..e7605f661 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2974,3 +2974,44 @@ func.func @aten_log$fold_splat_f32() -> !torch.vtensor<[4], f32> { %result = torch.aten.log %cst : !torch.vtensor<[4], f32> -> !torch.vtensor<[4], f32> return %result : !torch.vtensor<[4], f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.prims.convert_element_type$fold( +// CHECK: %[[ARG:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],f32> { +// CHECK: return %[[ARG]] : !torch.vtensor<[64],f32> +func.func @torch.prims.convert_element_type$fold(%arg0: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],f32> { + %int6 = torch.constant.int 6 + %0 = torch.prims.convert_element_type %arg0, %int6 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + return %0 : !torch.vtensor<[64],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.prims.convert_element_type$no_fold( +// CHECK: %[[ARG:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],si32> { +// CHECK: %[[RET:.*]] = torch.prims.convert_element_type %[[ARG]], %{{.*}} : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],si32> +// CHECK: return %[[RET]] : !torch.vtensor<[64],si32> +func.func @torch.prims.convert_element_type$no_fold(%arg0: !torch.vtensor<[64],f32>) -> !torch.vtensor<[64],si32> { + %int6 = torch.constant.int 6 + %0 = torch.prims.convert_element_type %arg0, %int6 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],si32> + return %0 : !torch.vtensor<[64],si32> +} + +// ----- + +// CHECK-LABEL: @torch.aten.max_pool2d_with_indices$canonicalize( +// CHECK: %[[ARG:.*]]: !torch.vtensor<[10,64,112,112],f32>) -> !torch.vtensor<[10,64,56,56],f32> { +// CHECK: %[[RET:.*]] = torch.aten.max_pool2d %[[ARG]] +// CHECK: return %[[RET]] : !torch.vtensor<[10,64,56,56],f32> +func.func @torch.aten.max_pool2d_with_indices$canonicalize(%arg0: !torch.vtensor<[10,64,112,112],f32>) -> !torch.vtensor<[10,64,56,56],f32> { + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %29 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %30 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %31 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %result0, %result1 = torch.aten.max_pool2d_with_indices %arg0, %29, %30, %31, %31, %false : !torch.vtensor<[10,64,112,112],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[10,64,56,56],f32>, !torch.vtensor<[10,64,56,56],si64> + return %result0 : !torch.vtensor<[10,64,56,56],f32> +} From 0bb62e4347d239018797b0829b44cdbffa78a3a2 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Thu, 2 May 2024 21:30:24 +0530 Subject: [PATCH 26/30] Revert Onnx.Selu lowering to corresponding Aten op (#3275) --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 28 ++----------------- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 16 +++-------- 2 files changed, 7 insertions(+), 37 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 5f9da3faa..3553d22c7 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -851,9 +851,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.tensorResultType(resultType)) return failure(); - Torch::ValueTensorType inputType = - operand.getType().cast(); - Value vAlpha = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), alpha)); @@ -862,31 +859,12 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), gamma)); - Value cstOne = rewriter.create( + Value vInputScale = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), 1.0)); - Value cstNone = rewriter.create(binder.getLoc()); - Value zeroTensor = rewriter.create( - binder.getLoc(), resultType, operand, cstNone, cstNone, cstNone, - cstNone, cstNone); - Value exp = rewriter.create(binder.getLoc(), - resultType, operand); - Value expMulAlpha = rewriter.create( - binder.getLoc(), resultType, exp, vAlpha); - Value expMulAlphaSubAlpha = rewriter.create( - binder.getLoc(), resultType, expMulAlpha, vAlpha, cstOne); - Value neg = rewriter.create( - binder.getLoc(), resultType, expMulAlphaSubAlpha, vScale); - Value pos = rewriter.create( - binder.getLoc(), resultType, operand, vScale); - Type compareType = inputType.getWithSizesAndDtype( - inputType.getOptionalSizes(), rewriter.getI1Type()); - Value xLessThanZero = rewriter.create( - binder.getLoc(), compareType, operand, zeroTensor); - - rewriter.replaceOpWithNewOp( - binder.op, resultType, xLessThanZero, neg, pos); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, vAlpha, vScale, vInputScale); return success(); }); patterns.onOp("ReduceL1", 1, diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 2748a640a..0c2b9180c 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -582,18 +582,10 @@ func.func @test_softmax_negative_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !to // CHECK-LABEL: func.func @test_selu func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 6 : si64} { - // CHECK: %[[F2:.+]] = torch.constant.float 2.000000e+00 - // CHECK: %[[F3:.+]] = torch.constant.float 3.000000e+00 - // CHECK: %[[F1:.+]] = torch.constant.float 1.000000e+00 - // CHECK: %[[NONE:.+]] = torch.constant.none - // CHECK: %[[ZEROS:.+]] = torch.aten.zeros_like %arg0, %none, %none, %none, %none, %none : !torch.vtensor<[3,4,5],f32>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[EXP:.+]] = torch.aten.exp %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[MUL:.+]] = torch.aten.mul.Scalar %[[EXP]], %[[F2]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[SUB:.+]] = torch.aten.sub.Scalar %[[MUL]], %[[F2]], %[[F1]] : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[MUL_1:.+]] = torch.aten.mul.Scalar %[[SUB]], %[[F3]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[MUL_2:.+]] = torch.aten.mul.Scalar %arg0, %[[F3]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[LT:.+]] = torch.aten.lt.Tensor %arg0, %[[ZEROS]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],i1> - // CHECK: torch.aten.where.self %[[LT]], %[[MUL_1]], %[[MUL_2]] : !torch.vtensor<[3,4,5],i1>, !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK-DAG: %[[F1:.+]] = torch.constant.float 1 + // CHECK-DAG: %[[F2:.+]] = torch.constant.float 2 + // CHECK-DAG: %[[F3:.+]] = torch.constant.float 3 + // CHECK: %[[ELU:.+]] = torch.aten.elu %arg0, %[[F2]], %[[F3]], %[[F1]] %0 = torch.operator "onnx.Selu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32, torch.onnx.gamma = 3.000000e+00 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } From a46fe2c9db1a843ec96a473c713887dea7b6a206 Mon Sep 17 00:00:00 2001 From: Archana Ramalingam <98564406+archana-ramalingam@users.noreply.github.com> Date: Thu, 2 May 2024 09:47:45 -0700 Subject: [PATCH 27/30] [MLIR][ONNX] Add OnnxToTorch support for ReduceSumSquare Op (#3188) This commit adds the OnnxToTorch support for ReduceSumSquare ops. --------- Co-authored-by: Ubuntu --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 52 +++++-- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 144 ++++++++++++++---- 2 files changed, 150 insertions(+), 46 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 3553d22c7..b2cbfa3e8 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -940,22 +940,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*memory_format=*/noneVal); return success(); }); - patterns.onOp("ReduceSum", 1, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value data; - int64_t keepDims, noop_with_empty_axes; - if (binder.tensorOperandAtIndex(data, 0) || - binder.tensorResultType(resultType) || - binder.s64IntegerAttr(keepDims, "keepdims", 1) || - binder.s64IntegerAttr(noop_with_empty_axes, - "noop_with_empty_axes", 0)) - return failure(); - - return reducedSumImpl(binder, rewriter, data, resultType, - /*storeValue=*/data, keepDims, - noop_with_empty_axes, false); - }); patterns.onOp("ReduceLogSum", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -982,6 +966,42 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, resultType, data); return success(); }); + patterns.onOp("ReduceSum", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + int64_t keepDims, noop_with_empty_axes; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, + "noop_with_empty_axes", 0)) + return failure(); + + return reducedSumImpl(binder, rewriter, data, resultType, + /*storeValue=*/data, keepDims, + noop_with_empty_axes, false); + }); + patterns.onOp("ReduceSumSquare", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + int64_t keepDims, noop_with_empty_axes; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, + "noop_with_empty_axes", 0)) + return failure(); + + Value dataSquare = rewriter.create( + binder.getLoc(), data.getType(), data, data); + + return reducedSumImpl(binder, rewriter, dataSquare, + resultType, + /*storeValue=*/data, keepDims, + noop_with_empty_axes, false); + }); patterns.onOp( "ReduceMean", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 0c2b9180c..56df1b2e0 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -860,6 +860,57 @@ func.func @test_reduce_l2_keep_dims_int_input_example(%arg0: !torch.vtensor<[3,2 // ----- +// CHECK-LABEL: func.func @test_reduce_log_sum_default_axes_keepdims_example +func.func @test_reduce_log_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> + // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[1,1,1],f32> -> !torch.vtensor<[1,1,1],f32> + // CHECK: return %[[LOG]] : !torch.vtensor<[1,1,1],f32> + %0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> + return %0 : !torch.vtensor<[1,1,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_log_sum_keep_dims_example +func.func @test_reduce_log_sum_keep_dims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> + // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2,1],f32> -> !torch.vtensor<[3,2,1],f32> + // CHECK: return %[[LOG]] : !torch.vtensor<[3,2,1],f32> + %0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> + return %0 : !torch.vtensor<[3,2,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_log_sum_do_not_keepdims_example +func.func @test_reduce_log_sum_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2,2],f32>, %arg1:!torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> + // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2],f32> -> !torch.vtensor<[3,2],f32> + // CHECK: return %[[LOG]] : !torch.vtensor<[3,2],f32> + %0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> + return %0 : !torch.vtensor<[3,2],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_reduce_sum_default_axes_keepdims_example func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0 @@ -942,41 +993,24 @@ func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor< // ----- -// CHECK-LABEL: func.func @test_reduce_log_sum_default_axes_keepdims_example -func.func @test_reduce_log_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK-LABEL: func.func @test_reduce_sum_square_default_axes_keepdims_example +func.func @test_reduce_sum_square_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32> // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[NONE:.+]] = torch.constant.none - // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> - // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[1,1,1],f32> -> !torch.vtensor<[1,1,1],f32> - // CHECK: return %[[LOG]] : !torch.vtensor<[1,1,1],f32> - %0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> + // CHECK: return %[[SUM]] : !torch.vtensor<[1,1,1],f32> + %0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> return %0 : !torch.vtensor<[1,1,1],f32> } // ----- -// CHECK-LABEL: func.func @test_reduce_log_sum_keep_dims_example -func.func @test_reduce_log_sum_keep_dims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list - // CHECK: %[[TRUE:.+]] = torch.constant.bool true - // CHECK: %[[NONE:.+]] = torch.constant.none - // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> - // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2,1],f32> -> !torch.vtensor<[3,2,1],f32> - // CHECK: return %[[LOG]] : !torch.vtensor<[3,2,1],f32> - %0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> - return %0 : !torch.vtensor<[3,2,1],f32> -} - -// ----- - -// CHECK-LABEL: func.func @test_reduce_log_sum_do_not_keepdims_example -func.func @test_reduce_log_sum_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2,2],f32>, %arg1:!torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK-LABEL: func.func @test_reduce_sum_square_do_not_keepdims_example +func.func @test_reduce_sum_square_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32> // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> @@ -984,15 +1018,65 @@ func.func @test_reduce_log_sum_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2 // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[NONE:.+]] = torch.constant.none - // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> - // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2],f32> -> !torch.vtensor<[3,2],f32> - // CHECK: return %[[LOG]] : !torch.vtensor<[3,2],f32> - %0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> + // CHECK: return %[[SUM]] : !torch.vtensor<[3,2],f32> + %0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> return %0 : !torch.vtensor<[3,2],f32> } // ----- +// CHECK-LABEL: func.func @test_reduce_sum_square_empty_set_non_reduced_axis_zero +func.func @test_reduce_sum_square_empty_set_non_reduced_axis_zero(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,0,1],f32> attributes {torch.onnx_meta.ir_version = 8: si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[2,0,4],f32>, !torch.vtensor<[2,0,4],f32> -> !torch.vtensor<[2,0,4],f32> + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT2]] : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[2,0,4],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[2,0,1],f32> + // CHECK: return %[[SUM]] : !torch.vtensor<[2,0,1],f32> + %0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,0,1],f32> + return %0 : !torch.vtensor<[2,0,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_sum_square_keepdims_example +func.func @test_reduce_sum_square_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32> + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32> + // CHECK: return %[[SUM]] : !torch.vtensor<[3,1,2],f32> + %0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> + return %0 : !torch.vtensor<[3,1,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_sum_square_keepdims_int_example +func.func @test_reduce_sum_square_keepdims_int_example(%arg0: !torch.vtensor<[3,2,2],si64>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],si64>, !torch.vtensor<[3,2,2],si64> -> !torch.vtensor<[3,2,2],si64> + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],si64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32> + // CHECK: return %[[SUM]] : !torch.vtensor<[3,1,2],f32> + %0 = torch.operator "onnx.ReduceSumSquare"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> + return %0 : !torch.vtensor<[3,1,2],f32> +} + +// ----- + // CHECK-LABEL: @test_reduce_mean_negative_axes_keepdims_example func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { // CHECK: %[[TENSOR:.+]] = torch.vtensor.literal(dense<-2> : tensor<1xsi64>) : !torch.vtensor<[1],si64> From 67d6a665a450eb0f69fbade238717f6dd3f56654 Mon Sep 17 00:00:00 2001 From: Vinayak Dev <104419489+vinayakdsci@users.noreply.github.com> Date: Fri, 3 May 2024 21:34:57 +0530 Subject: [PATCH 28/30] [torch] Add OnnxToTorch lowering for `onnx.HannWindow` (#3276) Adds OnnxToTorch lowering for the `onnx.HannWindow` op. Also factors out common implementation between the window functions. --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 227 +++++++++++------- .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 79 ++++++ 2 files changed, 213 insertions(+), 93 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index bd5c57fac..674af7399 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -35,6 +35,108 @@ static LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter, return success(); } +namespace { +LogicalResult windowFunctionImpl(OpBinder binder, + ConversionPatternRewriter &rewriter, + Value size, Value a0, Value a1, Value a2, + Torch::ValueTensorType resultType, + int64_t output_datatype, int64_t periodic) { + + Location loc = binder.getLoc(); + ImplicitLocOpBuilder b(loc, rewriter); + + double isPeriodicFp = static_cast(periodic); + + Value zero = b.create(rewriter.getF64FloatAttr(0.0)); + Value one = b.create(rewriter.getF64FloatAttr(1.0)); + Value two = b.create(rewriter.getF64FloatAttr(2.0)); + + constexpr double pi = llvm::numbers::pi; + Value tau = b.create( + rewriter.getFloatAttr(rewriter.getF64Type(), 2.0 * pi)); + + Value noneVal = b.create(); + Value cstFalse = b.create(false); + Value float32Type = b.create( + rewriter.getI64IntegerAttr(/*float32Type*/ 6)); + + // Create an f32 ValueTensorType with thse same size as size, the + // operand + auto shapeOfOperand = + size.getType().dyn_cast().getOptionalSizes(); + auto f32ResultType = rewriter.getType( + shapeOfOperand, rewriter.getF32Type()); + Value periodicSizeFloat = b.create( + f32ResultType, size, float32Type, cstFalse, cstFalse, noneVal); + Value symmetricSizeFloat = b.create( + periodicSizeFloat.getType(), periodicSizeFloat, one, one); + + Value isPeriodic = + b.create(rewriter.getF64FloatAttr(isPeriodicFp)); + Value isSymmetricFloat = b.create( + rewriter.getF64FloatAttr(1.0 - isPeriodicFp)); + + Value periodicComponent = b.create( + periodicSizeFloat.getType(), periodicSizeFloat, isPeriodic); + Value symmetricComponent = b.create( + symmetricSizeFloat.getType(), symmetricSizeFloat, isSymmetricFloat); + Value sizeFloat = b.create( + symmetricComponent.getType(), symmetricComponent, periodicComponent, one); + + // Here, size can be used in the place of periodicSizeFloat, as the + // latter is just a float representation of the former. + Value scalarLimit = getItemOp(binder, rewriter, size); + + Value rangeArr = b.create( + resultType, zero, scalarLimit, one, noneVal, noneVal, noneVal, noneVal); + + Value rangeTimesTau = + b.create(resultType, rangeArr, tau); + Value rangeAngular = + b.create(resultType, rangeTimesTau, sizeFloat); + Value twoRangeAngular = + b.create(resultType, rangeAngular, two); + + Value cosRangeAngular = b.create(resultType, rangeAngular); + Value cosTwoRangeAngular = + b.create(resultType, twoRangeAngular); + + Value a1Component = + b.create(resultType, cosRangeAngular, a1); + Value a2Component = + b.create(resultType, cosTwoRangeAngular, a2); + + // AtenSubScalarOp actually requires a tensor operand as the LHS, that + // is, operand #1. Therefore, to avoid errors, the onnx implementation + // has been modified. a1 has been changed to negative half, and the + // AtenSubScalarOp has been replaced with AtenAddScalarOp, as the add + // operation is commutative. + Value subA1Component = + b.create(resultType, a1Component, a0, one); + Value result = b.create(resultType, subA1Component, + a2Component, one); + + std::optional dtypeIntTorch = + onnxDtypeIntToTorchDtypeInt(output_datatype); + if (!dtypeIntTorch.has_value()) { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented support for the given dtype conversion"); + } + Value outputDtype = b.create( + rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + dtypeIntTorch.value())); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, result, outputDtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/noneVal); + + return success(); +} + +} // namespace + // Simple rewrites for the default domain. // See: https://onnx.ai/onnx/operators/ // For operators that are effectively version invariant, we register with @@ -2252,7 +2354,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.tensorResultType(resultType)) { return failure(); } - double isPeriodicFp = static_cast(periodic); Value a0 = rewriter.create( binder.getLoc(), rewriter.getFloatAttr(rewriter.getF64Type(), 0.42)); @@ -2262,104 +2363,44 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( Value a2 = rewriter.create( binder.getLoc(), rewriter.getFloatAttr(rewriter.getF64Type(), 0.08)); - Value zero = rewriter.create( - binder.getLoc(), rewriter.getF64FloatAttr(0.0)); - Value one = rewriter.create( - binder.getLoc(), rewriter.getF64FloatAttr(1.0)); - Value two = rewriter.create( - binder.getLoc(), rewriter.getF64FloatAttr(2.0)); - constexpr double pi = llvm::numbers::pi; - Value tau = rewriter.create( - binder.getLoc(), - rewriter.getFloatAttr(rewriter.getF64Type(), 2.0 * pi)); + auto windowFunctionResult = + windowFunctionImpl(binder, rewriter, size, a0, a1, a2, resultType, + output_datatype, periodic); - Value noneVal = rewriter.create(binder.getLoc()); - Value cstFalse = - rewriter.create(binder.getLoc(), false); - Value float32Type = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(/*float32Type*/ 6)); + if (failed(windowFunctionResult)) + return failure(); - // Create an f32 ValueTensorType with thse same size as size, the - // operand - auto shapeOfOperand = size.getType() - .dyn_cast() - .getOptionalSizes(); - auto f32ResultType = rewriter.getType( - shapeOfOperand, rewriter.getF32Type()); - Value periodicSizeFloat = rewriter.create( - binder.getLoc(), f32ResultType, size, float32Type, cstFalse, - cstFalse, noneVal); - Value symmetricSizeFloat = rewriter.create( - binder.getLoc(), periodicSizeFloat.getType(), periodicSizeFloat, - one, one); + return success(); + }); - Value isPeriodic = rewriter.create( - binder.getLoc(), rewriter.getF64FloatAttr(isPeriodicFp)); - Value isSymmetricFloat = rewriter.create( - binder.getLoc(), rewriter.getF64FloatAttr(1.0 - isPeriodicFp)); - - Value periodicComponent = rewriter.create( - binder.getLoc(), periodicSizeFloat.getType(), periodicSizeFloat, - isPeriodic); - Value symmetricComponent = rewriter.create( - binder.getLoc(), symmetricSizeFloat.getType(), symmetricSizeFloat, - isSymmetricFloat); - Value sizeFloat = rewriter.create( - binder.getLoc(), symmetricComponent.getType(), symmetricComponent, - periodicComponent, one); - - // Here, size can be used in the place of periodicSizeFloat, as the - // latter is just a float representation of the former. - Value scalarLimit = getItemOp(binder, rewriter, size); - - Value rangeArr = rewriter.create( - binder.getLoc(), resultType, zero, scalarLimit, one, noneVal, - noneVal, noneVal, noneVal); - - Value rangeTimesTau = rewriter.create( - binder.getLoc(), resultType, rangeArr, tau); - Value rangeAngular = rewriter.create( - binder.getLoc(), resultType, rangeTimesTau, sizeFloat); - Value twoRangeAngular = rewriter.create( - binder.getLoc(), resultType, rangeAngular, two); - - Value cosRangeAngular = rewriter.create( - binder.getLoc(), resultType, rangeAngular); - Value cosTwoRangeAngular = rewriter.create( - binder.getLoc(), resultType, twoRangeAngular); - - Value a1Component = rewriter.create( - binder.getLoc(), resultType, cosRangeAngular, a1); - Value a2Component = rewriter.create( - binder.getLoc(), resultType, cosTwoRangeAngular, a2); - - // AtenSubScalarOp actually requires a tensor operand as the LHS, that - // is, operand #1. Therefore, to avoid errors, the onnx implementation - // has been modified. a1 has been changed to negative half, and the - // AtenSubScalarOp has been replaced with AtenAddScalarOp, as the add - // operation is commutative. - Value subA1Component = rewriter.create( - binder.getLoc(), resultType, a1Component, a0, one); - Value result = rewriter.create( - binder.getLoc(), resultType, subA1Component, a2Component, one); - - std::optional dtypeIntTorch = - onnxDtypeIntToTorchDtypeInt(output_datatype); - if (!dtypeIntTorch.has_value()) { - return rewriter.notifyMatchFailure( - binder.op, - "unimplemented support for the given dtype conversion"); + patterns.onOp( + "HannWindow", 17, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Value size; + Torch::ValueTensorType resultType; + int64_t periodic, output_datatype; + if (binder.tensorOperand(size) || + binder.s64IntegerAttr(output_datatype, "output_datatype", 1) || + binder.s64IntegerAttr(periodic, "periodic", 1) || + binder.tensorResultType(resultType)) { + return failure(); } - Value outputDtype = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), - dtypeIntTorch.value())); + Value a0 = rewriter.create( + binder.getLoc(), rewriter.getFloatAttr(rewriter.getF64Type(), 0.5)); + Value a1 = rewriter.create( + binder.getLoc(), + rewriter.getFloatAttr(rewriter.getF64Type(), -0.5)); + Value a2 = rewriter.create( + binder.getLoc(), rewriter.getFloatAttr(rewriter.getF64Type(), 0.0)); + + auto windowFunctionResult = + windowFunctionImpl(binder, rewriter, size, a0, a1, a2, resultType, + output_datatype, periodic); + + if (failed(windowFunctionResult)) + return failure(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, result, outputDtype, - /*non_blocking=*/cstFalse, /*copy=*/cstFalse, - /*memory_format=*/noneVal); return success(); }); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index a068acbf2..179a75d49 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -2113,3 +2113,82 @@ func.func @test_blackmanwindow(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor %0 = torch.operator "onnx.BlackmanWindow"(%arg0) {torch.onnx.periodic = 1 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> return %0 : !torch.vtensor<[10],f32> } +// ----- + +// CHECK-LABEL: func.func @test_hannwindow +func.func @test_hannwindow(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[A0:.+]] = torch.constant.float 5.000000e-01 + // CHECK-DAG: %[[A1:.+]] = torch.constant.float -5.000000e-01 + // CHECK-DAG: %[[A2:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[ZERO:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[ONE:.+]] = torch.constant.float 1.000000e+00 + // CHECK-DAG: %[[TWO:.+]] = torch.constant.float 2.000000e+00 + // CHECK-DAG: %[[TAU:.+]] = torch.constant.float 6.2831853071795862 + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK-DAG: %[[INT6_0:.+]] = torch.constant.int 6 + // CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6_0]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[ONE]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 1.000000e+00 + // CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int + // CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[ZERO]], %[[RANGELIM]], %[[ONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TAU]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[TWO]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[TWOCOSRANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[TWOCOSRANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6 + // CHECK: %[[CAST_1:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32> + // CHECK: return %[[CAST_1]] : !torch.vtensor<[10],f32> + + %0 = torch.operator "onnx.HannWindow"(%arg0) {torch.onnx.periodic = 1 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> + return %0 : !torch.vtensor<[10],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_hannwindow_symmetric +func.func @test_hannwindow_symmetric(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[A0:.+]] = torch.constant.float 5.000000e-01 + // CHECK-DAG: %[[A1:.+]] = torch.constant.float -5.000000e-01 + // CHECK-DAG: %[[A2:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[ZERO:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[ONE:.+]] = torch.constant.float 1.000000e+00 + // CHECK-DAG: %[[TWO:.+]] = torch.constant.float 2.000000e+00 + // CHECK-DAG: %[[TAU:.+]] = torch.constant.float 6.2831853071795862 + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK-DAG: %[[INT6_0:.+]] = torch.constant.int 6 + // CHECK-DAG: %[[CAST_0:.+]] = torch.aten.to.dtype %arg0, %[[INT6_0]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SYMMSIZE:.+]] = torch.aten.sub.Scalar %[[CAST_0]], %[[ONE]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.float, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[PERIODIC:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[SYMMETRIC:.+]] = torch.constant.float 1.000000e+00 + // CHECK-DAG: %[[PERIODICCOMP:.+]] = torch.aten.mul.Scalar %[[CAST_0]], %[[PERIODIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SYMMETRICCOMP:.+]] = torch.aten.mul.Scalar %[[SYMMSIZE]], %[[SYMMETRIC]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[SIZEFP:.+]] = torch.aten.add.Tensor %[[SYMMETRICCOMP]], %[[PERIODICCOMP]], %[[ONE]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[RANGELIM:.+]] = torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int + // CHECK-DAG: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[ZERO]], %[[RANGELIM]], %[[ONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.float, !torch.int, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RANGETIMESTAU:.+]] = torch.aten.mul.Scalar %[[ARANGE]], %[[TAU]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RANGEANGULAR:.+]] = torch.aten.div.Tensor %[[RANGETIMESTAU]], %[[SIZEFP]] : !torch.vtensor<[10],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[TWORANGEANGULAR:.+]] = torch.aten.mul.Scalar %[[RANGEANGULAR]], %[[TWO]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[COSRANGEANGULAR:.+]] = torch.aten.cos %[[RANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[TWOCOSRANGEANGULAR:.+]] = torch.aten.cos %[[TWORANGEANGULAR]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[A1COMP:.+]] = torch.aten.mul.Scalar %[[COSRANGEANGULAR]], %[[A1]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[A2COMP:.+]] = torch.aten.mul.Scalar %[[TWOCOSRANGEANGULAR]], %[[A2]] : !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RES:.+]] = torch.aten.add.Scalar %[[A1COMP]], %[[A0]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.float, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[RESULT:.+]] = torch.aten.add.Tensor %[[RES]], %[[A2COMP]], %[[ONE]] : !torch.vtensor<[10],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[10],f32> + // CHECK-DAG: %[[INT6_1:.+]] = torch.constant.int 6 + // CHECK: %[[CAST_1:.+]] = torch.aten.to.dtype %[[RESULT]], %[[INT6_1]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f32> + // CHECK: return %[[CAST_1]] : !torch.vtensor<[10],f32> + + %0 = torch.operator "onnx.HannWindow"(%arg0) {torch.onnx.periodic = 0 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> + return %0 : !torch.vtensor<[10],f32> +} From 321b844df743dad3ab816583345c4228b4f66d3a Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 3 May 2024 09:06:44 -0700 Subject: [PATCH 29/30] Revert hyperbolic trigonometric decompositions (#3271) We should be using the `torch` path and handling decomposition in the `math` dialect. --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 142 ++++++------------ .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 35 ++--- .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 53 +------ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 10 +- 4 files changed, 64 insertions(+), 176 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 674af7399..873e02ed9 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -300,29 +300,17 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); - patterns.onOp( - "Asinh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) - return failure(); - - // log(x + sqrt(x**2 + 1)) - Value square = rewriter.create( - binder.getLoc(), resultType, operand); - Value cstOne = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); - Value add0 = rewriter.create( - binder.getLoc(), resultType, square, cstOne, cstOne); - Value sqrt = rewriter.create(binder.getLoc(), - resultType, add0); - Value add1 = rewriter.create( - binder.getLoc(), resultType, operand, sqrt, cstOne); - rewriter.replaceOpWithNewOp(binder.op, resultType, - add1); - return success(); - }); + patterns.onOp("Asinh", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); patterns.onOp("Atan", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -334,33 +322,17 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); - patterns.onOp( - "Atanh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) - return failure(); - - // 1/2 * log((1 + x) / (1 - x)) - Value cstOne = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); - Value add = rewriter.create( - binder.getLoc(), resultType, operand, cstOne, cstOne); - Value neg = rewriter.create(binder.getLoc(), - resultType, operand); - Value sub = rewriter.create( - binder.getLoc(), resultType, neg, cstOne, cstOne); - Value div = rewriter.create( - binder.getLoc(), resultType, add, sub); - Value log = - rewriter.create(binder.getLoc(), resultType, div); - Value cstTwo = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(2)); - rewriter.replaceOpWithNewOp( - binder.op, resultType, log, cstTwo); - return success(); - }); + patterns.onOp("Atanh", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); patterns.onOp("Acos", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -372,29 +344,17 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); - patterns.onOp( - "Acosh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) - return failure(); - - // log(x + sqrt(x**2 - 1)) - Value square = rewriter.create( - binder.getLoc(), resultType, operand); - Value cstOne = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); - Value sub = rewriter.create( - binder.getLoc(), resultType, square, cstOne, cstOne); - Value sqrt = rewriter.create(binder.getLoc(), - resultType, sub); - Value add = rewriter.create( - binder.getLoc(), resultType, operand, sqrt, cstOne); - rewriter.replaceOpWithNewOp(binder.op, resultType, - add); - return success(); - }); + patterns.onOp("Acosh", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); patterns.onOp("BatchNormalization", 15, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -1490,31 +1450,17 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); - patterns.onOp( - "Cosh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) - return failure(); - - // 1/2 * (exp(x) + exp(-x)) - Value x = rewriter.create(binder.getLoc(), resultType, - operand); - Value neg = rewriter.create(binder.getLoc(), - resultType, operand); - Value y = - rewriter.create(binder.getLoc(), resultType, neg); - Value cstOne = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); - Value z = rewriter.create( - binder.getLoc(), resultType, x, y, cstOne); - Value cstTwo = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(2)); - rewriter.replaceOpWithNewOp( - binder.op, resultType, z, cstTwo); - return success(); - }); + patterns.onOp("Cosh", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); patterns.onOp( "CumSum", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Location loc = binder.getLoc(); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index b2cbfa3e8..bea89ffed 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1439,31 +1439,18 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return success(); }); - patterns.onOp( - "Sinh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) - return failure(); + patterns.onOp("Sinh", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); - // 1/2 * (exp(x) – exp(-x)) - Value x = rewriter.create(binder.getLoc(), resultType, - operand); - Value neg = rewriter.create(binder.getLoc(), - resultType, operand); - Value y = - rewriter.create(binder.getLoc(), resultType, neg); - Value cstOne = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); - Value z = rewriter.create( - binder.getLoc(), resultType, x, y, cstOne); - Value cstTwo = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(2)); - rewriter.replaceOpWithNewOp( - binder.op, resultType, z, cstTwo); - return success(); - }); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); // split with fixed-size parts // Arguments: diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 179a75d49..47ed8948c 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -201,14 +201,7 @@ func.func @test_atan(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, // CHECK-LABEL: @test_atanh func.func @test_atanh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[C1:.*]] = torch.constant.int 1 - // CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %arg0, %[[C1]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[NEG:.*]] = torch.aten.neg %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[SUB:.*]] = torch.aten.add.Scalar %[[NEG]], %[[C1]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[DIV:.*]] = torch.aten.div.Tensor %[[ADD]], %[[SUB]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[LOG:.*]] = torch.aten.log %[[DIV]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[C2:.*]] = torch.constant.int 2 - // CHECK: torch.aten.div.Scalar %[[LOG]], %[[C2]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> + // CHECK: torch.aten.atanh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.Atanh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } @@ -672,13 +665,7 @@ func.func @test_cos(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5 // CHECK-LABEL: @test_cosh_example func.func @test_cosh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[X:.+]] = torch.aten.exp %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> - // CHECK: %[[NEG:.+]] = torch.aten.neg %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> - // CHECK: %[[Y:.+]] = torch.aten.exp %[[NEG]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> - // CHECK: %[[C1:.+]] = torch.constant.int 1 - // CHECK: %[[ADD:.+]] = torch.aten.add.Tensor %[[X]], %[[Y]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> - // CHECK: %[[C2:.+]] = torch.constant.int 2 - // CHECK: torch.aten.div.Scalar %[[ADD]], %[[C2]] : !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + // CHECK: torch.aten.cosh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> %0 = torch.operator "onnx.Cosh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } @@ -687,13 +674,7 @@ func.func @test_cosh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[ // CHECK-LABEL: @test_cosh func.func @test_cosh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[X:.+]] = torch.aten.exp %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[NEG:.+]] = torch.aten.neg %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[Y:.+]] = torch.aten.exp %[[NEG]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[C1:.+]] = torch.constant.int 1 - // CHECK: %[[ADD:.+]] = torch.aten.add.Tensor %[[X]], %[[Y]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[C2:.+]] = torch.constant.int 2 - // CHECK: torch.aten.div.Scalar %[[ADD]], %[[C2]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> + // CHECK: torch.aten.cosh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.Cosh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } @@ -702,12 +683,7 @@ func.func @test_cosh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, // CHECK-LABEL: @test_acosh_example func.func @test_acosh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[SQUARE:.+]] = torch.aten.square %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> - // CHECK: %[[C1:.+]] = torch.constant.int 1 - // CHECK: %[[SUB:.+]] = torch.aten.sub.Scalar %[[SQUARE]], %[[C1]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.int, !torch.int -> !torch.vtensor<[3],f32> - // CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[SUB]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> - // CHECK: %[[ADD:.+]] = torch.aten.add.Tensor %arg0, %[[SQRT]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> - // CHECK: torch.aten.log %[[ADD]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: torch.aten.acosh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> %0 = torch.operator "onnx.Acosh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } @@ -716,12 +692,7 @@ func.func @test_acosh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor< // CHECK-LABEL: @test_acosh func.func @test_acosh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[SQUARE:.+]] = torch.aten.square %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[C1:.+]] = torch.constant.int 1 - // CHECK: %[[SUB:.+]] = torch.aten.sub.Scalar %[[SQUARE]], %[[C1]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[SUB]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[ADD:.+]] = torch.aten.add.Tensor %arg0, %[[SQRT]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> - // CHECK: torch.aten.log %[[ADD]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: torch.aten.acosh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.Acosh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } @@ -748,12 +719,7 @@ func.func @test_asin(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, // CHECK-LABEL: @test_asinh_example func.func @test_asinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[SQUARE:.+]] = torch.aten.square %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> - // CHECK: %[[C1:.+]] = torch.constant.int 1 - // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %[[SQUARE]], %[[C1]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.int, !torch.int -> !torch.vtensor<[3],f32> - // CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[ADD]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> - // CHECK: %[[ADD_0:.+]] = torch.aten.add.Tensor %arg0, %[[SQRT]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> - // CHECK: torch.aten.log %[[ADD_0]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: torch.aten.asinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> %0 = torch.operator "onnx.Asinh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } @@ -762,12 +728,7 @@ func.func @test_asinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor< // CHECK-LABEL: @test_asinh func.func @test_asinh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[SQUARE:.+]] = torch.aten.square %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[C1:.+]] = torch.constant.int 1 - // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %[[SQUARE]], %[[C1]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[SQRT:.+]] = torch.aten.sqrt %[[ADD]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[ADD_0:.+]] = torch.aten.add.Tensor %arg0, %[[SQRT]], %[[C1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> - // CHECK: torch.aten.log %[[ADD_0]] : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: torch.aten.asinh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.Asinh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 56df1b2e0..53e352b23 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -1341,15 +1341,9 @@ func.func @test_reduce_prod_keepdims_random(%arg0: !torch.vtensor<[3,2,2],f32>, // ----- -// CHECK-LABEL: func.func @test_sinh_example +// CHECK-LABEL: func.func @test_sinh func.func @test_sinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64} { - // CHECK: %[[X:.+]] = torch.aten.exp %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> - // CHECK: %[[NEG:.+]] = torch.aten.neg %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> - // CHECK: %[[Y:.+]] = torch.aten.exp %[[NEG]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> - // CHECK: %[[C1:.+]] = torch.constant.int 1 - // CHECK: %[[SUB:.+]] = torch.aten.sub.Tensor %[[X]], %[[Y]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> - // CHECK: %[[C2:.+]] = torch.constant.int 2 - // CHECK: torch.aten.div.Scalar %[[SUB]], %[[C2]] : !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + // CHECK: torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> %0 = torch.operator "onnx.Sinh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } From 53299eb224bb1b64a84b72f6d21e545e152b9e8f Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Sun, 5 May 2024 19:56:12 +0800 Subject: [PATCH 30/30] [Stablehlo] Bump stablehlo to ab92adeda9119a6c3914cd42367b0a2b70765e91 (#3285) --- externals/stablehlo | 2 +- test/Conversion/TorchToStablehlo/basic.mlir | 10 +++---- .../TorchToStablehlo/elementwise.mlir | 8 ++--- test/Conversion/TorchToStablehlo/gather.mlir | 6 ++-- test/Conversion/TorchToStablehlo/pooling.mlir | 29 ++++++++++--------- test/Conversion/TorchToStablehlo/scatter.mlir | 4 +-- 6 files changed, 31 insertions(+), 28 deletions(-) diff --git a/externals/stablehlo b/externals/stablehlo index 271e8634d..ab92adeda 160000 --- a/externals/stablehlo +++ b/externals/stablehlo @@ -1 +1 @@ -Subproject commit 271e8634de184fbfafd677d3876170feb6d08c97 +Subproject commit ab92adeda9119a6c3914cd42367b0a2b70765e91 diff --git a/test/Conversion/TorchToStablehlo/basic.mlir b/test/Conversion/TorchToStablehlo/basic.mlir index 92888616a..d8ec0fa64 100644 --- a/test/Conversion/TorchToStablehlo/basic.mlir +++ b/test/Conversion/TorchToStablehlo/basic.mlir @@ -55,7 +55,7 @@ func.func @torch.aten.contiguous(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vt // CHECK-LABEL: func.func @torch.aten.reciprocal( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?],f32> -> tensor -// CHECK: %[[VAL_2:.*]] = "chlo.constant_like"(%[[VAL_1]]) {value = 1.000000e+00 : f32} : (tensor) -> tensor +// CHECK: %[[VAL_2:.*]] = "chlo.constant_like"(%[[VAL_1]]) <{value = 1.000000e+00 : f32}> : (tensor) -> tensor // CHECK: %[[VAL_3:.*]] = stablehlo.divide %[[VAL_2]], %[[VAL_1]] : tensor // CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?,?],f32> // CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?,?],f32> @@ -124,7 +124,7 @@ func.func @torch.aten.broadcast_to$dynamic_implicit(%arg0: !torch.vtensor<[?,?], // CHECK: %[[VAL_4:.*]] = arith.constant 1 : index // CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor // CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex> -// CHECK: %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor, tensor<3xf32>, tensor<3xf32>) -> (tensor, tensor<3xf32>, tensor<3xf32>) +// CHECK: %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]]) <{epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64}> : (tensor, tensor<3xf32>, tensor<3xf32>) -> (tensor, tensor<3xf32>, tensor<3xf32>) // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,3,?,?],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[?,3,?,?],f32> func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { @@ -152,7 +152,7 @@ func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>) // CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor // CHECK: %[[T4:.*]] = tensor.from_elements %[[T3]] : tensor<1xindex> // CHECK: %[[T5:.*]] = tensor.cast %[[T0]] : tensor to tensor -// CHECK: %[[T6:.*]] = "stablehlo.batch_norm_inference"(%[[T5]], %[[T2]], %[[T1]], %[[T1]], %[[T2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor +// CHECK: %[[T6:.*]] = "stablehlo.batch_norm_inference"(%[[T5]], %[[T2]], %[[T1]], %[[T1]], %[[T2]]) <{epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64}> : (tensor, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor // CHECK: %[[T7:.*]] = tensor.cast %[[T6]] : tensor to tensor // CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor -> !torch.vtensor<[?,3,?,?],f32> // CHECK: return %[[T8]] : !torch.vtensor<[?,3,?,?],f32> @@ -185,7 +185,7 @@ func.func @torch.aten.batch_norm$inference(%arg0: !torch.vtensor<[?,3,?,?],f32>) // CHECK: %[[VAL_8:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_6]], dims = [] : (tensor, tensor<1xindex>) -> tensor<3xf32> // CHECK: %[[VAL_9:.*]] = stablehlo.constant dense<0.000000e+00> : tensor // CHECK: %[[VAL_10:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_9]], %[[VAL_6]], dims = [] : (tensor, tensor<1xindex>) -> tensor<3xf32> -// CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_8]], %[[VAL_10]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor, tensor<3xf32>, tensor<3xf32>) -> (tensor, tensor<3xf32>, tensor<3xf32>) +// CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_8]], %[[VAL_10]]) <{epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64}> : (tensor, tensor<3xf32>, tensor<3xf32>) -> (tensor, tensor<3xf32>, tensor<3xf32>) // CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor -> !torch.vtensor<[?,3,?,?],f32> // CHECK: return %[[VAL_14]] : !torch.vtensor<[?,3,?,?],f32> func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { @@ -214,7 +214,7 @@ func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?], // CHECK: %[[VAL_6:.*]] = stablehlo.dynamic_reshape %[[VAL_1]], %[[VAL_5]] : (tensor<3x7x4x5xf32>, tensor<3xi64>) -> tensor<1x21x20xf32> // CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<21xf32> // CHECK: %[[VAL_8:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<21xf32> -// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = "stablehlo.batch_norm_training"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -> (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) +// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = "stablehlo.batch_norm_training"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) <{epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64}> : (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -> (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) // CHECK: %[[VAL_12:.*]] = stablehlo.constant dense<[3, 7, 4, 5]> : tensor<4xi64> // CHECK: %[[VAL_13:.*]] = stablehlo.dynamic_reshape %[[VAL_9]], %[[VAL_12]] : (tensor<1x21x20xf32>, tensor<4xi64>) -> tensor<3x7x4x5xf32> // CHECK: %[[VAL_14:.*]] = stablehlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64> diff --git a/test/Conversion/TorchToStablehlo/elementwise.mlir b/test/Conversion/TorchToStablehlo/elementwise.mlir index 814770bd6..08f1c8a18 100644 --- a/test/Conversion/TorchToStablehlo/elementwise.mlir +++ b/test/Conversion/TorchToStablehlo/elementwise.mlir @@ -4,9 +4,9 @@ // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[STR:.*]] = torch.constant.str "none" -// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 1.000000e+00 : f32} : (tensor) -> tensor -// CHECK: %[[T2:.*]] = "chlo.constant_like"(%[[T0]]) {value = 2.000000e+00 : f32} : (tensor) -> tensor -// CHECK: %[[T3:.*]] = "chlo.constant_like"(%[[T0]]) {value = 5.000000e-01 : f32} : (tensor) -> tensor +// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) <{value = 1.000000e+00 : f32}> : (tensor) -> tensor +// CHECK: %[[T2:.*]] = "chlo.constant_like"(%[[T0]]) <{value = 2.000000e+00 : f32}> : (tensor) -> tensor +// CHECK: %[[T3:.*]] = "chlo.constant_like"(%[[T0]]) <{value = 5.000000e-01 : f32}> : (tensor) -> tensor // CHECK: %[[T4:.*]] = stablehlo.rsqrt %[[T2]] : tensor // CHECK: %[[T5:.*]] = stablehlo.multiply %[[T0]], %[[T4]] : tensor // CHECK: %[[T6:.*]] = chlo.erf %[[T5]] : tensor -> tensor @@ -487,7 +487,7 @@ func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch // CHECK-LABEL: func.func @torch.aten.relu( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 0.000000e+00 : f32} : (tensor) -> tensor +// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) <{value = 0.000000e+00 : f32}> : (tensor) -> tensor // CHECK: %[[T2:.*]] = stablehlo.maximum %[[T0]], %[[T1]] : tensor // CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32> diff --git a/test/Conversion/TorchToStablehlo/gather.mlir b/test/Conversion/TorchToStablehlo/gather.mlir index a88b6e375..df29bf1d4 100644 --- a/test/Conversion/TorchToStablehlo/gather.mlir +++ b/test/Conversion/TorchToStablehlo/gather.mlir @@ -10,7 +10,7 @@ // CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor // CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64> -// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather, indices_are_sorted = false} : (tensor, tensor<2xi64>, tensor<2xi64>) -> tensor<2x4xf32> +// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false}> : (tensor, tensor<2xi64>, tensor<2xi64>) -> tensor<2x4xf32> // CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<2x4xf32> // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> // CHECK: return %[[T7]] : !torch.vtensor<[2,4],f32> @@ -31,7 +31,7 @@ func.func @torch.aten.index_select$basic(%arg0: !torch.vtensor<[?,4],f32>, %arg1 // CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor // CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64> -// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather, indices_are_sorted = false} : (tensor, tensor, tensor<2xi64>) -> tensor +// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false}> : (tensor, tensor, tensor<2xi64>) -> tensor // CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T7]] : !torch.vtensor<[?,?],f32> @@ -53,7 +53,7 @@ func.func @torch.aten.embedding$basic(%weight: !torch.vtensor<[?,?],f32>, %indic // CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor // CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64> -// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather, indices_are_sorted = false} : (tensor, tensor, tensor<2xi64>) -> tensor +// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false}> : (tensor, tensor, tensor<2xi64>) -> tensor // CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor -> !torch.vtensor<[?,1,?],f32> // CHECK: return %[[T7]] : !torch.vtensor<[?,1,?],f32> diff --git a/test/Conversion/TorchToStablehlo/pooling.mlir b/test/Conversion/TorchToStablehlo/pooling.mlir index b8fc6cbd8..156c3ff51 100644 --- a/test/Conversion/TorchToStablehlo/pooling.mlir +++ b/test/Conversion/TorchToStablehlo/pooling.mlir @@ -14,11 +14,11 @@ // CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_6:.*]] = stablehlo.constant dense<0xFF800000> : tensor -// CHECK: %[[VAL_7:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_6]]) ({ +// CHECK: %[[VAL_7:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_6]]) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ // CHECK: ^bb0(%[[VAL_8:.*]]: tensor, %[[VAL_9:.*]]: tensor): // CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor // CHECK: stablehlo.return %[[VAL_10]] : tensor -// CHECK: }) {padding = dense<0> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor) -> tensor +// CHECK: }) : (tensor, tensor) -> tensor // CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[VAL_11]] : !torch.vtensor<[?,?,?,?],f32> func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { @@ -46,12 +46,12 @@ func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<0xFF800000> : tensor -// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({ +// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) +// CHECK{LITERAL}: <{padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ // CHECK: ^bb0(%[[VAL_8:.*]]: tensor, %[[VAL_9:.*]]: tensor): // CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor // CHECK: stablehlo.return %[[VAL_10]] : tensor -// CHECK: }) -// CHECK-SAME{LITERAL}: {padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor) -> tensor +// CHECK: }) : (tensor, tensor) -> tensor // CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?,?,?],f32> func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { @@ -96,7 +96,7 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) - // CHECK: %[[T10:.*]] = stablehlo.dynamic_iota %[[FROM_ELEMENTS_2]], dim = 1 : (tensor<2xi64>) -> tensor // CHECK: %[[T11:.*]] = stablehlo.dynamic_reshape %[[T10]], %[[FROM_ELEMENTS]] : (tensor, tensor<3xi64>) -> tensor // CHECK: %[[T12:.*]] = stablehlo.constant dense<0> : tensor -// CHECK: %[[T13:.*]]:2 = "stablehlo.reduce_window"(%[[T0]], %[[T11]], %[[T5]], %[[T12]]) ({ +// CHECK: %[[T13:.*]]:2 = "stablehlo.reduce_window"(%[[T0]], %[[T11]], %[[T5]], %[[T12]]) <{padding = dense<0> : tensor<3x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ // CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor, %[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): // CHECK: %[[T16:.*]] = stablehlo.compare GE, %[[ARG1]], %[[ARG3]], FLOAT : (tensor, tensor) -> tensor // CHECK: %[[T17:.*]] = stablehlo.select %[[T16]], %[[ARG1]], %[[ARG3]] : tensor, tensor @@ -105,7 +105,7 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) - // CHECK: %[[T20:.*]] = stablehlo.select %[[T16]], %[[ARG2]], %[[ARG4]] : tensor, tensor // CHECK: %[[T21:.*]] = stablehlo.select %[[T18]], %[[T19]], %[[T20]] : tensor, tensor // CHECK: stablehlo.return %[[T17]], %[[T21]] : tensor, tensor -// CHECK: }) {padding = dense<0> : tensor<3x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor, tensor, tensor) -> (tensor, tensor) +// CHECK: }) : (tensor, tensor, tensor, tensor) -> (tensor, tensor) // CHECK: %[[T14:.*]] = torch_c.from_builtin_tensor %[[T13]]#0 : tensor -> !torch.vtensor<[?,?,?],f32> // CHECK: %[[T15:.*]] = torch_c.from_builtin_tensor %[[T13]]#1 : tensor -> !torch.vtensor<[?,?,?],si64> // CHECK: return %[[T14]], %[[T15]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64> @@ -137,11 +137,12 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32> // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({ +// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) +// CHECK{LITERAL}: <{padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ // CHECK: ^bb0(%[[IVAL_0:.*]]: tensor, %[[IVAL_1:.*]]: tensor): // CHECK: %[[IVAL_2:.*]] = stablehlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor // CHECK: stablehlo.return %[[IVAL_2]] : tensor -// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor) -> tensor +// CHECK: }) : (tensor, tensor) -> tensor // CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor // CHECK: %[[IDX_0:.*]] = arith.constant 0 : index // CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor @@ -158,11 +159,12 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32> // CHECK: %[[VAL_16:.*]] = tensor.from_elements %[[VAL_9]], %[[VAL_11]], %[[VAL_13]], %[[VAL_15]] : tensor<4xi64> // CHECK: %[[VAL_17:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_16]], dims = [] : (tensor, tensor<4xi64>) -> tensor // CHECK: %[[VAL_18:.*]] = stablehlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[VAL_19:.*]] = "stablehlo.reduce_window"(%[[VAL_17]], %[[VAL_18]]) ({ +// CHECK: %[[VAL_19:.*]] = "stablehlo.reduce_window"(%[[VAL_17]], %[[VAL_18]]) +// CHECK{LITERAL}: <{padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ // CHECK: ^bb0(%[[IVAL_3:.*]]: tensor, %[[IVAL_4:.*]]: tensor): // CHECK: %[[IVAL_5:.*]] = stablehlo.add %[[IVAL_3]], %[[IVAL_4]] : tensor // CHECK: stablehlo.return %[[IVAL_5]] : tensor -// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor) -> tensor +// CHECK: }) : (tensor, tensor) -> tensor // CHECK: %[[VAL_20:.*]] = stablehlo.divide %[[VAL_6]], %[[VAL_19]] : tensor // CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[VAL_21]] : !torch.vtensor<[?,?,?,?],f32> @@ -194,11 +196,12 @@ func.func @torch.aten.avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch // CHECK: %[[T2:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T4:.*]] = stablehlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[T5:.*]] = "stablehlo.reduce_window"(%[[T0]], %[[T4]]) ({ +// CHECK: %[[T5:.*]] = "stablehlo.reduce_window"(%[[T0]], %[[T4]]) +// CHECK{LITERAL}: <{padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ // CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): // CHECK: %[[T10:.*]] = stablehlo.add %[[ARG1]], %[[ARG2]] : tensor // CHECK: stablehlo.return %[[T10]] : tensor -// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor) -> tensor +// CHECK: }) : (tensor, tensor) -> tensor // CHECK: %[[T6:.*]] = stablehlo.constant dense<9> : tensor // CHECK: %[[T7:.*]] = stablehlo.convert %[[T6]] : (tensor) -> tensor // CHECK: %[[T8:.*]] = chlo.broadcast_divide %[[T5]], %[[T7]] : (tensor, tensor) -> tensor diff --git a/test/Conversion/TorchToStablehlo/scatter.mlir b/test/Conversion/TorchToStablehlo/scatter.mlir index 432fc0c86..fe8ffb9ee 100644 --- a/test/Conversion/TorchToStablehlo/scatter.mlir +++ b/test/Conversion/TorchToStablehlo/scatter.mlir @@ -22,10 +22,10 @@ // CHECK: %[[VAR_6:.*]] = stablehlo.dynamic_reshape %1, %[[FE_3]] : (tensor, tensor<3xi64>) -> tensor // CHECK: %[[VAR_7:.*]] = stablehlo.dynamic_iota %[[FE_3]], dim = 1 : (tensor<3xi64>) -> tensor // CHECK: %[[VAR_8:.*]] = stablehlo.concatenate %[[VAR_6]], %[[VAR_7]], dim = 2 : (tensor, tensor) -> tensor -// CHECK: %[[VAR_9:.*]] = "stablehlo.scatter"(%[[VAR_0]], %[[VAR_8]], %[[VAR_5]]) ({ +// CHECK: %[[VAR_9:.*]] = "stablehlo.scatter"(%[[VAR_0]], %[[VAR_8]], %[[VAR_5]]) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ // CHECK: ^bb0(%arg3: tensor, %[[ARG_4:.*]]: tensor): // CHECK: stablehlo.return %[[ARG_4]] : tensor -// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false} : (tensor, tensor, tensor) -> tensor +// CHECK: }) : (tensor, tensor, tensor) -> tensor // CHECK: %[[VAR_10:.*]] = torch_c.from_builtin_tensor %[[VAR_9]] : tensor -> !torch.vtensor<[?,?],si64> // CHECK: return %[[VAR_10]] : !torch.vtensor<[?,?],si64> func.func @forward(%arg0: !torch.vtensor<[?,?],si64>, %arg1: !torch.vtensor<[?,?],si64>, %arg2: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> {