From 0e77de996aa715fb75aff4c3dd3d10c8c9c01853 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 18 Apr 2024 11:47:19 -0700 Subject: [PATCH] [torch] Add support for `torch.view` with dynamic shapes (#3164) We can map to `tensor.reshape` for handling multiple output dynamic shapes. Later we can perform a more complex analysis for indentifying expand/collapse cases from the tensor.reshape. Initially we planned to handle this identification at the `torch` level however it will be easier to handle once converted to core mlir-dialects. --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 95 ++++++++++++++++++- projects/pt1/e2e_testing/main.py | 3 +- projects/pt1/e2e_testing/xfail_sets.py | 15 +-- .../test_suite/reshape_like.py | 24 ++++- test/Conversion/TorchToLinalg/view.mlir | 12 +-- 5 files changed, 127 insertions(+), 22 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 5a47a247a..a94f8882e 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -1003,8 +1003,14 @@ public: // collapsed. Note this may technically not always be true. // TODO: think of a way better way to at least detect when this assumption // is violated for the cases of dynamic dimensions. - bool inputHasOneDynDim = llvm::count(inputShape, kUnknownSize) == 1; - bool outputHasOneDynDim = llvm::count(outputShape, kUnknownSize) == 1; + int64_t inputDynDim = llvm::count(inputShape, kUnknownSize); + int64_t outputDynDim = llvm::count(outputShape, kUnknownSize); + if (outputDynDim > 1) + return rewriter.notifyMatchFailure( + op, "Cannot support more than one output dynamic dimension"); + + bool inputHasOneDynDim = inputDynDim == 1; + bool outputHasOneDynDim = outputDynDim == 1; bool singleDynDimsAreEqual = inputHasOneDynDim && outputHasOneDynDim && productReduce(inputShape) == productReduce(outputShape); @@ -1271,6 +1277,85 @@ public: }; } // namespace +namespace { +class ConvertAtenViewOpToReshape : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenViewOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector sizes; + if (!getListConstructElements(op.getSize(), sizes)) + return op.emitError( + "unimplemented: the tensor size list is not from list construct"); + + auto loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, rewriter); + auto self = adaptor.getSelf(); + const TypeConverter *typeConverter = getTypeConverter(); + + // Convert to the `linalg` types, count the number of negative values, + // and determine the product of non-negative values. This lets us compute + // the inferred dimensions sizes. + auto sizeTy = + cast(typeConverter->convertType(sizes.front().getType())); + Value one = + b.create(sizeTy, rewriter.getIntegerAttr(sizeTy, 1)); + Value zero = + b.create(sizeTy, rewriter.getIntegerAttr(sizeTy, 0)); + Value count = zero; + Value knownSize = one; + for (auto &size : sizes) { + Value convert = typeConverter->materializeTargetConversion(rewriter, loc, + sizeTy, size); + + Value mul = b.create(knownSize, convert); + Value add = b.create(count, one); + Value isNeg = + b.create(arith::CmpIPredicate::slt, convert, zero); + + knownSize = b.create(isNeg, knownSize, mul); + count = b.create(isNeg, add, count); + size = convert; + } + + // Check we are only inferring one dimension: + Value countPred = + b.create(arith::CmpIPredicate::sle, count, one); + b.create( + loc, countPred, + b.getStringAttr("must have at most one inferred (negative) dimension")); + + // Determine the total size of the inferred dimension and update the + // inferred dimension: + auto selfTy = cast(self.getType()); + Value totalSize = one; + for (int i = 0, s = selfTy.getRank(); i < s; ++i) { + Value index = b.create(i); + Value dim = b.create(self, index); + dim = b.create(sizeTy, dim); + totalSize = b.create(totalSize, dim); + } + + Value inferredSize = b.create(totalSize, knownSize); + for (auto &size : sizes) { + Value isNeg = + b.create(arith::CmpIPredicate::slt, size, zero); + size = b.create(isNeg, inferredSize, size); + } + + auto ty = RankedTensorType::get(sizes.size(), sizes.front().getType()); + auto outputDims = b.create(ty, sizes); + + auto resultType = + typeConverter->convertType(op.getType()).cast(); + rewriter.replaceOpWithNewOp(op, resultType, self, + outputDims); + return success(); + } +}; +} // namespace + namespace { class ConvertAtenSqueezeOp : public OpConversionPattern { public: @@ -2348,10 +2433,12 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); - patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context, /*benefit=*/200); + patterns.add(typeConverter, context, + /*benefit=*/100); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/main.py b/projects/pt1/e2e_testing/main.py index 9f2323793..d2c381d65 100644 --- a/projects/pt1/e2e_testing/main.py +++ b/projects/pt1/e2e_testing/main.py @@ -32,6 +32,7 @@ from torch_mlir_e2e_test.stablehlo_backends.linalg_on_tensors import LinalgOnTen from .xfail_sets import ( LINALG_XFAIL_SET, + LINALG_CRASHING_SET, MAKE_FX_TOSA_PASS_SET, STABLEHLO_PASS_SET, STABLEHLO_CRASHING_SET, @@ -99,7 +100,7 @@ def main(): if args.config == "linalg": config = LinalgOnTensorsBackendTestConfig(RefBackendLinalgOnTensorsBackend()) xfail_set = LINALG_XFAIL_SET - crashing_set = set() + crashing_set = LINALG_CRASHING_SET elif args.config == "stablehlo": config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend()) xfail_set = all_test_unique_names - STABLEHLO_PASS_SET diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index f6949abcf..72b667802 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -24,6 +24,11 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | { "SplitWithSizes_Module_basic", } +LINALG_CRASHING_SET = { + # Crashes due to copy to a smaller destination buffer than the source buffer. + "SliceCopyStartGreaterThanDimSize_Module_basic", +} + TORCHDYNAMO_XFAIL_SET = { #### General TorchDynamo/PyTorch errors @@ -2280,15 +2285,6 @@ ONNX_XFAIL_SET = { "ElementwiseToDtypeI64ToUI8Module_basic", # Failure - torch.aten.view lower - "IndexTensorDyanmicInputContiguousWithNoneModule_basic", - "IndexTensorDyanmicInputNonContiguousWithNoneModule_basic", - "IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic", - "IndexTensorMultiInputContiguousCenter_basic", - "IndexTensorMultiInputNonContiguousMultipleStaticDims_basic", - "IndexTensorMultiInputNonContiguous_basic", - "IndexTensorMultiInputOneDim_basic", - "IndexTensorMultiInputThreeIndexers_basic", - "IndexTensorMultiInput_basic", "IndexTensorMultiInputContiguousOneDimDynamic_basic", "IndexTensorMultiInputNonContiguousOneDimDynamic_basic", @@ -2327,7 +2323,6 @@ ONNX_XFAIL_SET = { "EmbeddingModuleF16_basic", "EmbeddingModuleI32_basic", "EmbeddingModuleI64_basic", - "FlattenDynamicModule_basic", "GluStaticModule_basic", "GroupNormModule_basic", "IndexTensorHackedTwinModule3dInput_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py index 73b15afe9..8aa3e2c1f 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -992,6 +992,28 @@ class ReshapeAliasExpandModule(torch.nn.Module): def ReshapeAliasExpandModule_basic(module, tu: TestUtils): module.forward(tu.rand(384)) + +# ============================================================================== + +class ReshapeDynamicModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(a.size(1), a.size(0)) + +@register_test_case(module_factory=lambda: ReshapeDynamicModule()) +def ReshapeDynamicModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3,4)) + + + # ============================================================================== class ReshapeAliasCollapseModule(torch.nn.Module): @@ -1153,4 +1175,4 @@ class EinsumStaticWithEllipsisSlicingAndBroadcastModule(torch.nn.Module): @register_test_case(module_factory=lambda: EinsumStaticWithEllipsisSlicingAndBroadcastModule()) def EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 6, 4, 5), tu.rand(6, 5)) \ No newline at end of file + module.forward(tu.rand(2, 6, 4, 5), tu.rand(6, 5)) diff --git a/test/Conversion/TorchToLinalg/view.mlir b/test/Conversion/TorchToLinalg/view.mlir index 4f9c1f867..7cad9ffe3 100644 --- a/test/Conversion/TorchToLinalg/view.mlir +++ b/test/Conversion/TorchToLinalg/view.mlir @@ -23,7 +23,8 @@ func.func @torch.aten.view$twotothree(%arg0: !torch.vtensor<[3,2],f32>) -> !torc // CHECK-LABEL: func.func @torch.aten.view$dynamictest( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[BUILTIN_TENSOR]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: %[[RESHAPE:.*]] = tensor.reshape %[[BUILTIN_TENSOR]] +// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[RESHAPE]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.view$dynamictest(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -31,7 +32,7 @@ func.func @torch.aten.view$dynamictest(%arg0: !torch.vtensor<[?,?],f32>) -> !tor %int0 = torch.constant.int 0 %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int %1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int - %2 = torch.prim.ListConstruct %0, %1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %1, %0 : (!torch.int, !torch.int) -> !torch.list %3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?],f32>, !torch.list -> !torch.vtensor<[?,?],f32> return %3 : !torch.vtensor<[?,?],f32> } @@ -41,7 +42,7 @@ func.func @torch.aten.view$dynamictest(%arg0: !torch.vtensor<[?,?],f32>) -> !tor // CHECK-LABEL: func.func @torch.aten.view$dynamictest2( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,6,?],f32>) -> !torch.vtensor<[?,2,3,?],f32> { // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,6,?],f32> -> tensor -// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1, 2], [3]] : tensor into tensor +// CHECK: %[[EXPAND:.*]] = tensor.reshape %[[BUILTIN_TENSOR]] // CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND]] : tensor -> !torch.vtensor<[?,2,3,?],f32> // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[?,2,3,?],f32> @@ -174,9 +175,8 @@ func.func @torch.aten.view$singleUnknownMatches0(%arg0: !torch.vtensor<[10,3,?,2 // CHECK: func.func @torch.aten.view$combineConcepts( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[8,?,?,?,2,1,3],f32>) -> !torch.vtensor<[2,2,2,?,?,?,6],f32> { // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[8,?,?,?,2,1,3],f32> -> tensor<8x?x?x?x2x1x3xf32> -// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2], [3], [4, 5, 6]] : tensor<8x?x?x?x2x1x3xf32> into tensor<8x?x?x?x6xf32> -// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[COLLAPSE]] {{\[\[}}0, 1, 2], [3], [4], [5], [6]] : tensor<8x?x?x?x6xf32> into tensor<2x2x2x?x?x?x6xf32> -// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND]] : tensor<2x2x2x?x?x?x6xf32> -> !torch.vtensor<[2,2,2,?,?,?,6],f32> +// CHECK: %[[RESHAPE:.*]] = tensor.reshape %[[BUILTIN_TENSOR]] +// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[RESHAPE]] : tensor<2x2x2x?x?x?x6xf32> -> !torch.vtensor<[2,2,2,?,?,?,6],f32> // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,2,2,?,?,?,6],f32> func.func @torch.aten.view$combineConcepts(%arg0 : !torch.vtensor<[8,?,?,?,2,1,3], f32>) -> !torch.vtensor<[2,2,2,?,?,?,6], f32> {