From 9535be790363ecb97df0f64c85ac279aeca26011 Mon Sep 17 00:00:00 2001 From: Jiawei Wu Date: Thu, 20 Jul 2023 16:46:44 +0800 Subject: [PATCH] [Torch-Dialect] emit aten.narrow.Tensor op and decompose it to aten.narrow op (#2297) --- e2e_testing/xfail_sets.py | 3 ++ .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 +++++++++++++ lib/Conversion/TorchToStablehlo/Basic.cpp | 17 +++++++++ lib/Dialect/Torch/IR/TorchOps.cpp | 4 +++ .../Transforms/AbstractInterpLibrary.cpp | 8 +++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 22 ++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + lib/Dialect/Torch/Utils/Utils.cpp | 5 +-- .../build_tools/abstract_interp_lib_gen.py | 10 ++++++ .../jit_ir/build_tools/torch_ods_gen.py | 1 + .../test_suite/slice_like.py | 36 +++++++++++++++++++ 11 files changed, 130 insertions(+), 2 deletions(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 48d136685..5bd7fd274 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -314,6 +314,7 @@ TORCHDYNAMO_CRASHING_SET = { STABLEHLO_PASS_SET = { "AliasModule_basic", + "TensorIntModule_basic", "AllBoolFalseModule_basic", "AllBoolTrueModule_basic", "AnyBoolFalseModule_basic", @@ -751,6 +752,8 @@ STABLEHLO_PASS_SET = { "NarrowHorizontalTest_basic", "NarrowVerticalTest2_basic", "NarrowVerticalTest_basic", + "NarrowTensorHorizontalModule_basic", + "NarrowTensorVerticalModule_basic", "NumToTensorIntModule_basic", "NumpyTRank0Module_basic", "NumpyTRank1Module_basic", diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 860b674bb..51a6abc04 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -11314,6 +11314,31 @@ def Torch_AtenNarrowOp : Torch_Op<"aten.narrow", [ }]; } +def Torch_AtenNarrowTensorOp : Torch_Op<"aten.narrow.Tensor", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::narrow.Tensor : (Tensor, int, Tensor, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchTensorType:$start, + Torch_IntType:$length + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNarrowTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenNarrowTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenScalarImplicitOp : Torch_Op<"aten.ScalarImplicit", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index e37e0a609..6b7c3fd86 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -762,6 +762,22 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// AtenTensorIntOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenTensorIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + RankedTensorType resultType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + Type outElementType = resultType.getElementType(); + Value innerValue = adaptor.getT(); + Value stablehloTensor = + hlo::scalarToStablehloTensor(rewriter, op, innerValue, outElementType); + rewriter.replaceOp(op, stablehloTensor); + return success(); +} + // AtenReciprocalOp // Reciprocal(x) = Div(1, x) template <> @@ -1699,6 +1715,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenPermuteOp); INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); + INSERT_ATENOP_PATTERN(AtenTensorIntOp); INSERT_ATENOP_PATTERN(AtenReciprocalOp); INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp); INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index ab1acaa43..91d4c40cf 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -156,6 +156,8 @@ static Value getScalarIntValue(Value input, Location loc, } else if (auto primNumToTensorScalarOp = input.getDefiningOp()) { return primNumToTensorScalarOp.getA(); + } else if (auto tensorIntOp = input.getDefiningOp()) { + return tensorIntOp.getT(); } return nullptr; } @@ -2557,6 +2559,8 @@ OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) { // aten.Int.Tensor, fold to the scalar number. if (auto numToTensorScalar = getA().getDefiningOp()) return numToTensorScalar.getA(); + if (auto tensorIntOp = getA().getDefiningOp()) + return tensorIntOp.getT(); return nullptr; } diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 5891e88a1..520c11c5e 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7543,6 +7543,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %1, %2, %int1) : (!torch.list, !torch.int, !torch.optional, !torch.optional, !torch.int) -> !torch.list\n" " return %3 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.narrow.Tensor\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.int) -> !torch.list {\n" +" %0 = torch.aten._set_item.t %arg0, %arg1, %arg3 : !torch.list, !torch.int, !torch.int -> !torch.list\n" +" return %arg0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.slice_scatter\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.int) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" @@ -8430,6 +8434,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.narrow.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple, %arg3: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.neg\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 13996d6db..1db9440ce 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -340,6 +340,27 @@ public: }; } // namespace +namespace { +// Decompose `aten.narrow.Tensor` to `aten.narrow` op +class DecomposeAtenNarrowTensorOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNarrowTensorOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto *context = op.getContext(); + // PyTorch makes sure that `start` param is an 0-dim integral tensor. + // REF: https://pytorch.org/docs/stable/generated/torch.narrow.html. + auto start = rewriter.create( + loc, Torch::IntType::get(context), op.getStart()); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), op.getDim(), start, op.getLength()); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenZeroOp : public OpRewritePattern { @@ -4753,6 +4774,7 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index fc9b845cf..6ab84e748 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -459,6 +459,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index b52416665..0e0e20f86 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -204,8 +204,9 @@ bool Torch::isViewLikeOp(Operation *op) { AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp, AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp, TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp, - AtenNarrowOp, AtenToDeviceOp, PrimsSqueezeOp, AtenMovedimIntOp, - PrimsViewOfOp, AtenRealOp, AtenImagOp, AtenViewAsComplexOp>(op); + AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp, + AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp, + AtenViewAsComplexOp>(op); } Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index e7794e0a3..a0ae1e027 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -900,6 +900,11 @@ def aten〇sort〡dtype(self_rank_dtype: Tuple[int, int], dim: int = -1, descend def aten〇narrow〡shape(self: List[int], dim: int, start: int, length: int) -> List[int]: return upstream_shape_functions.slice(self, dim, start, start + length, 1) +# This shape function is a little hacky, because we don't know the start index which is determined by a tensor param. +def aten〇narrow〇Tensor〡shape(self: List[int], dim: int, start: List[int], length: int) -> List[int]: + self[dim] = length + return self + def aten〇slice_scatter〡shape(self: List[int], src: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]: return self @@ -1659,6 +1664,11 @@ def aten〇narrow〡dtype(self_rank_dtype: Tuple[int, int], dim: int, start: int self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function([Invocation(TensorOfShape(3, 4, dtype=dtype, device=torch.device("cpu")), 0, ZeroDTensorWithDtype(dtype=torch.int64, device=torch.device("cpu")), 1) for dtype in _SORTED_TORCH_TYPES]) +def aten〇narrow〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], dim: int, start_rank_dtype: Tuple[int, int], length: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool})) def aten〇neg〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 94938d2a4..5dbec3999 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -668,6 +668,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::eq.device : (Device, Device) -> (bool)") emit("aten::ceil.float : (float) -> (int)", has_folder=True) emit("aten::narrow : (Tensor, int, int, int) -> (Tensor)") + emit("aten::narrow.Tensor : (Tensor, int, Tensor, int) -> (Tensor)") emit("aten::ScalarImplicit : (Tensor) -> (Scalar)", has_canonicalizer=True) # backprop ops diff --git a/python/torch_mlir_e2e_test/test_suite/slice_like.py b/python/torch_mlir_e2e_test/test_suite/slice_like.py index 0d67636a9..e18775a85 100644 --- a/python/torch_mlir_e2e_test/test_suite/slice_like.py +++ b/python/torch_mlir_e2e_test/test_suite/slice_like.py @@ -523,6 +523,42 @@ def NarrowVerticalTest2_basic(module, tu: TestUtils): # ============================================================================== +class NarrowTensorHorizontalModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True) + ]) + def forward(self, x): + return torch.narrow(x, dim=1, start=torch.tensor(0), length=2) + +@register_test_case(module_factory=lambda: NarrowTensorHorizontalModule()) +def NarrowTensorHorizontalModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6,4)) + +# ============================================================================== + +class NarrowTensorVerticalModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True) + ]) + def forward(self, x): + return torch.narrow(x, dim=1, start=torch.tensor(1), length=2) + +@register_test_case(module_factory=lambda: NarrowTensorVerticalModule()) +def NarrowTensorVerticalModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6,4)) + +# ============================================================================== + class SliceCopy_Module(torch.nn.Module): def __init__(self): super().__init__()