mirror of https://github.com/llvm/torch-mlir
[torch] Add folders for `torch.fill`, `torch.ones`, `torch.zeros` and `aten.getItem` (#2849)
So that the CumSum Op in OPT can get the constant that it requires to be lowered to TMTensor --------- Co-authored-by: Rob Suderman <rob.suderman@gmail.com> Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>pull/2856/head
parent
962d514308
commit
24b8c8672a
|
@ -8416,6 +8416,7 @@ def Torch_AtenOnesOp : Torch_Op<"aten.ones", [
|
|||
printDefaultTorchOp(printer, *this, 5, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenNewOnesOp : Torch_Op<"aten.new_ones", [
|
||||
|
@ -8471,6 +8472,7 @@ def Torch_AtenZerosOp : Torch_Op<"aten.zeros", [
|
|||
printDefaultTorchOp(printer, *this, 5, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenNewZerosOp : Torch_Op<"aten.new_zeros", [
|
||||
|
@ -9858,6 +9860,7 @@ def Torch_AtenItemOp : Torch_Op<"aten.item", [
|
|||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenMaskedSelectOp : Torch_Op<"aten.masked_select", [
|
||||
|
@ -11202,6 +11205,7 @@ def Torch_AtenFullOp : Torch_Op<"aten.full", [
|
|||
printDefaultTorchOp(printer, *this, 6, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenFullLikeOp : Torch_Op<"aten.full_like", [
|
||||
|
|
|
@ -1089,11 +1089,9 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
binder.op, "expected result type to have a dtype");
|
||||
}
|
||||
// resultTensorType.print(llvm::outs());
|
||||
Value resultDType = Torch::getDtypeIntValueForType(
|
||||
rewriter, loc, resultTensorType.getDtype());
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenCumsumOp>(
|
||||
binder.op, resultType, operand, dim, resultDType);
|
||||
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenCumsumOp>(binder.op, resultType,
|
||||
operand, dim, none);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
|
|
|
@ -6,9 +6,10 @@
|
|||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define DEBUG_TYPE "torch-mlir-torch-dialect"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
@ -2813,6 +2814,151 @@ OpFoldResult AtenDivIntOp::fold(FoldAdaptor adaptor) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenItemOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenItemOp::fold(FoldAdaptor adaptor) {
|
||||
// see if we have a constant tensor
|
||||
DenseElementsAttr attr;
|
||||
if (matchPattern(getOperand(), m_Constant(&attr))) {
|
||||
auto splat = attr.getSplatValue<Attribute>();
|
||||
if (auto intAttr = dyn_cast<IntegerAttr>(splat)) {
|
||||
return getI64IntegerAttr(getContext(), intAttr.getSInt());
|
||||
}
|
||||
if (auto floatAttr = dyn_cast<FloatAttr>(splat)) {
|
||||
return getF64FloatAttr(getContext(), floatAttr.getValueAsDouble());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenOnesOp, AtenZerosOp, AtenFullOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) {
|
||||
SmallVector<int64_t> sizes;
|
||||
if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenOnesOp: size operand is "
|
||||
"not a list of constant integers.\n");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Type resultType = getResult().getType();
|
||||
BaseTensorType resultTensorType = resultType.dyn_cast<BaseTensorType>();
|
||||
if (!resultTensorType || !resultTensorType.hasDtype()) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenOnesOp: result type is not "
|
||||
"a tensor type or does not have a dtype.\n");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ShapedType shapedty =
|
||||
mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType
|
||||
sizes, resultTensorType.getDtype());
|
||||
if (!shapedty) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "Failing to fold AtenOnesOp: ShapedType cast failed.\n");
|
||||
return nullptr;
|
||||
}
|
||||
auto elementType = shapedty.getElementType();
|
||||
if (elementType.isa<IntegerType>()) {
|
||||
Attribute attribute = IntegerAttr::get(elementType, 1);
|
||||
return DenseElementsAttr::get(shapedty, attribute);
|
||||
}
|
||||
if (elementType.isa<FloatType>()) {
|
||||
Attribute attribute = FloatAttr::get(elementType, 1.0);
|
||||
return DenseElementsAttr::get(shapedty, attribute);
|
||||
}
|
||||
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenOnesOp: element type is "
|
||||
"not integer or float.\n");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) {
|
||||
SmallVector<int64_t> sizes;
|
||||
if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenZerosOp: size operand is "
|
||||
"not a list of constant integers.\n");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Type resultType = getResult().getType();
|
||||
BaseTensorType resultTensorType = resultType.dyn_cast<BaseTensorType>();
|
||||
if (!resultTensorType || !resultTensorType.hasDtype()) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenZerosOp: result type is "
|
||||
"not a tensor type or does not have a dtype.\n");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ShapedType shapedty =
|
||||
mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType
|
||||
sizes, resultTensorType.getDtype());
|
||||
if (!shapedty) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "Failing to fold AtenZerosOp: ShapedType cast failed.\n");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto elementType = shapedty.getElementType();
|
||||
if (elementType.isa<IntegerType>()) {
|
||||
Attribute attribute = IntegerAttr::get(elementType, 0);
|
||||
return DenseElementsAttr::get(shapedty, attribute);
|
||||
}
|
||||
if (elementType.isa<FloatType>()) {
|
||||
Attribute attribute = FloatAttr::get(elementType, 0.0);
|
||||
return DenseElementsAttr::get(shapedty, attribute);
|
||||
}
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenZerosOp: element type is "
|
||||
"not integer or float.\n");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) {
|
||||
SmallVector<int64_t> sizes;
|
||||
if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenFullOp: size operand is "
|
||||
"not a list of constant integers.\n");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Type resultType = getResult().getType();
|
||||
BaseTensorType resultTensorType = resultType.dyn_cast<BaseTensorType>();
|
||||
if (!resultTensorType || !resultTensorType.hasDtype()) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenFullOp: result type is not "
|
||||
"a tensor type or does not have a dtype.\n");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ShapedType shapedty =
|
||||
mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType
|
||||
sizes, resultTensorType.getDtype());
|
||||
if (!shapedty) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "Failing to fold AtenFullOp: ShapedType cast failed.\n");
|
||||
return nullptr;
|
||||
}
|
||||
auto elementType = shapedty.getElementType();
|
||||
if (elementType.isa<IntegerType>()) {
|
||||
int64_t value = 0;
|
||||
if (matchPattern(getFillValue(), m_TorchConstantInt(&value))) {
|
||||
Attribute attribute = IntegerAttr::get(elementType, value);
|
||||
return DenseElementsAttr::get(shapedty, attribute);
|
||||
}
|
||||
}
|
||||
if (elementType.isa<FloatType>()) {
|
||||
double value = 0.0;
|
||||
if (matchPattern(getFillValue(), m_TorchConstantFloat(&value))) {
|
||||
Attribute attribute = FloatAttr::get(elementType, value);
|
||||
return DenseElementsAttr::get(shapedty, attribute);
|
||||
}
|
||||
}
|
||||
LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenFullOp: element type is "
|
||||
"not integer or float.\n");
|
||||
return nullptr;
|
||||
}
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenCeilFloatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -26,6 +26,9 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
|
|||
TORCHDYNAMO_XFAIL_SET = {
|
||||
#### General TorchDynamo/PyTorch errors
|
||||
|
||||
# torch._dynamo.exc.Unsupported: Tensor.item
|
||||
"CumsumModule_basic",
|
||||
|
||||
# TypeError: new_empty(): argument 'size' (position 1) must be tuple of ints, but found element of type NoneType at pos 0
|
||||
# RuntimeError: Failed running call_function aten.convolution_backward(...
|
||||
# https://github.com/pytorch/pytorch/issues/89629
|
||||
|
|
|
@ -564,9 +564,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True)
|
||||
emit("aten::Bool.Tensor : (Tensor) -> (bool)")
|
||||
emit("aten::is_floating_point : (Tensor) -> (bool)", has_folder=True)
|
||||
emit("aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)", has_folder=True)
|
||||
emit("aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)", has_folder=True)
|
||||
emit("aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::eye : (int, int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::eye.m : (int, int, int?, int?, Device?, bool?) -> (Tensor)")
|
||||
|
@ -618,7 +618,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)")
|
||||
emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)")
|
||||
emit_with_mutating_variants("aten::_index_put_impl : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)")
|
||||
emit("aten::item : (Tensor) -> (Scalar)")
|
||||
emit("aten::item : (Tensor) -> (Scalar)", has_folder=True)
|
||||
emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::numel : (Tensor) -> (int)", has_canonicalizer=True)
|
||||
emit("aten::repeat : (Tensor, int[]) -> (Tensor)")
|
||||
|
@ -669,7 +669,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)")
|
||||
emit("aten::t : (Tensor) -> (Tensor)")
|
||||
emit("aten::numpy_T : (Tensor) -> (Tensor)")
|
||||
emit("aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)", has_folder=True)
|
||||
emit("aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||
emit("aten::new_full : (Tensor, int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit_with_mutating_variants("aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)")
|
||||
|
|
|
@ -4092,7 +4092,13 @@ class CumsumModule(torch.nn.Module):
|
|||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, val):
|
||||
return torch.ops.aten.cumsum(val, 1)
|
||||
# the onnx cumsum op uses a constant 1d tensor
|
||||
# to specify the dimension along which to do cumsum
|
||||
# we replicate that here to ensure that cumsum correctly
|
||||
# trigger the relevant folders and provides TMTensor
|
||||
# with a constant dimension
|
||||
ones = torch.ones([1], dtype=torch.int32)
|
||||
return torch.ops.aten.cumsum(val, ones.item())
|
||||
|
||||
@register_test_case(module_factory=lambda: CumsumModule())
|
||||
def CumsumModule_basic(module, tu: TestUtils):
|
||||
|
|
|
@ -29,6 +29,46 @@ func.func @torch.runtime.assert() {
|
|||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.ones_item
|
||||
// CHECK: %[[CONST:.*]] = torch.constant.int 1
|
||||
// CHECK: return %[[CONST]] : !torch.int
|
||||
func.func @torch.aten.ones_item() -> !torch.int {
|
||||
%int1 = torch.constant.int 1
|
||||
%int3 = torch.constant.int 3
|
||||
%none = torch.constant.none
|
||||
%0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
|
||||
%1 = torch.aten.ones %0, %int3, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si32>
|
||||
%2 = torch.aten.item %1 : !torch.vtensor<[1],si32> -> !torch.int
|
||||
return %2 : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.zeros_item
|
||||
// CHECK: %[[CONST:.*]] = torch.constant.int 0
|
||||
// CHECK: return %[[CONST]] : !torch.int
|
||||
func.func @torch.aten.zeros_item() -> !torch.int {
|
||||
%int1 = torch.constant.int 1
|
||||
%int3 = torch.constant.int 3
|
||||
%none = torch.constant.none
|
||||
%0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
|
||||
%1 = torch.aten.zeros %0, %int3, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si32>
|
||||
%2 = torch.aten.item %1 : !torch.vtensor<[1],si32> -> !torch.int
|
||||
return %2 : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.full_item
|
||||
// CHECK: %[[CONST:.*]] = torch.constant.int 1337
|
||||
// CHECK: return %[[CONST]] : !torch.int
|
||||
func.func @torch.aten.full_item() -> !torch.int {
|
||||
%int1 = torch.constant.int 1
|
||||
%int3 = torch.constant.int 1337
|
||||
%int5 = torch.constant.int 5
|
||||
%none = torch.constant.none
|
||||
%0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
|
||||
%1 = torch.aten.full %0, %int3, %int5, %none, %none, %none : !torch.list<int>, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si32>
|
||||
%2 = torch.aten.item %1 : !torch.vtensor<[1],si32> -> !torch.int
|
||||
return %2 : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.is_floating_point$fold_true
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: return %[[TRUE]] : !torch.bool
|
||||
|
|
Loading…
Reference in New Issue