mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] Add `torch.aten.mul.int_float` (required to simplify shape calculation of `upsample_nearest2d`) (#3764)
As per title. See also [PR](https://github.com/llvm/torch-mlir/pull/3750) for `torch.aten.mul.float_int`. --------- Co-authored-by: zjgarvey <47986913+zjgarvey@users.noreply.github.com>pull/909/merge
parent
bdbc64a205
commit
1b8d7e094b
|
@ -15885,6 +15885,31 @@ def Torch_AtenMulIntOp : Torch_Op<"aten.mul.int", [
|
|||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenMulIntFloatOp : Torch_Op<"aten.mul.int_float", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::mul.int_float : (int, float) -> (float)`";
|
||||
let arguments = (ins
|
||||
Torch_IntType:$a,
|
||||
Torch_FloatType:$b
|
||||
);
|
||||
let results = (outs
|
||||
Torch_FloatType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenMulIntFloatOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenMulIntFloatOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenDivIntOp : Torch_Op<"aten.div.int", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -76,6 +76,8 @@ public:
|
|||
Value b = adaptor.getB();
|
||||
if (llvm::is_one_of<AtenOp, AtenAddFloatIntOp>::value)
|
||||
b = convertScalarToDtype(rewriter, op.getLoc(), b, a.getType());
|
||||
if (llvm::is_one_of<AtenOp, AtenMulIntFloatOp>::value)
|
||||
a = convertScalarToDtype(rewriter, op.getLoc(), a, b.getType());
|
||||
rewriter.template replaceOpWithNewOp<BinOp>(op, a, b);
|
||||
return success();
|
||||
}
|
||||
|
@ -487,7 +489,7 @@ public:
|
|||
target.addIllegalOp<AtenNegIntOp>();
|
||||
patterns.add<ConvertAtenNegIntOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenAddIntOp, AtenAddFloatIntOp, AtenSubIntOp,
|
||||
AtenMulIntOp, AtenRemainderIntOp>();
|
||||
AtenMulIntOp, AtenRemainderIntOp, AtenMulIntFloatOp>();
|
||||
patterns.add<ConvertAtenBinaryOp<AtenAddIntOp, arith::AddIOp>>(
|
||||
typeConverter, context);
|
||||
patterns.add<ConvertAtenBinaryOp<AtenRemainderIntOp, arith::RemSIOp>>(
|
||||
|
@ -498,6 +500,8 @@ public:
|
|||
typeConverter, context);
|
||||
patterns.add<ConvertAtenBinaryOp<AtenMulIntOp, arith::MulIOp>>(
|
||||
typeConverter, context);
|
||||
patterns.add<ConvertAtenBinaryOp<AtenMulIntFloatOp, arith::MulFOp>>(
|
||||
typeConverter, context);
|
||||
target.addIllegalOp<AtenSubFloatOp, AtenMulFloatOp>();
|
||||
patterns.add<ConvertAtenBinaryOp<AtenSubFloatOp, arith::SubFOp>>(
|
||||
typeConverter, context);
|
||||
|
|
|
@ -4219,6 +4219,19 @@ OpFoldResult AtenMulOp::fold(FoldAdaptor adaptor) {
|
|||
[](double a, double b) -> double { return a * b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenMulIntFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenMulIntFloatOp::fold(FoldAdaptor adaptor) {
|
||||
if (!adaptor.getA() || !adaptor.getB()) {
|
||||
return nullptr;
|
||||
}
|
||||
return atenBinaryFloatOperatorFoldHelper(
|
||||
adaptor.getOperands(),
|
||||
[](double a, double b) -> double { return a * b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenSubOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -5507,12 +5507,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %16 = torch.aten.__getitem__.t %11, %int0 : !torch.list<float>, !torch.int -> !torch.float\n"
|
||||
" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float \n"
|
||||
" %17 = torch.aten.mul.int_float %15, %16 : !torch.int, !torch.float -> !torch.float\n"
|
||||
" %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n"
|
||||
" %19 = torch.aten.append.t %1, %18 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||
" %20 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %21 = torch.aten.__getitem__.t %11, %int1 : !torch.list<float>, !torch.int -> !torch.float\n"
|
||||
" %22 = torch.operator \"aten.mul.int_float\"(%20, %21) : (!torch.int, !torch.float) -> !torch.float \n"
|
||||
" %22 = torch.aten.mul.int_float %20, %21 : !torch.int, !torch.float -> !torch.float\n"
|
||||
" %23 = torch.aten.Int.float %22 : !torch.float -> !torch.int\n"
|
||||
" %24 = torch.aten.append.t %1, %23 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield\n"
|
||||
|
@ -11184,7 +11184,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %14 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %16 = torch.aten.__getitem__.t %12, %int0 : !torch.list<float>, !torch.int -> !torch.float\n"
|
||||
" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float \n"
|
||||
" %17 = torch.aten.mul.int_float %15, %16 : !torch.int, !torch.float -> !torch.float\n"
|
||||
" %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n"
|
||||
" %19 = torch.prim.ListConstruct %13, %14, %18 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %19 : !torch.list<int>\n"
|
||||
|
@ -11264,11 +11264,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %14 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %16 = torch.aten.__getitem__.t %12, %int0 : !torch.list<float>, !torch.int -> !torch.float\n"
|
||||
" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float \n"
|
||||
" %17 = torch.aten.mul.int_float %15, %16 : !torch.int, !torch.float -> !torch.float\n"
|
||||
" %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n"
|
||||
" %19 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %20 = torch.aten.__getitem__.t %12, %int1 : !torch.list<float>, !torch.int -> !torch.float\n"
|
||||
" %21 = torch.operator \"aten.mul.int_float\"(%19, %20) : (!torch.int, !torch.float) -> !torch.float \n"
|
||||
" %21 = torch.aten.mul.int_float %19, %20 : !torch.int, !torch.float -> !torch.float\n"
|
||||
" %22 = torch.aten.Int.float %21 : !torch.float -> !torch.int\n"
|
||||
" %23 = torch.prim.ListConstruct %13, %14, %18, %22 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %23 : !torch.list<int>\n"
|
||||
|
|
|
@ -1118,6 +1118,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
has_folder=True,
|
||||
has_canonicalizer=True,
|
||||
)
|
||||
emit("aten::mul.int_float : (int, float) -> (float)", has_folder=True)
|
||||
emit("aten::div.int : (int, int) -> (float)", has_folder=True)
|
||||
emit("aten::neg.int : (int) -> (int)", has_folder=True)
|
||||
emit("aten::log.int : (int) -> (float)")
|
||||
|
|
|
@ -236,6 +236,20 @@ func.func @torch.aten.mul.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.in
|
|||
return %0 : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.mul.int_float(
|
||||
// CHECK-SAME: %[[LHS:.*]]: !torch.int,
|
||||
// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float {
|
||||
// CHECK-DAG: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]]
|
||||
// CHECK-DAG: %[[RHS_F64:.*]] = torch_c.to_f64 %[[RHS]]
|
||||
// CHECK: %[[LHS_F64:.*]] = arith.sitofp %[[LHS_I64]] : i64 to f64
|
||||
// CHECK: %[[MUL:.*]] = arith.mulf %[[LHS_F64]], %[[RHS_F64]] : f64
|
||||
// CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[MUL]]
|
||||
// CHECK: return %[[OUT]] : !torch.float
|
||||
func.func @torch.aten.mul.int_float(%arg0: !torch.int, %arg1: !torch.float) -> !torch.float {
|
||||
%0 = torch.aten.mul.int_float %arg0, %arg1 : !torch.int, !torch.float -> !torch.float
|
||||
return %0 : !torch.float
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.div.float(
|
||||
// CHECK-SAME: %[[LHS:.*]]: !torch.float,
|
||||
// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float {
|
||||
|
|
|
@ -1235,6 +1235,16 @@ func.func @torch.aten.mul.int$canonicalize(%arg0: !torch.int) -> !torch.int {
|
|||
return %ret : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.mul.int_float() -> !torch.float {
|
||||
// CHECK: %[[CST6:.*]] = torch.constant.float 6.000000e+00
|
||||
// CHECK: return %[[CST6]] : !torch.float
|
||||
func.func @torch.aten.mul.int_float() -> !torch.float {
|
||||
%cst2 = torch.constant.int 2
|
||||
%cst3 = torch.constant.float 3.0
|
||||
%ret = torch.aten.mul.int_float %cst2, %cst3: !torch.int, !torch.float -> !torch.float
|
||||
return %ret : !torch.float
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.mul.float() -> !torch.float {
|
||||
// CHECK: %[[CST30:.*]] = torch.constant.float 3.000000e+01
|
||||
// CHECK: return %[[CST30]] : !torch.float
|
||||
|
|
Loading…
Reference in New Issue