[torch] Add folders for `torch.fill`, `torch.ones`, `torch.zeros` and `aten.getItem` (#2849)

So that the CumSum Op in OPT can get the constant that it requires to be lowered to TMTensor

---------

Co-authored-by: Rob Suderman <rob.suderman@gmail.com>
Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>
pull/2856/head
Xida Ren (Cedar) 2024-02-02 13:46:33 -05:00 committed by GitHub
parent 962d514308
commit 24b8c8672a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 208 additions and 11 deletions

View File

@ -8416,6 +8416,7 @@ def Torch_AtenOnesOp : Torch_Op<"aten.ones", [
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
let hasFolder = 1;
}
def Torch_AtenNewOnesOp : Torch_Op<"aten.new_ones", [
@ -8471,6 +8472,7 @@ def Torch_AtenZerosOp : Torch_Op<"aten.zeros", [
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
let hasFolder = 1;
}
def Torch_AtenNewZerosOp : Torch_Op<"aten.new_zeros", [
@ -9858,6 +9860,7 @@ def Torch_AtenItemOp : Torch_Op<"aten.item", [
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasFolder = 1;
}
def Torch_AtenMaskedSelectOp : Torch_Op<"aten.masked_select", [
@ -11202,6 +11205,7 @@ def Torch_AtenFullOp : Torch_Op<"aten.full", [
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
let hasFolder = 1;
}
def Torch_AtenFullLikeOp : Torch_Op<"aten.full_like", [

View File

@ -1089,11 +1089,9 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.op, "expected result type to have a dtype");
}
// resultTensorType.print(llvm::outs());
Value resultDType = Torch::getDtypeIntValueForType(
rewriter, loc, resultTensorType.getDtype());
rewriter.replaceOpWithNewOp<Torch::AtenCumsumOp>(
binder.op, resultType, operand, dim, resultDType);
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
rewriter.replaceOpWithNewOp<Torch::AtenCumsumOp>(binder.op, resultType,
operand, dim, none);
return success();
});
patterns.onOp(

View File

@ -6,9 +6,10 @@
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#define DEBUG_TYPE "torch-mlir-torch-dialect"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
@ -2813,6 +2814,151 @@ OpFoldResult AtenDivIntOp::fold(FoldAdaptor adaptor) {
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenItemOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenItemOp::fold(FoldAdaptor adaptor) {
// see if we have a constant tensor
DenseElementsAttr attr;
if (matchPattern(getOperand(), m_Constant(&attr))) {
auto splat = attr.getSplatValue<Attribute>();
if (auto intAttr = dyn_cast<IntegerAttr>(splat)) {
return getI64IntegerAttr(getContext(), intAttr.getSInt());
}
if (auto floatAttr = dyn_cast<FloatAttr>(splat)) {
return getF64FloatAttr(getContext(), floatAttr.getValueAsDouble());
}
return nullptr;
}
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenOnesOp, AtenZerosOp, AtenFullOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) {
SmallVector<int64_t> sizes;
if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) {
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenOnesOp: size operand is "
"not a list of constant integers.\n");
return nullptr;
}
Type resultType = getResult().getType();
BaseTensorType resultTensorType = resultType.dyn_cast<BaseTensorType>();
if (!resultTensorType || !resultTensorType.hasDtype()) {
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenOnesOp: result type is not "
"a tensor type or does not have a dtype.\n");
return nullptr;
}
ShapedType shapedty =
mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType
sizes, resultTensorType.getDtype());
if (!shapedty) {
LLVM_DEBUG(llvm::dbgs()
<< "Failing to fold AtenOnesOp: ShapedType cast failed.\n");
return nullptr;
}
auto elementType = shapedty.getElementType();
if (elementType.isa<IntegerType>()) {
Attribute attribute = IntegerAttr::get(elementType, 1);
return DenseElementsAttr::get(shapedty, attribute);
}
if (elementType.isa<FloatType>()) {
Attribute attribute = FloatAttr::get(elementType, 1.0);
return DenseElementsAttr::get(shapedty, attribute);
}
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenOnesOp: element type is "
"not integer or float.\n");
return nullptr;
}
OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) {
SmallVector<int64_t> sizes;
if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) {
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenZerosOp: size operand is "
"not a list of constant integers.\n");
return nullptr;
}
Type resultType = getResult().getType();
BaseTensorType resultTensorType = resultType.dyn_cast<BaseTensorType>();
if (!resultTensorType || !resultTensorType.hasDtype()) {
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenZerosOp: result type is "
"not a tensor type or does not have a dtype.\n");
return nullptr;
}
ShapedType shapedty =
mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType
sizes, resultTensorType.getDtype());
if (!shapedty) {
LLVM_DEBUG(llvm::dbgs()
<< "Failing to fold AtenZerosOp: ShapedType cast failed.\n");
return nullptr;
}
auto elementType = shapedty.getElementType();
if (elementType.isa<IntegerType>()) {
Attribute attribute = IntegerAttr::get(elementType, 0);
return DenseElementsAttr::get(shapedty, attribute);
}
if (elementType.isa<FloatType>()) {
Attribute attribute = FloatAttr::get(elementType, 0.0);
return DenseElementsAttr::get(shapedty, attribute);
}
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenZerosOp: element type is "
"not integer or float.\n");
return nullptr;
}
OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) {
SmallVector<int64_t> sizes;
if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) {
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenFullOp: size operand is "
"not a list of constant integers.\n");
return nullptr;
}
Type resultType = getResult().getType();
BaseTensorType resultTensorType = resultType.dyn_cast<BaseTensorType>();
if (!resultTensorType || !resultTensorType.hasDtype()) {
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenFullOp: result type is not "
"a tensor type or does not have a dtype.\n");
return nullptr;
}
ShapedType shapedty =
mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType
sizes, resultTensorType.getDtype());
if (!shapedty) {
LLVM_DEBUG(llvm::dbgs()
<< "Failing to fold AtenFullOp: ShapedType cast failed.\n");
return nullptr;
}
auto elementType = shapedty.getElementType();
if (elementType.isa<IntegerType>()) {
int64_t value = 0;
if (matchPattern(getFillValue(), m_TorchConstantInt(&value))) {
Attribute attribute = IntegerAttr::get(elementType, value);
return DenseElementsAttr::get(shapedty, attribute);
}
}
if (elementType.isa<FloatType>()) {
double value = 0.0;
if (matchPattern(getFillValue(), m_TorchConstantFloat(&value))) {
Attribute attribute = FloatAttr::get(elementType, value);
return DenseElementsAttr::get(shapedty, attribute);
}
}
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenFullOp: element type is "
"not integer or float.\n");
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenCeilFloatOp
//===----------------------------------------------------------------------===//

