diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 0e9318753..c38d0dbbd 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4223,6 +4223,52 @@ def Torch_AtenRound_Op : Torch_Op<"aten.round_", [ }]; } +def Torch_AtenTruncOp : Torch_Op<"aten.trunc", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::trunc : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenTruncOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenTruncOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; + let hasFolder = 1; +} + +def Torch_AtenTrunc_Op : Torch_Op<"aten.trunc_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::trunc_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + AnyTorchOptionalNonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenTrunc_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenTrunc_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenSignOp : Torch_Op<"aten.sign", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index e768033ac..a8769def6 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1834,6 +1834,19 @@ OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) { return {}; } +//===----------------------------------------------------------------------===// +// AtenTruncOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenTruncOp::fold(FoldAdaptor adaptor) { + auto resultType = getType().dyn_cast(); + if (resultType && resultType.hasDtype() && + resultType.getDtype().isa()) { + return getSelf(); + } + return {}; +} + //===----------------------------------------------------------------------===// // AtenSignOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 3ac56e933..f4415a480 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6502,6 +6502,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.trunc\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.log\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -10003,6 +10007,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.trunc\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.clamp_max\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %int4 = torch.constant.int 4\n" " %int11 = torch.constant.int 11\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 87f93ba9c..49dd53195 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -5886,6 +5886,32 @@ class DecomposeAtenCosineSimilarityOp }; } // namespace +namespace { +// decompose `trunc(x)` to `sign(x) * floor(abs(x))` +class DecomposeAtenTruncOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenTruncOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + + auto resultTy = dyn_cast(op.getType()); + if (!resultTy || !resultTy.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result must have dtype"); + } + + if (isa(resultTy.getDtype())) { + Value sign = rewriter.create(loc, resultTy, self); + Value abs = rewriter.create(loc, resultTy, self); + Value floor = rewriter.create(loc, resultTy, abs); + rewriter.replaceOpWithNewOp(op, resultTy, sign, floor); + return success(); + } + return failure(); + } +}; +} // namespace + namespace { // Decompose `aten.baddbmm` op into `aten.bmm`, `aten.mul.Scalar`, and // `aten.add.Tensor` op. @@ -7700,6 +7726,7 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 701300fef..e1377afce 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -512,6 +512,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index ba13b7360..45a4b94a5 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1479,6 +1479,8 @@ STABLEHLO_PASS_SET = { "ElementwiseCoshModule_basic", "ElementwiseSinhIntModule_basic", "ElementwiseSinhModule_basic", + "ElementwiseTruncIntModule_basic", + "ElementwiseTruncModule_basic", } STABLEHLO_CRASHING_SET = { @@ -1488,6 +1490,8 @@ STABLEHLO_CRASHING_SET = { # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "ElementwiseTruncModule_basic", + "ElementwiseTruncIntModule_basic", "ElementwiseSgnModule_basic", "ElementwiseSignIntModule_basic", "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", @@ -2344,6 +2348,8 @@ ONNX_XFAIL_SET = { "ElementwiseSinhModule_basic", "ElementwiseCoshIntModule_basic", "ElementwiseCoshModule_basic", + "ElementwiseTruncIntModule_basic", + "ElementwiseTruncModule_basic", "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", "ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 1bdcfdbe9..06962010f 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -245,6 +245,9 @@ def aten〇hardtanh_backward〡shape(grad_output: List[int], self: List[int], mi def aten〇ceil〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇trunc〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇log〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -2227,6 +2230,11 @@ def aten〇ceil〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇trunc〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, max=0)) def aten〇clamp_max〡dtype(self_rank_dtype: Tuple[int, int], max: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 638ec1dd8..e5b219e55 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -359,6 +359,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::ceil : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True) + emit_with_mutating_variants("aten::trunc : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::sign : (Tensor) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index b365ac54f..3aa8f10ff 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -2077,6 +2077,50 @@ def ElementwiseCeilModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseTruncModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 6], torch.float32, True), + ]) + def forward(self, a): + return torch.trunc(a) + + +@register_test_case(module_factory=lambda: ElementwiseTruncModule()) +def ElementwiseTruncModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([[-torch.inf, torch.inf, torch.nan, -2.3, 0.0, 1.5]])) + + +# ============================================================================== + + +class ElementwiseTruncIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.trunc(a) + + +@register_test_case(module_factory=lambda: ElementwiseTruncIntModule()) +def ElementwiseTruncIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-100, high=100).to(torch.int32)) + + +# ============================================================================== + + class ElementwiseSignModule(torch.nn.Module): def __init__(self): diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 2f7d5a11a..4d2a595da 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2308,6 +2308,14 @@ func.func @torch.aten.floor$canonicalize(%arg0: !torch.vtensor<[?,?],si64>) -> ! return %0 : !torch.vtensor<[?,?],si64> } +// CHECK-LABEL: func.func @torch.aten.trunc$canonicalize +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],si64> +// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[?,?],si64> +func.func @torch.aten.trunc$canonicalize(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { + %0 = torch.aten.trunc %arg0 : !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64> + return %0 : !torch.vtensor<[?,?],si64> +} + // CHECK-LABEL: func.func @torch.aten.numel$canonicalize // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,4],f32> // CHECK-NEXT: %int12 = torch.constant.int 12