[Torch Dialect] add folder for aten.Int.float (#1863)

pull/1875/head snapshot-20230211.746
Yuanqiang Liu 2023-02-11 05:59:03 +08:00 committed by GitHub
parent f1b8d5e581
commit 6ab990e1e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 63 additions and 3 deletions

View File

@ -9222,6 +9222,7 @@ def Torch_AtenIntFloatOp : Torch_Op<"aten.Int.float", [
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasFolder = 1;
}
def Torch_AtenIntScalarOp : Torch_Op<"aten.Int.Scalar", [

View File

@ -232,6 +232,23 @@ public:
return success();
}
};
class ConvertTorchConstantIntOp
: public OpConversionPattern<Torch::ConstantIntOp> {
public:
using OpConversionPattern<Torch::ConstantIntOp>::OpConversionPattern;
using OpAdaptor = Torch::ConstantIntOp::Adaptor;
LogicalResult
matchAndRewrite(Torch::ConstantIntOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// note: arith.constant only accept singless integer, so convert singed to
// singless
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, rewriter.getIntegerAttr(rewriter.getI64Type(),
op.getValueAttr().getValue()));
return success();
}
};
} // namespace
namespace {
@ -381,8 +398,8 @@ public:
patterns.add<ConvertTorchConstantOp<Torch::ConstantFloatOp>>(typeConverter,
context);
target.addIllegalOp<Torch::ConstantIntOp>();
patterns.add<ConvertTorchConstantOp<Torch::ConstantIntOp>>(typeConverter,
context);
patterns.add<ConvertTorchConstantIntOp>(typeConverter, context);
target.addIllegalOp<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp>();
patterns.add<ConvertAtenBinaryOp<AtenAddIntOp, arith::AddIOp>>(
typeConverter, context);

View File

@ -1330,6 +1330,20 @@ OpFoldResult AtenFloatScalarOp::fold(FoldAdaptor adaptor) {
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenIntFloatOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenIntFloatOp::fold(FoldAdaptor adaptor) {
// Constant fold float -> int conversion.
if (auto floatAttr = adaptor.getA().dyn_cast_or_null<FloatAttr>()) {
return IntegerAttr::get(
mlir::IntegerType::get(getContext(), 64, IntegerType::Signed),
static_cast<int64_t>(floatAttr.getValue().convertToDouble()));
}
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenIntScalarOp
//===----------------------------------------------------------------------===//

View File

@ -582,7 +582,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
# Type conversion ops.
emit("aten::Float.Scalar : (Scalar) -> (float)", has_folder=True)
emit("aten::Float.str : (str) -> (float)")
emit("aten::Int.float : (float) -> (int)")
emit("aten::Int.float : (float) -> (int)", has_folder=True)
emit("aten::Int.Scalar : (Scalar) -> (int)", has_folder=True)
emit("aten::Int.bool : (bool) -> (int)", has_folder=True)

View File

@ -2666,6 +2666,25 @@ def LenStrModule_basic(module, tu: TestUtils):
# ==============================================================================
class IntFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.value = 1.0
@export
@annotate_args([
None,
])
def forward(self):
return torch.ops.aten.Int(self.value)
@register_test_case(module_factory=lambda: IntFloatModule())
def IntFloatModule_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class ScalarImplicitFloatModule(torch.nn.Module):

View File

@ -1271,6 +1271,15 @@ func.func @torch.aten.Int.Tensor(%arg0: !torch.int) -> !torch.int {
return %scalar : !torch.int
}
// CHECK-LABEL: func.func @torch.aten.Int.float() -> !torch.int {
// CHECK: %[[NUM:.*]] = torch.constant.int 1
// CHECK: return %[[NUM]] : !torch.int
func.func @torch.aten.Int.float() -> !torch.int {
%float1 = torch.constant.float 1.0
%int1 = torch.aten.Int.float %float1 : !torch.float -> !torch.int
return %int1 : !torch.int
}
// CHECK-LABEL: func.func @torch.aten.Float.Tensor(
// CHECK-SAME: %[[NUM:.*]]: !torch.float) -> !torch.float {
// CHECK: %[[T:.*]] = torch.prim.NumToTensor.Scalar %[[NUM]] : !torch.float -> !torch.vtensor<[],f64>