From 1b8d7e094b39582524e185b808b3f9ee8702f443 Mon Sep 17 00:00:00 2001 From: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com> Date: Wed, 20 Nov 2024 17:43:06 +0100 Subject: [PATCH] [Torch Dialect] Add `torch.aten.mul.int_float` (required to simplify shape calculation of `upsample_nearest2d`) (#3764) As per title. See also [PR](https://github.com/llvm/torch-mlir/pull/3750) for `torch.aten.mul.float_int`. --------- Co-authored-by: zjgarvey <47986913+zjgarvey@users.noreply.github.com> --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 +++++++++++++++++++ lib/Conversion/TorchToArith/TorchToArith.cpp | 6 ++++- lib/Dialect/Torch/IR/TorchOps.cpp | 13 ++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 10 ++++---- .../build_tools/torch_ods_gen.py | 1 + test/Conversion/TorchToArith/basic.mlir | 14 +++++++++++ test/Dialect/Torch/canonicalize.mlir | 10 ++++++++ 7 files changed, 73 insertions(+), 6 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 28764009a..a3bad0e04 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -15885,6 +15885,31 @@ def Torch_AtenMulIntOp : Torch_Op<"aten.mul.int", [ let hasCanonicalizer = 1; } +def Torch_AtenMulIntFloatOp : Torch_Op<"aten.mul.int_float", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::mul.int_float : (int, float) -> (float)`"; + let arguments = (ins + Torch_IntType:$a, + Torch_FloatType:$b + ); + let results = (outs + Torch_FloatType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMulIntFloatOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenMulIntFloatOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; +} + def Torch_AtenDivIntOp : Torch_Op<"aten.div.int", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index 458ea3185..4204cc2b1 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -76,6 +76,8 @@ public: Value b = adaptor.getB(); if (llvm::is_one_of::value) b = convertScalarToDtype(rewriter, op.getLoc(), b, a.getType()); + if (llvm::is_one_of::value) + a = convertScalarToDtype(rewriter, op.getLoc(), a, b.getType()); rewriter.template replaceOpWithNewOp(op, a, b); return success(); } @@ -487,7 +489,7 @@ public: target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); + AtenMulIntOp, AtenRemainderIntOp, AtenMulIntFloatOp>(); patterns.add>( typeConverter, context); patterns.add>( @@ -498,6 +500,8 @@ public: typeConverter, context); patterns.add>( typeConverter, context); + patterns.add>( + typeConverter, context); target.addIllegalOp(); patterns.add>( typeConverter, context); diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index dde9bc130..87d1464e2 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -4219,6 +4219,19 @@ OpFoldResult AtenMulOp::fold(FoldAdaptor adaptor) { [](double a, double b) -> double { return a * b; }); } +//===----------------------------------------------------------------------===// +// AtenMulIntFloatOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenMulIntFloatOp::fold(FoldAdaptor adaptor) { + if (!adaptor.getA() || !adaptor.getB()) { + return nullptr; + } + return atenBinaryFloatOperatorFoldHelper( + adaptor.getOperands(), + [](double a, double b) -> double { return a * b; }); +} + //===----------------------------------------------------------------------===// // AtenSubOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 1cc02a48f..a8ce5ed20 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -5507,12 +5507,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" " %16 = torch.aten.__getitem__.t %11, %int0 : !torch.list, !torch.int -> !torch.float\n" -" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float \n" +" %17 = torch.aten.mul.int_float %15, %16 : !torch.int, !torch.float -> !torch.float\n" " %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n" " %19 = torch.aten.append.t %1, %18 : !torch.list, !torch.int -> !torch.list\n" " %20 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list, !torch.int -> !torch.int\n" " %21 = torch.aten.__getitem__.t %11, %int1 : !torch.list, !torch.int -> !torch.float\n" -" %22 = torch.operator \"aten.mul.int_float\"(%20, %21) : (!torch.int, !torch.float) -> !torch.float \n" +" %22 = torch.aten.mul.int_float %20, %21 : !torch.int, !torch.float -> !torch.float\n" " %23 = torch.aten.Int.float %22 : !torch.float -> !torch.int\n" " %24 = torch.aten.append.t %1, %23 : !torch.list, !torch.int -> !torch.list\n" " torch.prim.If.yield\n" @@ -11184,7 +11184,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %14 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" " %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" " %16 = torch.aten.__getitem__.t %12, %int0 : !torch.list, !torch.int -> !torch.float\n" -" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float \n" +" %17 = torch.aten.mul.int_float %15, %16 : !torch.int, !torch.float -> !torch.float\n" " %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n" " %19 = torch.prim.ListConstruct %13, %14, %18 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" " torch.prim.If.yield %19 : !torch.list\n" @@ -11264,11 +11264,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %14 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" " %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" " %16 = torch.aten.__getitem__.t %12, %int0 : !torch.list, !torch.int -> !torch.float\n" -" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float \n" +" %17 = torch.aten.mul.int_float %15, %16 : !torch.int, !torch.float -> !torch.float\n" " %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n" " %19 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list, !torch.int -> !torch.int\n" " %20 = torch.aten.__getitem__.t %12, %int1 : !torch.list, !torch.int -> !torch.float\n" -" %21 = torch.operator \"aten.mul.int_float\"(%19, %20) : (!torch.int, !torch.float) -> !torch.float \n" +" %21 = torch.aten.mul.int_float %19, %20 : !torch.int, !torch.float -> !torch.float\n" " %22 = torch.aten.Int.float %21 : !torch.float -> !torch.int\n" " %23 = torch.prim.ListConstruct %13, %14, %18, %22 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" " torch.prim.If.yield %23 : !torch.list\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 31916f7fe..07029d089 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -1118,6 +1118,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): has_folder=True, has_canonicalizer=True, ) + emit("aten::mul.int_float : (int, float) -> (float)", has_folder=True) emit("aten::div.int : (int, int) -> (float)", has_folder=True) emit("aten::neg.int : (int) -> (int)", has_folder=True) emit("aten::log.int : (int) -> (float)") diff --git a/test/Conversion/TorchToArith/basic.mlir b/test/Conversion/TorchToArith/basic.mlir index 86ad4e972..88d08d695 100644 --- a/test/Conversion/TorchToArith/basic.mlir +++ b/test/Conversion/TorchToArith/basic.mlir @@ -236,6 +236,20 @@ func.func @torch.aten.mul.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.in return %0 : !torch.int } +// CHECK-LABEL: func.func @torch.aten.mul.int_float( +// CHECK-SAME: %[[LHS:.*]]: !torch.int, +// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float { +// CHECK-DAG: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] +// CHECK-DAG: %[[RHS_F64:.*]] = torch_c.to_f64 %[[RHS]] +// CHECK: %[[LHS_F64:.*]] = arith.sitofp %[[LHS_I64]] : i64 to f64 +// CHECK: %[[MUL:.*]] = arith.mulf %[[LHS_F64]], %[[RHS_F64]] : f64 +// CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[MUL]] +// CHECK: return %[[OUT]] : !torch.float +func.func @torch.aten.mul.int_float(%arg0: !torch.int, %arg1: !torch.float) -> !torch.float { + %0 = torch.aten.mul.int_float %arg0, %arg1 : !torch.int, !torch.float -> !torch.float + return %0 : !torch.float +} + // CHECK-LABEL: func.func @torch.aten.div.float( // CHECK-SAME: %[[LHS:.*]]: !torch.float, // CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float { diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index ef478617d..12778f401 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1235,6 +1235,16 @@ func.func @torch.aten.mul.int$canonicalize(%arg0: !torch.int) -> !torch.int { return %ret : !torch.int } +// CHECK-LABEL: func.func @torch.aten.mul.int_float() -> !torch.float { +// CHECK: %[[CST6:.*]] = torch.constant.float 6.000000e+00 +// CHECK: return %[[CST6]] : !torch.float +func.func @torch.aten.mul.int_float() -> !torch.float { + %cst2 = torch.constant.int 2 + %cst3 = torch.constant.float 3.0 + %ret = torch.aten.mul.int_float %cst2, %cst3: !torch.int, !torch.float -> !torch.float + return %ret : !torch.float +} + // CHECK-LABEL: func.func @torch.aten.mul.float() -> !torch.float { // CHECK: %[[CST30:.*]] = torch.constant.float 3.000000e+01 // CHECK: return %[[CST30]] : !torch.float