mirror of https://github.com/llvm/torch-mlir
parent
f1b8d5e581
commit
6ab990e1e8
|
@ -9222,6 +9222,7 @@ def Torch_AtenIntFloatOp : Torch_Op<"aten.Int.float", [
|
|||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenIntScalarOp : Torch_Op<"aten.Int.Scalar", [
|
||||
|
|
|
@ -232,6 +232,23 @@ public:
|
|||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertTorchConstantIntOp
|
||||
: public OpConversionPattern<Torch::ConstantIntOp> {
|
||||
public:
|
||||
using OpConversionPattern<Torch::ConstantIntOp>::OpConversionPattern;
|
||||
using OpAdaptor = Torch::ConstantIntOp::Adaptor;
|
||||
LogicalResult
|
||||
matchAndRewrite(Torch::ConstantIntOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// note: arith.constant only accept singless integer, so convert singed to
|
||||
// singless
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
||||
op, rewriter.getIntegerAttr(rewriter.getI64Type(),
|
||||
op.getValueAttr().getValue()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
|
@ -381,8 +398,8 @@ public:
|
|||
patterns.add<ConvertTorchConstantOp<Torch::ConstantFloatOp>>(typeConverter,
|
||||
context);
|
||||
target.addIllegalOp<Torch::ConstantIntOp>();
|
||||
patterns.add<ConvertTorchConstantOp<Torch::ConstantIntOp>>(typeConverter,
|
||||
context);
|
||||
patterns.add<ConvertTorchConstantIntOp>(typeConverter, context);
|
||||
|
||||
target.addIllegalOp<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp>();
|
||||
patterns.add<ConvertAtenBinaryOp<AtenAddIntOp, arith::AddIOp>>(
|
||||
typeConverter, context);
|
||||
|
|
|
@ -1330,6 +1330,20 @@ OpFoldResult AtenFloatScalarOp::fold(FoldAdaptor adaptor) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenIntFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenIntFloatOp::fold(FoldAdaptor adaptor) {
|
||||
// Constant fold float -> int conversion.
|
||||
if (auto floatAttr = adaptor.getA().dyn_cast_or_null<FloatAttr>()) {
|
||||
return IntegerAttr::get(
|
||||
mlir::IntegerType::get(getContext(), 64, IntegerType::Signed),
|
||||
static_cast<int64_t>(floatAttr.getValue().convertToDouble()));
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenIntScalarOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -582,7 +582,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
# Type conversion ops.
|
||||
emit("aten::Float.Scalar : (Scalar) -> (float)", has_folder=True)
|
||||
emit("aten::Float.str : (str) -> (float)")
|
||||
emit("aten::Int.float : (float) -> (int)")
|
||||
emit("aten::Int.float : (float) -> (int)", has_folder=True)
|
||||
emit("aten::Int.Scalar : (Scalar) -> (int)", has_folder=True)
|
||||
emit("aten::Int.bool : (bool) -> (int)", has_folder=True)
|
||||
|
||||
|
|
|
@ -2666,6 +2666,25 @@ def LenStrModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class IntFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.value = 1.0
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
])
|
||||
def forward(self):
|
||||
return torch.ops.aten.Int(self.value)
|
||||
|
||||
@register_test_case(module_factory=lambda: IntFloatModule())
|
||||
def IntFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ScalarImplicitFloatModule(torch.nn.Module):
|
||||
|
||||
|
|
|
@ -1271,6 +1271,15 @@ func.func @torch.aten.Int.Tensor(%arg0: !torch.int) -> !torch.int {
|
|||
return %scalar : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.Int.float() -> !torch.int {
|
||||
// CHECK: %[[NUM:.*]] = torch.constant.int 1
|
||||
// CHECK: return %[[NUM]] : !torch.int
|
||||
func.func @torch.aten.Int.float() -> !torch.int {
|
||||
%float1 = torch.constant.float 1.0
|
||||
%int1 = torch.aten.Int.float %float1 : !torch.float -> !torch.int
|
||||
return %int1 : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.Float.Tensor(
|
||||
// CHECK-SAME: %[[NUM:.*]]: !torch.float) -> !torch.float {
|
||||
// CHECK: %[[T:.*]] = torch.prim.NumToTensor.Scalar %[[NUM]] : !torch.float -> !torch.vtensor<[],f64>
|
||||
|
|
Loading…
Reference in New Issue