From 24b8c8672ae4366025aa8cb3155dedb58c8e3450 Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Fri, 2 Feb 2024 13:46:33 -0500 Subject: [PATCH] [torch] Add folders for `torch.fill`, `torch.ones`, `torch.zeros` and `aten.getItem` (#2849) So that the CumSum Op in OPT can get the constant that it requires to be lowered to TMTensor --------- Co-authored-by: Rob Suderman Co-authored-by: Xida Ren --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 4 + .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 8 +- lib/Dialect/Torch/IR/TorchOps.cpp | 148 +++++++++++++++++- projects/pt1/e2e_testing/xfail_sets.py | 3 + .../build_tools/torch_ods_gen.py | 8 +- .../torch_mlir_e2e_test/test_suite/basic.py | 8 +- test/Dialect/Torch/canonicalize.mlir | 40 +++++ 7 files changed, 208 insertions(+), 11 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 0ae45798d..a0ec9663b 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -8416,6 +8416,7 @@ def Torch_AtenOnesOp : Torch_Op<"aten.ones", [ printDefaultTorchOp(printer, *this, 5, 1); } }]; + let hasFolder = 1; } def Torch_AtenNewOnesOp : Torch_Op<"aten.new_ones", [ @@ -8471,6 +8472,7 @@ def Torch_AtenZerosOp : Torch_Op<"aten.zeros", [ printDefaultTorchOp(printer, *this, 5, 1); } }]; + let hasFolder = 1; } def Torch_AtenNewZerosOp : Torch_Op<"aten.new_zeros", [ @@ -9858,6 +9860,7 @@ def Torch_AtenItemOp : Torch_Op<"aten.item", [ printDefaultTorchOp(printer, *this, 1, 1); } }]; + let hasFolder = 1; } def Torch_AtenMaskedSelectOp : Torch_Op<"aten.masked_select", [ @@ -11202,6 +11205,7 @@ def Torch_AtenFullOp : Torch_Op<"aten.full", [ printDefaultTorchOp(printer, *this, 6, 1); } }]; + let hasFolder = 1; } def Torch_AtenFullLikeOp : Torch_Op<"aten.full_like", [ diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index e9221ed13..1161b981c 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1089,11 +1089,9 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, "expected result type to have a dtype"); } // resultTensorType.print(llvm::outs()); - Value resultDType = Torch::getDtypeIntValueForType( - rewriter, loc, resultTensorType.getDtype()); - - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand, dim, resultDType); + Value none = rewriter.create(loc); + rewriter.replaceOpWithNewOp(binder.op, resultType, + operand, dim, none); return success(); }); patterns.onOp( diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 5877a3549..98de4f85b 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -6,9 +6,10 @@ // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// - +#define DEBUG_TYPE "torch-mlir-torch-dialect" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "llvm/Support/Debug.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Builders.h" @@ -2813,6 +2814,151 @@ OpFoldResult AtenDivIntOp::fold(FoldAdaptor adaptor) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenItemOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenItemOp::fold(FoldAdaptor adaptor) { + // see if we have a constant tensor + DenseElementsAttr attr; + if (matchPattern(getOperand(), m_Constant(&attr))) { + auto splat = attr.getSplatValue(); + if (auto intAttr = dyn_cast(splat)) { + return getI64IntegerAttr(getContext(), intAttr.getSInt()); + } + if (auto floatAttr = dyn_cast(splat)) { + return getF64FloatAttr(getContext(), floatAttr.getValueAsDouble()); + } + return nullptr; + } + + return nullptr; +} + +//===----------------------------------------------------------------------===// +// AtenOnesOp, AtenZerosOp, AtenFullOp +//===----------------------------------------------------------------------===// +OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) { + SmallVector sizes; + if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) { + LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenOnesOp: size operand is " + "not a list of constant integers.\n"); + return nullptr; + } + + Type resultType = getResult().getType(); + BaseTensorType resultTensorType = resultType.dyn_cast(); + if (!resultTensorType || !resultTensorType.hasDtype()) { + LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenOnesOp: result type is not " + "a tensor type or does not have a dtype.\n"); + return nullptr; + } + + ShapedType shapedty = + mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType + sizes, resultTensorType.getDtype()); + if (!shapedty) { + LLVM_DEBUG(llvm::dbgs() + << "Failing to fold AtenOnesOp: ShapedType cast failed.\n"); + return nullptr; + } + auto elementType = shapedty.getElementType(); + if (elementType.isa()) { + Attribute attribute = IntegerAttr::get(elementType, 1); + return DenseElementsAttr::get(shapedty, attribute); + } + if (elementType.isa()) { + Attribute attribute = FloatAttr::get(elementType, 1.0); + return DenseElementsAttr::get(shapedty, attribute); + } + LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenOnesOp: element type is " + "not integer or float.\n"); + return nullptr; +} + +OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) { + SmallVector sizes; + if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) { + LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenZerosOp: size operand is " + "not a list of constant integers.\n"); + return nullptr; + } + + Type resultType = getResult().getType(); + BaseTensorType resultTensorType = resultType.dyn_cast(); + if (!resultTensorType || !resultTensorType.hasDtype()) { + LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenZerosOp: result type is " + "not a tensor type or does not have a dtype.\n"); + return nullptr; + } + + ShapedType shapedty = + mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType + sizes, resultTensorType.getDtype()); + if (!shapedty) { + LLVM_DEBUG(llvm::dbgs() + << "Failing to fold AtenZerosOp: ShapedType cast failed.\n"); + return nullptr; + } + + auto elementType = shapedty.getElementType(); + if (elementType.isa()) { + Attribute attribute = IntegerAttr::get(elementType, 0); + return DenseElementsAttr::get(shapedty, attribute); + } + if (elementType.isa()) { + Attribute attribute = FloatAttr::get(elementType, 0.0); + return DenseElementsAttr::get(shapedty, attribute); + } + + LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenZerosOp: element type is " + "not integer or float.\n"); + return nullptr; +} + +OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) { + SmallVector sizes; + if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) { + LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenFullOp: size operand is " + "not a list of constant integers.\n"); + return nullptr; + } + + Type resultType = getResult().getType(); + BaseTensorType resultTensorType = resultType.dyn_cast(); + if (!resultTensorType || !resultTensorType.hasDtype()) { + LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenFullOp: result type is not " + "a tensor type or does not have a dtype.\n"); + return nullptr; + } + + ShapedType shapedty = + mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType + sizes, resultTensorType.getDtype()); + if (!shapedty) { + LLVM_DEBUG(llvm::dbgs() + << "Failing to fold AtenFullOp: ShapedType cast failed.\n"); + return nullptr; + } + auto elementType = shapedty.getElementType(); + if (elementType.isa()) { + int64_t value = 0; + if (matchPattern(getFillValue(), m_TorchConstantInt(&value))) { + Attribute attribute = IntegerAttr::get(elementType, value); + return DenseElementsAttr::get(shapedty, attribute); + } + } + if (elementType.isa()) { + double value = 0.0; + if (matchPattern(getFillValue(), m_TorchConstantFloat(&value))) { + Attribute attribute = FloatAttr::get(elementType, value); + return DenseElementsAttr::get(shapedty, attribute); + } + } + LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenFullOp: element type is " + "not integer or float.\n"); + return nullptr; +} //===----------------------------------------------------------------------===// // AtenCeilFloatOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index d7dba54a0..2ee5d279a 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -26,6 +26,9 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | { TORCHDYNAMO_XFAIL_SET = { #### General TorchDynamo/PyTorch errors + # torch._dynamo.exc.Unsupported: Tensor.item + "CumsumModule_basic", + # TypeError: new_empty(): argument 'size' (position 1) must be tuple of ints, but found element of type NoneType at pos 0 # RuntimeError: Failed running call_function aten.convolution_backward(... # https://github.com/pytorch/pytorch/issues/89629 diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 43635bf2f..41a297ba6 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -564,9 +564,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True) emit("aten::Bool.Tensor : (Tensor) -> (bool)") emit("aten::is_floating_point : (Tensor) -> (bool)", has_folder=True) - emit("aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)") + emit("aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)", has_folder=True) emit("aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)") - emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)") + emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)", has_folder=True) emit("aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::eye : (int, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::eye.m : (int, int, int?, int?, Device?, bool?) -> (Tensor)") @@ -618,7 +618,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)") emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)") emit_with_mutating_variants("aten::_index_put_impl : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)") - emit("aten::item : (Tensor) -> (Scalar)") + emit("aten::item : (Tensor) -> (Scalar)", has_folder=True) emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)") emit("aten::numel : (Tensor) -> (int)", has_canonicalizer=True) emit("aten::repeat : (Tensor, int[]) -> (Tensor)") @@ -669,7 +669,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)") emit("aten::t : (Tensor) -> (Tensor)") emit("aten::numpy_T : (Tensor) -> (Tensor)") - emit("aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)") + emit("aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)", has_folder=True) emit("aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)") emit("aten::new_full : (Tensor, int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)") emit_with_mutating_variants("aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 91c311213..c73d706f2 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -4092,7 +4092,13 @@ class CumsumModule(torch.nn.Module): ([-1, -1, -1], torch.float32, True), ]) def forward(self, val): - return torch.ops.aten.cumsum(val, 1) + # the onnx cumsum op uses a constant 1d tensor + # to specify the dimension along which to do cumsum + # we replicate that here to ensure that cumsum correctly + # trigger the relevant folders and provides TMTensor + # with a constant dimension + ones = torch.ones([1], dtype=torch.int32) + return torch.ops.aten.cumsum(val, ones.item()) @register_test_case(module_factory=lambda: CumsumModule()) def CumsumModule_basic(module, tu: TestUtils): diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 3cf82d9ed..cb2ec2d14 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -29,6 +29,46 @@ func.func @torch.runtime.assert() { return } +// CHECK-LABEL: func.func @torch.aten.ones_item +// CHECK: %[[CONST:.*]] = torch.constant.int 1 +// CHECK: return %[[CONST]] : !torch.int +func.func @torch.aten.ones_item() -> !torch.int { + %int1 = torch.constant.int 1 + %int3 = torch.constant.int 3 + %none = torch.constant.none + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.aten.ones %0, %int3, %none, %none, %none : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si32> + %2 = torch.aten.item %1 : !torch.vtensor<[1],si32> -> !torch.int + return %2 : !torch.int +} + +// CHECK-LABEL: func.func @torch.aten.zeros_item +// CHECK: %[[CONST:.*]] = torch.constant.int 0 +// CHECK: return %[[CONST]] : !torch.int +func.func @torch.aten.zeros_item() -> !torch.int { + %int1 = torch.constant.int 1 + %int3 = torch.constant.int 3 + %none = torch.constant.none + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.aten.zeros %0, %int3, %none, %none, %none : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si32> + %2 = torch.aten.item %1 : !torch.vtensor<[1],si32> -> !torch.int + return %2 : !torch.int +} + +// CHECK-LABEL: func.func @torch.aten.full_item +// CHECK: %[[CONST:.*]] = torch.constant.int 1337 +// CHECK: return %[[CONST]] : !torch.int +func.func @torch.aten.full_item() -> !torch.int { + %int1 = torch.constant.int 1 + %int3 = torch.constant.int 1337 + %int5 = torch.constant.int 5 + %none = torch.constant.none + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.aten.full %0, %int3, %int5, %none, %none, %none : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si32> + %2 = torch.aten.item %1 : !torch.vtensor<[1],si32> -> !torch.int + return %2 : !torch.int +} + // CHECK-LABEL: func.func @torch.aten.is_floating_point$fold_true // CHECK: %[[TRUE:.*]] = torch.constant.bool true // CHECK: return %[[TRUE]] : !torch.bool