View File

@ -26,6 +26,9 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
TORCHDYNAMO_XFAIL_SET = {
#### General TorchDynamo/PyTorch errors
# torch._dynamo.exc.Unsupported: Tensor.item
"CumsumModule_basic",
# TypeError: new_empty(): argument 'size' (position 1) must be tuple of ints, but found element of type NoneType at pos 0
# RuntimeError: Failed running call_function aten.convolution_backward(...
# https://github.com/pytorch/pytorch/issues/89629

View File

@ -564,9 +564,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True)
emit("aten::Bool.Tensor : (Tensor) -> (bool)")
emit("aten::is_floating_point : (Tensor) -> (bool)", has_folder=True)
emit("aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)", has_folder=True)
emit("aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)", has_folder=True)
emit("aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::eye : (int, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::eye.m : (int, int, int?, int?, Device?, bool?) -> (Tensor)")
@ -618,7 +618,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)")
emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)")
emit_with_mutating_variants("aten::_index_put_impl : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)")
emit("aten::item : (Tensor) -> (Scalar)")
emit("aten::item : (Tensor) -> (Scalar)", has_folder=True)
emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)")
emit("aten::numel : (Tensor) -> (int)", has_canonicalizer=True)
emit("aten::repeat : (Tensor, int[]) -> (Tensor)")
@ -669,7 +669,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)")
emit("aten::t : (Tensor) -> (Tensor)")
emit("aten::numpy_T : (Tensor) -> (Tensor)")
emit("aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)", has_folder=True)
emit("aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::new_full : (Tensor, int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit_with_mutating_variants("aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)")

View File

@ -4092,7 +4092,13 @@ class CumsumModule(torch.nn.Module):
([-1, -1, -1], torch.float32, True),
])
def forward(self, val):
return torch.ops.aten.cumsum(val, 1)
# the onnx cumsum op uses a constant 1d tensor
# to specify the dimension along which to do cumsum
# we replicate that here to ensure that cumsum correctly
# trigger the relevant folders and provides TMTensor
# with a constant dimension
ones = torch.ones([1], dtype=torch.int32)
return torch.ops.aten.cumsum(val, ones.item())
@register_test_case(module_factory=lambda: CumsumModule())
def CumsumModule_basic(module, tu: TestUtils):

View File

@ -29,6 +29,46 @@ func.func @torch.runtime.assert() {
return
}
// CHECK-LABEL: func.func @torch.aten.ones_item
// CHECK: %[[CONST:.*]] = torch.constant.int 1
// CHECK: return %[[CONST]] : !torch.int
func.func @torch.aten.ones_item() -> !torch.int {
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%none = torch.constant.none
%0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%1 = torch.aten.ones %0, %int3, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si32>
%2 = torch.aten.item %1 : !torch.vtensor<[1],si32> -> !torch.int
return %2 : !torch.int
}
// CHECK-LABEL: func.func @torch.aten.zeros_item
// CHECK: %[[CONST:.*]] = torch.constant.int 0
// CHECK: return %[[CONST]] : !torch.int
func.func @torch.aten.zeros_item() -> !torch.int {
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%none = torch.constant.none
%0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%1 = torch.aten.zeros %0, %int3, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si32>
%2 = torch.aten.item %1 : !torch.vtensor<[1],si32> -> !torch.int
return %2 : !torch.int
}
// CHECK-LABEL: func.func @torch.aten.full_item
// CHECK: %[[CONST:.*]] = torch.constant.int 1337
// CHECK: return %[[CONST]] : !torch.int
func.func @torch.aten.full_item() -> !torch.int {
%int1 = torch.constant.int 1
%int3 = torch.constant.int 1337
%int5 = torch.constant.int 5
%none = torch.constant.none
%0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%1 = torch.aten.full %0, %int3, %int5, %none, %none, %none : !torch.list<int>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si32>
%2 = torch.aten.item %1 : !torch.vtensor<[1],si32> -> !torch.int
return %2 : !torch.int
}
// CHECK-LABEL: func.func @torch.aten.is_floating_point$fold_true
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: return %[[TRUE]] : !torch.bool