[Torch Dialect] Add `torch.aten.mul.int_float` (required to simplify shape calculation of `upsample_nearest2d`) (#3764)

As per title. See also
[PR](https://github.com/llvm/torch-mlir/pull/3750) for
`torch.aten.mul.float_int`.

---------

Co-authored-by: zjgarvey <47986913+zjgarvey@users.noreply.github.com>
pull/909/merge
Giacomo Serafini 2024-11-20 17:43:06 +01:00 committed by GitHub
parent bdbc64a205
commit 1b8d7e094b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 73 additions and 6 deletions

View File

@ -15885,6 +15885,31 @@ def Torch_AtenMulIntOp : Torch_Op<"aten.mul.int", [
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
} }
def Torch_AtenMulIntFloatOp : Torch_Op<"aten.mul.int_float", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::mul.int_float : (int, float) -> (float)`";
let arguments = (ins
Torch_IntType:$a,
Torch_FloatType:$b
);
let results = (outs
Torch_FloatType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMulIntFloatOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenMulIntFloatOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
}
def Torch_AtenDivIntOp : Torch_Op<"aten.div.int", [ def Torch_AtenDivIntOp : Torch_Op<"aten.div.int", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -76,6 +76,8 @@ public:
Value b = adaptor.getB(); Value b = adaptor.getB();
if (llvm::is_one_of<AtenOp, AtenAddFloatIntOp>::value) if (llvm::is_one_of<AtenOp, AtenAddFloatIntOp>::value)
b = convertScalarToDtype(rewriter, op.getLoc(), b, a.getType()); b = convertScalarToDtype(rewriter, op.getLoc(), b, a.getType());
if (llvm::is_one_of<AtenOp, AtenMulIntFloatOp>::value)
a = convertScalarToDtype(rewriter, op.getLoc(), a, b.getType());
rewriter.template replaceOpWithNewOp<BinOp>(op, a, b); rewriter.template replaceOpWithNewOp<BinOp>(op, a, b);
return success(); return success();
} }
@ -487,7 +489,7 @@ public:
target.addIllegalOp<AtenNegIntOp>(); target.addIllegalOp<AtenNegIntOp>();
patterns.add<ConvertAtenNegIntOp>(typeConverter, context); patterns.add<ConvertAtenNegIntOp>(typeConverter, context);
target.addIllegalOp<AtenAddIntOp, AtenAddFloatIntOp, AtenSubIntOp, target.addIllegalOp<AtenAddIntOp, AtenAddFloatIntOp, AtenSubIntOp,
AtenMulIntOp, AtenRemainderIntOp>(); AtenMulIntOp, AtenRemainderIntOp, AtenMulIntFloatOp>();
patterns.add<ConvertAtenBinaryOp<AtenAddIntOp, arith::AddIOp>>( patterns.add<ConvertAtenBinaryOp<AtenAddIntOp, arith::AddIOp>>(
typeConverter, context); typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenRemainderIntOp, arith::RemSIOp>>( patterns.add<ConvertAtenBinaryOp<AtenRemainderIntOp, arith::RemSIOp>>(
@ -498,6 +500,8 @@ public:
typeConverter, context); typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenMulIntOp, arith::MulIOp>>( patterns.add<ConvertAtenBinaryOp<AtenMulIntOp, arith::MulIOp>>(
typeConverter, context); typeConverter, context);
patterns.add<ConvertAtenBinaryOp<AtenMulIntFloatOp, arith::MulFOp>>(
typeConverter, context);
target.addIllegalOp<AtenSubFloatOp, AtenMulFloatOp>(); target.addIllegalOp<AtenSubFloatOp, AtenMulFloatOp>();
patterns.add<ConvertAtenBinaryOp<AtenSubFloatOp, arith::SubFOp>>( patterns.add<ConvertAtenBinaryOp<AtenSubFloatOp, arith::SubFOp>>(
typeConverter, context); typeConverter, context);

View File

@ -4219,6 +4219,19 @@ OpFoldResult AtenMulOp::fold(FoldAdaptor adaptor) {
[](double a, double b) -> double { return a * b; }); [](double a, double b) -> double { return a * b; });
} }
//===----------------------------------------------------------------------===//
// AtenMulIntFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenMulIntFloatOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getA() || !adaptor.getB()) {
return nullptr;
}
return atenBinaryFloatOperatorFoldHelper(
adaptor.getOperands(),
[](double a, double b) -> double { return a * b; });
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenSubOp // AtenSubOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -5507,12 +5507,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n" " }\n"
" %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list<int>, !torch.int -> !torch.int\n" " %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %16 = torch.aten.__getitem__.t %11, %int0 : !torch.list<float>, !torch.int -> !torch.float\n" " %16 = torch.aten.__getitem__.t %11, %int0 : !torch.list<float>, !torch.int -> !torch.float\n"
" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float \n" " %17 = torch.aten.mul.int_float %15, %16 : !torch.int, !torch.float -> !torch.float\n"
" %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n" " %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n"
" %19 = torch.aten.append.t %1, %18 : !torch.list<int>, !torch.int -> !torch.list<int>\n" " %19 = torch.aten.append.t %1, %18 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" %20 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list<int>, !torch.int -> !torch.int\n" " %20 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %21 = torch.aten.__getitem__.t %11, %int1 : !torch.list<float>, !torch.int -> !torch.float\n" " %21 = torch.aten.__getitem__.t %11, %int1 : !torch.list<float>, !torch.int -> !torch.float\n"
" %22 = torch.operator \"aten.mul.int_float\"(%20, %21) : (!torch.int, !torch.float) -> !torch.float \n" " %22 = torch.aten.mul.int_float %20, %21 : !torch.int, !torch.float -> !torch.float\n"
" %23 = torch.aten.Int.float %22 : !torch.float -> !torch.int\n" " %23 = torch.aten.Int.float %22 : !torch.float -> !torch.int\n"
" %24 = torch.aten.append.t %1, %23 : !torch.list<int>, !torch.int -> !torch.list<int>\n" " %24 = torch.aten.append.t %1, %23 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.If.yield\n" " torch.prim.If.yield\n"
@ -11184,7 +11184,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %14 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n" " %14 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list<int>, !torch.int -> !torch.int\n" " %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %16 = torch.aten.__getitem__.t %12, %int0 : !torch.list<float>, !torch.int -> !torch.float\n" " %16 = torch.aten.__getitem__.t %12, %int0 : !torch.list<float>, !torch.int -> !torch.float\n"
" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float \n" " %17 = torch.aten.mul.int_float %15, %16 : !torch.int, !torch.float -> !torch.float\n"
" %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n" " %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n"
" %19 = torch.prim.ListConstruct %13, %14, %18 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n" " %19 = torch.prim.ListConstruct %13, %14, %18 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %19 : !torch.list<int>\n" " torch.prim.If.yield %19 : !torch.list<int>\n"
@ -11264,11 +11264,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %14 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n" " %14 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list<int>, !torch.int -> !torch.int\n" " %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %16 = torch.aten.__getitem__.t %12, %int0 : !torch.list<float>, !torch.int -> !torch.float\n" " %16 = torch.aten.__getitem__.t %12, %int0 : !torch.list<float>, !torch.int -> !torch.float\n"
" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float \n" " %17 = torch.aten.mul.int_float %15, %16 : !torch.int, !torch.float -> !torch.float\n"
" %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n" " %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n"
" %19 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list<int>, !torch.int -> !torch.int\n" " %19 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %20 = torch.aten.__getitem__.t %12, %int1 : !torch.list<float>, !torch.int -> !torch.float\n" " %20 = torch.aten.__getitem__.t %12, %int1 : !torch.list<float>, !torch.int -> !torch.float\n"
" %21 = torch.operator \"aten.mul.int_float\"(%19, %20) : (!torch.int, !torch.float) -> !torch.float \n" " %21 = torch.aten.mul.int_float %19, %20 : !torch.int, !torch.float -> !torch.float\n"
" %22 = torch.aten.Int.float %21 : !torch.float -> !torch.int\n" " %22 = torch.aten.Int.float %21 : !torch.float -> !torch.int\n"
" %23 = torch.prim.ListConstruct %13, %14, %18, %22 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n" " %23 = torch.prim.ListConstruct %13, %14, %18, %22 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" torch.prim.If.yield %23 : !torch.list<int>\n" " torch.prim.If.yield %23 : !torch.list<int>\n"

View File

@ -1118,6 +1118,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
has_folder=True, has_folder=True,
has_canonicalizer=True, has_canonicalizer=True,
) )
emit("aten::mul.int_float : (int, float) -> (float)", has_folder=True)
emit("aten::div.int : (int, int) -> (float)", has_folder=True) emit("aten::div.int : (int, int) -> (float)", has_folder=True)
emit("aten::neg.int : (int) -> (int)", has_folder=True) emit("aten::neg.int : (int) -> (int)", has_folder=True)
emit("aten::log.int : (int) -> (float)") emit("aten::log.int : (int) -> (float)")

View File

@ -236,6 +236,20 @@ func.func @torch.aten.mul.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.in
return %0 : !torch.int return %0 : !torch.int
} }
// CHECK-LABEL: func.func @torch.aten.mul.int_float(
// CHECK-SAME: %[[LHS:.*]]: !torch.int,
// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float {
// CHECK-DAG: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
// CHECK-DAG: %[[RHS_F64:.*]] = torch_c.to_f64 %[[RHS]]
// CHECK: %[[LHS_F64:.*]] = arith.sitofp %[[LHS_I64]] : i64 to f64
// CHECK: %[[MUL:.*]] = arith.mulf %[[LHS_F64]], %[[RHS_F64]] : f64
// CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[MUL]]
// CHECK: return %[[OUT]] : !torch.float
func.func @torch.aten.mul.int_float(%arg0: !torch.int, %arg1: !torch.float) -> !torch.float {
%0 = torch.aten.mul.int_float %arg0, %arg1 : !torch.int, !torch.float -> !torch.float
return %0 : !torch.float
}
// CHECK-LABEL: func.func @torch.aten.div.float( // CHECK-LABEL: func.func @torch.aten.div.float(
// CHECK-SAME: %[[LHS:.*]]: !torch.float, // CHECK-SAME: %[[LHS:.*]]: !torch.float,
// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float { // CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float {

View File

@ -1235,6 +1235,16 @@ func.func @torch.aten.mul.int$canonicalize(%arg0: !torch.int) -> !torch.int {
return %ret : !torch.int return %ret : !torch.int
} }
// CHECK-LABEL: func.func @torch.aten.mul.int_float() -> !torch.float {
// CHECK: %[[CST6:.*]] = torch.constant.float 6.000000e+00
// CHECK: return %[[CST6]] : !torch.float
func.func @torch.aten.mul.int_float() -> !torch.float {
%cst2 = torch.constant.int 2
%cst3 = torch.constant.float 3.0
%ret = torch.aten.mul.int_float %cst2, %cst3: !torch.int, !torch.float -> !torch.float
return %ret : !torch.float
}
// CHECK-LABEL: func.func @torch.aten.mul.float() -> !torch.float { // CHECK-LABEL: func.func @torch.aten.mul.float() -> !torch.float {
// CHECK: %[[CST30:.*]] = torch.constant.float 3.000000e+01 // CHECK: %[[CST30:.*]] = torch.constant.float 3.000000e+01
// CHECK: return %[[CST30]] : !torch.float // CHECK: return %[[CST30]] : !torch.float