diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index babbc9eb4..8b048dcc6 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -11444,7 +11444,7 @@ def Torch_AtenIntTensorOp : Torch_Op<"aten.Int.Tensor", [ printDefaultTorchOp(printer, *this, 1, 1); } }]; - let hasFolder = 1; + let hasCanonicalizer = 1; } def Torch_AtenFloatTensorOp : Torch_Op<"aten.Float.Tensor", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index b5a16fbe3..bfb745f5c 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -150,7 +150,7 @@ static Value getScalarIntValue(Value input, Location loc, if (auto valueTensorLiteralOp = input.getDefiningOp()) { auto val = valueTensorLiteralOp.getValue() - .cast() + .cast() .getSplatValue(); return rewriter.create( loc, rewriter.getI64IntegerAttr(val)); @@ -3646,14 +3646,15 @@ OpFoldResult Aten_ShapeAsTensorOp::fold(FoldAdaptor adaptor) { // AtenIntTensorOp //===----------------------------------------------------------------------===// -OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) { - // If a scalar number is converted to a 0-d tensor and passed on to - // aten.Int.Tensor, fold to the scalar number. - if (auto numToTensorScalar = getA().getDefiningOp()) - return numToTensorScalar.getA(); - if (auto tensorIntOp = getA().getDefiningOp()) - return tensorIntOp.getT(); - return nullptr; +void AtenIntTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenIntTensorOp op, PatternRewriter &rewriter) { + Value scalarInt = getScalarIntValue(op.getA(), op.getLoc(), rewriter); + if (!scalarInt) + return failure(); + rewriter.replaceOp(op, scalarInt); + return success(); + }); } //===----------------------------------------------------------------------===// @@ -4099,6 +4100,9 @@ OpFoldResult PrimNumToTensorScalarOp::fold(FoldAdaptor adaptor) { a = IntegerAttr::get(dty, iattr.getInt()); } else if (auto fattr = dyn_cast(a)) { a = FloatAttr::get(dty, fattr.getValueAsDouble()); + } else { + // doesn't handle other types, like complex type + return {}; } auto mlirTensorType = diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 7106b82bb..7ac5fc96b 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -679,7 +679,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::IntImplicit : (Tensor) -> (int)", has_canonicalizer=True) emit("aten::FloatImplicit : (Tensor) -> (float)", has_canonicalizer=True) emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)") - emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True) + emit("aten::Int.Tensor : (Tensor) -> (int)", has_canonicalizer=True) emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True) emit_with_mutating_variants("aten::dropout : (Tensor, float, bool) -> (Tensor)") emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)") diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index a607365f4..9558e897a 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1425,6 +1425,15 @@ func.func @torch.aten.Int.Tensor(%arg0: !torch.int) -> !torch.int { return %scalar : !torch.int } +// CHECK-LABEL: @torch.aten.Int.Tensor$canonicalize_0d_const() -> !torch.int { +// CHECK: %[[NUM:.*]] = torch.constant.int 1 +// CHECK: return %[[NUM]] : !torch.int +func.func @torch.aten.Int.Tensor$canonicalize_0d_const() -> !torch.int { + %cst = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %scalar = torch.aten.Int.Tensor %cst : !torch.vtensor<[],si64> -> !torch.int + return %scalar : !torch.int +} + // CHECK-LABEL: func.func @torch.aten.Int.float() -> !torch.int { // CHECK: %[[NUM:.*]] = torch.constant.int 1 // CHECK: return %[[NUM]] : !torch.int