From 00efec0b73c6b9e4f5ec1ab1276fbbffbea17389 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Fri, 10 May 2024 13:45:50 -0700 Subject: [PATCH] [linalg] Implement strict mode lowering for aten.view. (#3319) * Enables assume_strict_symbolic_shapes on fx_importer imported programs, indicating strict shape semantics. * Reworks the view->reshape lowering to take advantage of strict mode and do one of: * Collapse to 0D * Flatten/Unflatten when there is an inferred dim. * Fallback to tensor.reshape * Splits some test cases up and adds an attribute to control the old pattern (so new corners can be tested in strict mode in isolation). * Dynamic inferred mode needs upstream work to generalize expand_shape (so that case is suppressed here). * Deletes the assert from the existing tensor.reshape lowering if strict shape mode is enabled (since the condition it is dynamically asserting cannot happen). --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 199 +++++++++++++++++- python/torch_mlir/extras/fx_importer.py | 5 + test/Conversion/TorchToLinalg/view.mlir | 73 ++++--- .../Conversion/TorchToLinalg/view_strict.mlir | 150 +++++++++++++ 4 files changed, 395 insertions(+), 32 deletions(-) create mode 100644 test/Conversion/TorchToLinalg/view_strict.mlir diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 67d13c5fb..d8dd75a9a 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -940,6 +940,9 @@ public: LogicalResult matchAndRewrite(AtenViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + if (op->getParentOp()->hasAttr("torch.disable_legacy_view")) + return rewriter.notifyMatchFailure(op.getLoc(), + "legacy view lowering diabled"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); @@ -1284,6 +1287,9 @@ public: LogicalResult matchAndRewrite(AtenViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + if (op->getParentOp()->hasAttr("torch.disable_legacy_view")) + return rewriter.notifyMatchFailure(op.getLoc(), + "legacy view lowering diabled"); SmallVector sizes; if (!getListConstructElements(op.getSize(), sizes)) return op.emitError( @@ -1319,12 +1325,16 @@ public: size = convert; } - // Check we are only inferring one dimension: - Value countPred = - b.create(arith::CmpIPredicate::sle, count, one); - b.create( - loc, countPred, - b.getStringAttr("must have at most one inferred (negative) dimension")); + // Check we are only inferring one dimension if not in strict mode. In + // strict mode, there will only ever statically be one inferred dim. + if (!isAssumingStrictSymbolicShapes(rewriter)) { + Value countPred = + b.create(arith::CmpIPredicate::sle, count, one); + b.create( + loc, countPred, + b.getStringAttr( + "must have at most one inferred (negative) dimension")); + } // Determine the total size of the inferred dimension and update the // inferred dimension: @@ -1356,6 +1366,165 @@ public: }; } // namespace +namespace { +class ConvertAtenViewOpStrict : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenViewOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isAssumingStrictSymbolicShapes(rewriter)) + return rewriter.notifyMatchFailure(op.getLoc(), + "not strict symbolic shapes"); + SmallVector sizeValues; + if (!getListConstructElements(op.getSize(), sizeValues)) + return op.emitError( + "unimplemented: the tensor size list is not from list construct"); + + auto loc = op.getLoc(); + auto resultType = + cast(typeConverter->convertType(op.getType())); + auto self = adaptor.getSelf(); + auto selfTy = cast(self.getType()); + + // Handle collapse to 0D. + if (sizeValues.empty()) { + rewriter.replaceOpWithNewOp( + op, resultType, adaptor.getSelf(), ArrayRef{}); + return success(); + } + + // If there is a static inferred dimension (-1), then we emit a + // flatten/unflatten and let that proceed through its lowering. + // Otherwise, emit a tensor.reshape. Note that this relies on the fact that + // Torch does not allow such an op to have a symbolic inferred dim. + int inferredDim = -1; + bool staticSizes = true; + for (int i = 0, e = sizeValues.size(); i < e; ++i) { + int64_t dim; + if (!matchPattern(sizeValues[i], m_TorchConstantInt(&dim))) { + staticSizes = false; + continue; + } + if (dim == -1) { + inferredDim = i; + break; + } + } + + // While it should be illegal to have a view op with fully known sizes + // and a dynamic shape, in reality, torch IR is a bit loosey and + // progressively resolves to this state. There are delicate invariants + // on the ops we produce that require this, so we enforce. + if (staticSizes && !resultType.hasStaticShape()) { + return rewriter.notifyMatchFailure(loc, + "view cannot be converted with static " + "sizes and a dynamic result type"); + } + + // Handle inferred dim case. + // TODO: Remove the restriction on staticSizes once flatten/unflatten + // reliably work with multiple dynamic dimensions. + if (inferredDim >= 0 && staticSizes) { + if (!staticSizes) { + return rewriter.notifyMatchFailure( + loc, "view to flatten/unflatten only supported for static sizes"); + } + // This is a torch-torch conversion, so only non adapted types are + // involved. + auto selfTy = dyn_cast(op.getSelf().getType()); + if (!selfTy || !selfTy.hasSizes()) + return failure(); + + // Work out the 1D flattened type. + int64_t flatDim = 1; + auto selfSizes = selfTy.getSizes(); + for (int64_t dim : selfSizes) { + if (dim == kUnknownSize) { + flatDim = kUnknownSize; + break; + } + flatDim *= dim; + } + // Flatten to 1D. + ValueTensorType flatType = rewriter.getType( + ArrayRef{flatDim}, selfTy.getOptionalDtype()); + Value dimStart = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value dimEnd = rewriter.create( + loc, rewriter.getI64IntegerAttr(selfSizes.size() - 1)); + Value flatSelf = rewriter.create( + loc, flatType, op.getSelf(), dimStart, dimEnd); + + // Unflatten to requested size. + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), flatSelf, dimStart, op.getSize()); + return success(); + } + + // Generate output dims, either based on whether there is an inferred dim + // present or all dims are specified. + auto sizeTy = cast( + typeConverter->convertType(sizeValues.front().getType())); + SmallVector outputDimValues; + assert(sizeTy && "Type converter did not handle size"); + if (inferredDim >= 0) { + // Inferred dim. If the above flatten/unflatten logic ever catches + // everything, this branch can go away entirely. + Value one = rewriter.create( + loc, sizeTy, rewriter.getIntegerAttr(sizeTy, 1)); + Value sizeProduct = one; + // Multiply the non-inferred target sizes. + for (int i = 0, e = sizeValues.size(); i < e; ++i) { + if (i == inferredDim) + continue; + Value size = sizeValues[i]; + Value convertedSize = typeConverter->materializeTargetConversion( + rewriter, loc, sizeTy, size); + assert(convertedSize && "Type converter did not handle size"); + sizeProduct = + rewriter.create(loc, sizeProduct, convertedSize); + } + + // Multiply the self tensor sizes. + Value selfProduct = one; + for (int i = 0, e = selfTy.getRank(); i < e; ++i) { + Value index = rewriter.create(loc, i); + Value dim = rewriter.create(loc, self, index); + dim = rewriter.create(loc, sizeTy, dim); + selfProduct = rewriter.create(loc, selfProduct, dim); + } + + Value inferredSize = + rewriter.create(loc, selfProduct, sizeProduct); + for (int i = 0, e = sizeValues.size(); i < e; ++i) { + if (i == inferredDim) { + outputDimValues.push_back(inferredSize); + } else { + outputDimValues.push_back(typeConverter->materializeTargetConversion( + rewriter, loc, sizeTy, sizeValues[i])); + } + } + } else { + // No inferred dim. So output dims are just pass through. + for (Value torchSize : sizeValues) { + outputDimValues.push_back(typeConverter->materializeTargetConversion( + rewriter, loc, sizeTy, torchSize)); + } + } + + // Normal lowering to reshape with fully computed sizes. + auto outputDimsTy = RankedTensorType::get( + outputDimValues.size(), outputDimValues.front().getType()); + auto outputDims = rewriter.create(loc, outputDimsTy, + outputDimValues); + rewriter.replaceOpWithNewOp( + op, resultType, adaptor.getSelf(), outputDims); + return success(); + } +}; +} // namespace + namespace { class ConvertAtenSqueezeOp : public OpConversionPattern { public: @@ -2459,6 +2628,9 @@ SmallVector ConvertSparseOperatorOp::legalizedNames = { void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { + // Add some legal ops for torch-torch lowering. + target.addLegalOp(); + MLIRContext *context = patterns.getContext(); target.addIllegalOp(); patterns.add(typeConverter, context); @@ -2468,10 +2640,23 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( patterns.add(typeConverter, context); patterns.add(typeConverter, context); target.addIllegalOp(); + + // View op sadness: In the future, we only want ConvertAtenViewOpStrict, + // but this requires work upstream to fully generalize reshape handling. + // In the meantime, the analysis based ConvertAtenViewOp tries hard to + // produce expand/collapse shapes, the ConvertAtenViewOpStrict does the + // right thing but cannot be fully supported for dynamic shapes, and + // ConvertAtenViewOpToReshape overly pessimizes and generates a lot of IR + // due to not statically switching between inferred and non-inferred view + // cases. They are ordered by optimiality of the lowerings they generate + // when they are able. target.addIllegalOp(); - patterns.add(typeConverter, context, /*benefit=*/200); + patterns.add(typeConverter, context, /*benefit=*/300); + patterns.add(typeConverter, context, + /*benefit=*/200); patterns.add(typeConverter, context, /*benefit=*/100); + target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 24bda3f5b..381f8f9ad 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -103,6 +103,7 @@ from ..ir import ( StringAttr, SymbolTable, Type as IrType, + UnitAttr, Value, ) @@ -642,6 +643,10 @@ class FxImporter: func_op = func_dialect.FuncOp( func_name, ftype, ip=self._m_ip, visibility=func_visibility ) + # Programs imported from FX have strong guarantees. Setting this attribute + # causes various lowerings to be able to emit more efficient code or + # handle more cases. See isAssumingStrictSymbolicShapes(). + func_op.attributes["torch.assume_strict_symbolic_shapes"] = UnitAttr.get() entry_block = Block.create_at_start(func_op.body, ftype.inputs) node_importer = GraphNodeImporter( diff --git a/test/Conversion/TorchToLinalg/view.mlir b/test/Conversion/TorchToLinalg/view.mlir index 20f0301cc..3d265a308 100644 --- a/test/Conversion/TorchToLinalg/view.mlir +++ b/test/Conversion/TorchToLinalg/view.mlir @@ -1,16 +1,17 @@ // RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s - // ----- // CHECK-LABEL: func.func @torch.aten.view$twotothree( -// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,2],f32>) -> !torch.vtensor<[2,3],f32> { +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,2],f32>) -> !torch.vtensor<[2,3],f32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[3,2],f32> -> tensor<3x2xf32> // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1]] : tensor<3x2xf32> into tensor<6xf32> // CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1]] output_shape [2, 3] : tensor<6xf32> into tensor<2x3xf32> // CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32> // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,3],f32> -func.func @torch.aten.view$twotothree(%arg0: !torch.vtensor<[3,2],f32>) -> !torch.vtensor<[2,3],f32> { +func.func @torch.aten.view$twotothree(%arg0: !torch.vtensor<[3,2],f32>) -> !torch.vtensor<[2,3],f32> + attributes {torch.assume_strict_symbolic_shapes} +{ %int3 = torch.constant.int 3 %int2 = torch.constant.int 2 %0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list @@ -21,13 +22,15 @@ func.func @torch.aten.view$twotothree(%arg0: !torch.vtensor<[3,2],f32>) -> !torc // ----- // CHECK-LABEL: func.func @torch.aten.view$dynamictest( -// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[RESHAPE:.*]] = tensor.reshape %[[BUILTIN_TENSOR]] // CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[RESHAPE]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[?,?],f32> -func.func @torch.aten.view$dynamictest(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +func.func @torch.aten.view$dynamictest(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> + attributes {torch.assume_strict_symbolic_shapes} +{ %int1 = torch.constant.int 1 %int0 = torch.constant.int 0 %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int @@ -40,13 +43,15 @@ func.func @torch.aten.view$dynamictest(%arg0: !torch.vtensor<[?,?],f32>) -> !tor // ----- // CHECK-LABEL: func.func @torch.aten.view$dynamictest2( -// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,6,?],f32>) -> !torch.vtensor<[?,2,3,?],f32> { +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,6,?],f32>) -> !torch.vtensor<[?,2,3,?],f32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,6,?],f32> -> tensor // CHECK: %[[EXPAND:.*]] = tensor.reshape %[[BUILTIN_TENSOR]] // CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND]] : tensor -> !torch.vtensor<[?,2,3,?],f32> // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[?,2,3,?],f32> -func.func @torch.aten.view$dynamictest2(%arg0: !torch.vtensor<[?,6,?],f32>) -> !torch.vtensor<[?,2,3,?],f32> { +func.func @torch.aten.view$dynamictest2(%arg0: !torch.vtensor<[?,6,?],f32>) -> !torch.vtensor<[?,2,3,?],f32> + attributes {torch.assume_strict_symbolic_shapes} +{ %int3 = torch.constant.int 3 %int2 = torch.constant.int 2 %int0 = torch.constant.int 0 @@ -60,7 +65,7 @@ func.func @torch.aten.view$dynamictest2(%arg0: !torch.vtensor<[?,6,?],f32>) -> ! // ----- // CHECK-LABEL: func.func @torch.aten.view$dynamicVal( -// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[1,?,128],f32>) -> !torch.vtensor<[16,1,128],f32> { +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[1,?,128],f32>) -> !torch.vtensor<[16,1,128],f32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[1,?,128],f32> -> tensor<1x?x128xf32> // CHECK: %[[CASTED:.*]] = tensor.cast %[[BUILTIN_TENSOR]] : tensor<1x?x128xf32> to tensor<1x16x128xf32> // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[CASTED]] {{\[\[}}0, 1], [2]] : tensor<1x16x128xf32> into tensor<16x128xf32> @@ -68,7 +73,9 @@ func.func @torch.aten.view$dynamictest2(%arg0: !torch.vtensor<[?,6,?],f32>) -> ! // CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<16x1x128xf32> -> !torch.vtensor<[16,1,128],f32> // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[16,1,128],f32> -func.func @torch.aten.view$dynamicVal(%arg0: !torch.vtensor<[1,?,128],f32>) -> !torch.vtensor<[16,1,128],f32> { +func.func @torch.aten.view$dynamicVal(%arg0: !torch.vtensor<[1,?,128],f32>) -> !torch.vtensor<[16,1,128],f32> + attributes {torch.assume_strict_symbolic_shapes} +{ %int128 = torch.constant.int 128 %int1 = torch.constant.int 1 %int16 = torch.constant.int 16 @@ -80,7 +87,7 @@ func.func @torch.aten.view$dynamicVal(%arg0: !torch.vtensor<[1,?,128],f32>) -> ! // ----- // CHECK-LABEL: func.func @torch.aten$dynamicValOutput( -// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[4,5,6],f32>) -> !torch.vtensor<[8,1,?,1],f32> { +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[4,5,6],f32>) -> !torch.vtensor<[8,1,?,1],f32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[4,5,6],f32> -> tensor<4x5x6xf32> // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1, 2]] : tensor<4x5x6xf32> into tensor<120xf32> // CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2, 3]] output_shape [8, 1, 15, 1] : tensor<120xf32> into tensor<8x1x15x1xf32> @@ -88,7 +95,9 @@ func.func @torch.aten.view$dynamicVal(%arg0: !torch.vtensor<[1,?,128],f32>) -> ! // CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[CAST]] : tensor<8x1x?x1xf32> -> !torch.vtensor<[8,1,?,1],f32> // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[8,1,?,1],f32> -func.func @torch.aten$dynamicValOutput(%arg0: !torch.vtensor<[4,5,6],f32>) -> !torch.vtensor<[8,1,?,1],f32> { +func.func @torch.aten$dynamicValOutput(%arg0: !torch.vtensor<[4,5,6],f32>) -> !torch.vtensor<[8,1,?,1],f32> + attributes {torch.assume_strict_symbolic_shapes} +{ %int8 = torch.constant.int 8 %int1 = torch.constant.int 1 %int-1 = torch.constant.int -1 @@ -100,7 +109,7 @@ func.func @torch.aten$dynamicValOutput(%arg0: !torch.vtensor<[4,5,6],f32>) -> !t // ----- // CHECK-LABEL: func.func @torch.aten$dynamicValOutput2( -// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[4,5,6],f32>) -> !torch.vtensor<[2,1,2,3,?],f32> { +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[4,5,6],f32>) -> !torch.vtensor<[2,1,2,3,?],f32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[4,5,6],f32> -> tensor<4x5x6xf32> // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1, 2]] : tensor<4x5x6xf32> into tensor<4x30xf32> // CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2], [3, 4]] output_shape [2, 1, 2, 3, 10] : tensor<4x30xf32> into tensor<2x1x2x3x10xf32> @@ -109,7 +118,9 @@ func.func @torch.aten$dynamicValOutput(%arg0: !torch.vtensor<[4,5,6],f32>) -> !t // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,1,2,3,?],f32> // 4 -> [2,1,2] [5,6] -> [3,10]. -func.func @torch.aten$dynamicValOutput2(%arg0: !torch.vtensor<[4,5,6],f32>) -> !torch.vtensor<[2,1,2,3,?],f32> { +func.func @torch.aten$dynamicValOutput2(%arg0: !torch.vtensor<[4,5,6],f32>) -> !torch.vtensor<[2,1,2,3,?],f32> + attributes {torch.assume_strict_symbolic_shapes} +{ %int2 = torch.constant.int 2 %int1 = torch.constant.int 1 %int3 = torch.constant.int 3 @@ -122,14 +133,16 @@ func.func @torch.aten$dynamicValOutput2(%arg0: !torch.vtensor<[4,5,6],f32>) -> ! // ----- // CHECK-LABEL: func.func @torch.aten.view$expandInferredDim( -// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,6],f32>) -> !torch.vtensor<[3,2,2],f32> { +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,6],f32>) -> !torch.vtensor<[3,2,2],f32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,6],f32> -> tensor<2x6xf32> // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1]] : tensor<2x6xf32> into tensor<12xf32> // CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2]] output_shape [3, 2, 2] : tensor<12xf32> into tensor<3x2x2xf32> // CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPANDED]] : tensor<3x2x2xf32> -> !torch.vtensor<[3,2,2],f32> // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[3,2,2],f32> -func.func @torch.aten.view$expandInferredDim(%arg0: !torch.vtensor<[2,6],f32>) -> !torch.vtensor<[3,2,2],f32> { +func.func @torch.aten.view$expandInferredDim(%arg0: !torch.vtensor<[2,6],f32>) -> !torch.vtensor<[3,2,2],f32> + attributes {torch.assume_strict_symbolic_shapes} +{ %int2 = torch.constant.int 2 %int3 = torch.constant.int 3 %int-1 = torch.constant.int -1 @@ -141,7 +154,7 @@ func.func @torch.aten.view$expandInferredDim(%arg0: !torch.vtensor<[2,6],f32>) - // ----- // CHECK-LABEL: func.func @torch.aten.view$singleUnknownMatches0( -// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[10,3,?,2,3],f32>) -> !torch.vtensor<[2,3,5,?,6],f32> { +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[10,3,?,2,3],f32>) -> !torch.vtensor<[2,3,5,?,6],f32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[10,3,?,2,3],f32> -> tensor<10x3x?x2x3xf32> // CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2], [3, 4]] : tensor<10x3x?x2x3xf32> into tensor<30x?x6xf32> // CHECK: %[[C1:.*]] = arith.constant 1 : index @@ -154,7 +167,9 @@ func.func @torch.aten.view$expandInferredDim(%arg0: !torch.vtensor<[2,6],f32>) - // Associations are, // -- for collapse, [0,1], [2], [3,4] and // -- for expand [0,1,2], [3], [4]. -func.func @torch.aten.view$singleUnknownMatches0(%arg0: !torch.vtensor<[10,3,?,2,3],f32>) -> !torch.vtensor<[2,3,5,?,6],f32> { +func.func @torch.aten.view$singleUnknownMatches0(%arg0: !torch.vtensor<[10,3,?,2,3],f32>) -> !torch.vtensor<[2,3,5,?,6],f32> + attributes {torch.assume_strict_symbolic_shapes} +{ %int3 = torch.constant.int 3 %int2 = torch.constant.int 2 %int6 = torch.constant.int 6 @@ -175,13 +190,15 @@ func.func @torch.aten.view$singleUnknownMatches0(%arg0: !torch.vtensor<[10,3,?,2 // but one which matches between the input and the output // CHECK: func.func @torch.aten.view$combineConcepts( -// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[8,?,?,?,2,1,3],f32>) -> !torch.vtensor<[2,2,2,?,?,?,6],f32> { +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[8,?,?,?,2,1,3],f32>) -> !torch.vtensor<[2,2,2,?,?,?,6],f32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[8,?,?,?,2,1,3],f32> -> tensor<8x?x?x?x2x1x3xf32> // CHECK: %[[RESHAPE:.*]] = tensor.reshape %[[BUILTIN_TENSOR]] // CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[RESHAPE]] : tensor<2x2x2x?x?x?x6xf32> -> !torch.vtensor<[2,2,2,?,?,?,6],f32> // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,2,2,?,?,?,6],f32> -func.func @torch.aten.view$combineConcepts(%arg0 : !torch.vtensor<[8,?,?,?,2,1,3], f32>) -> !torch.vtensor<[2,2,2,?,?,?,6], f32> { +func.func @torch.aten.view$combineConcepts(%arg0 : !torch.vtensor<[8,?,?,?,2,1,3], f32>) -> !torch.vtensor<[2,2,2,?,?,?,6], f32> + attributes {torch.assume_strict_symbolic_shapes} +{ %int1 = torch.constant.int 1 %size1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[8,?,?,?,2,1,3], f32>, !torch.int -> !torch.int @@ -200,12 +217,14 @@ func.func @torch.aten.view$combineConcepts(%arg0 : !torch.vtensor<[8,?,?,?,2,1,3 // ----- // CHECK-LABEL: func.func @torch.aten.view$multiDynamicsInSourceOfCollapse -// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,2,?,4,?],f32>) -> !torch.vtensor<[?],f32> { +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,2,?,4,?],f32>) -> !torch.vtensor<[?],f32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,2,?,4,?],f32> -> tensor // CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1, 2, 3, 4]] : tensor into tensor // CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[COLLAPSE]] : tensor -> !torch.vtensor<[?],f32> // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[?],f32> -func.func @torch.aten.view$multiDynamicsInSourceOfCollapse (%arg0 : !torch.vtensor<[?,2,?,4,?], f32>) -> !torch.vtensor<[?], f32> { +func.func @torch.aten.view$multiDynamicsInSourceOfCollapse (%arg0 : !torch.vtensor<[?,2,?,4,?], f32>) -> !torch.vtensor<[?], f32> + attributes {torch.assume_strict_symbolic_shapes} +{ %int-1 = torch.constant.int -1 %0 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[?,2,?,4,?], f32>, !torch.list -> !torch.vtensor<[?], f32> @@ -215,7 +234,7 @@ func.func @torch.aten.view$multiDynamicsInSourceOfCollapse (%arg0 : !torch.vtens // ----- // CHECK-LABEL: func.func @torch.aten.view$castingView -// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[3,4,5],f32> { +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[3,4,5],f32> // The current lowring only succeeds if the input (arg0) has shape [3,4,5], // determined at runtime. This is a bit limiting, and we'll probably want to @@ -225,7 +244,9 @@ func.func @torch.aten.view$multiDynamicsInSourceOfCollapse (%arg0 : !torch.vtens // CHECK-COUNT-2: cf.assert {{.*}} "mismatching contracting dimension // CHECK: return {{.*}} : !torch.vtensor<[3,4,5],f32> -func.func @torch.aten.view$castingView (%arg0 : !torch.vtensor<[?,?,?], f32>) -> !torch.vtensor<[3,4,5], f32> { +func.func @torch.aten.view$castingView (%arg0 : !torch.vtensor<[?,?,?], f32>) -> !torch.vtensor<[3,4,5], f32> + attributes {torch.assume_strict_symbolic_shapes} +{ %int3 = torch.constant.int 3 %int4 = torch.constant.int 4 %int5 = torch.constant.int 5 @@ -240,7 +261,7 @@ func.func @torch.aten.view$castingView (%arg0 : !torch.vtensor<[?,?,?], f32>) -> // We expect this to lower to a collapse with [0], [1], [2,3] followed by // an expand with [0,1], [2], [3]: // CHECK: func.func @torch.aten.view$dynamicInferredSame( -// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[10,?,2,3],f32>) -> !torch.vtensor<[2,5,?,6],f32> { +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[10,?,2,3],f32>) -> !torch.vtensor<[2,5,?,6],f32> // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[10,?,2,3],f32> -> tensor<10x?x2x3xf32> // CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1], [2, 3]] : tensor<10x?x2x3xf32> into tensor<10x?x6xf32> // CHECK: %[[C1:.*]] = arith.constant 1 : index @@ -249,7 +270,9 @@ func.func @torch.aten.view$castingView (%arg0 : !torch.vtensor<[?,?,?], f32>) -> // CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND]] : tensor<2x5x?x6xf32> -> !torch.vtensor<[2,5,?,6],f32> // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,5,?,6],f32> -func.func @torch.aten.view$dynamicInferredSame(%arg0: !torch.vtensor<[10,?,2,3],f32>) -> !torch.vtensor<[2,5,?,6],f32> { +func.func @torch.aten.view$dynamicInferredSame(%arg0: !torch.vtensor<[10,?,2,3],f32>) -> !torch.vtensor<[2,5,?,6],f32> + attributes {torch.assume_strict_symbolic_shapes} +{ %int2 = torch.constant.int 2 %int5 = torch.constant.int 5 %int6 = torch.constant.int 6 diff --git a/test/Conversion/TorchToLinalg/view_strict.mlir b/test/Conversion/TorchToLinalg/view_strict.mlir new file mode 100644 index 000000000..8be9a2f9f --- /dev/null +++ b/test/Conversion/TorchToLinalg/view_strict.mlir @@ -0,0 +1,150 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s +// Since we want to migrate to the strict view op lowering, these test cases +// verify this one pattern specifically via attributes on the functions that +// disable the legacy behavior. + +// ----- + +// CHECK-LABEL: func.func @torch.aten.view$twotothree +// CHECK: %[[ARG0:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[3,2],f32> -> tensor<3x2xf32> +// CHECK: %[[T3:.*]] = torch.constant.int 3 +// CHECK: %[[T2:.*]] = torch.constant.int 2 +// CHECK: %[[N2:.*]] = torch_c.to_i64 %[[T2]] +// CHECK: %[[N3:.*]] = torch_c.to_i64 %[[T3]] +// CHECK: %[[ELEMENTS:.*]] = tensor.from_elements %[[N2]], %[[N3]] : tensor<2xi64> +// CHECK: %[[RESHAPE:.*]] = tensor.reshape %[[ARG0]](%[[ELEMENTS]]) : (tensor<3x2xf32>, tensor<2xi64>) -> tensor<2x3xf32> +func.func @torch.aten.view$twotothree(%arg0: !torch.vtensor<[3,2],f32>) -> !torch.vtensor<[2,3],f32> + attributes {torch.assume_strict_symbolic_shapes, torch.disable_legacy_view} +{ + %int3 = torch.constant.int 3 + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[3,2],f32>, !torch.list -> !torch.vtensor<[2,3],f32> + return %1 : !torch.vtensor<[2,3],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.view$zerod +// CHECK: %[[ARG0:.*]] = torch_c.to_builtin_tensor %arg0 +// CHECK: tensor.collapse_shape %0 [] : tensor into tensor +func.func @torch.aten.view$zerod(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> + attributes {torch.assume_strict_symbolic_shapes, torch.disable_legacy_view} +{ + %0 = torch.prim.ListConstruct : () -> !torch.list + %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[?,?],f32>, !torch.list -> !torch.vtensor<[],f32> + return %1 : !torch.vtensor<[],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.view$dynamictest +// CHECK: %[[ARG0:.*]] = torch_c.to_builtin_tensor %arg0 +// CHECK: %[[ARG1:.*]] = torch_c.to_i64 %arg1 +// CHECK: %[[ARG2:.*]] = torch_c.to_i64 %arg2 +// CHECK: %[[ELTS:.*]] = tensor.from_elements %[[ARG1]], %[[ARG2]] : tensor<2xi64> +// CHECK: tensor.reshape %[[ARG0]](%[[ELTS]]) : (tensor, tensor<2xi64>) -> tensor +func.func @torch.aten.view$dynamictest(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.vtensor<[?,?],f32> + attributes {torch.assume_strict_symbolic_shapes, torch.disable_legacy_view} +{ + %2 = torch.prim.ListConstruct %arg1, %arg2 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?],f32>, !torch.list -> !torch.vtensor<[?,?],f32> + return %3 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.view$dynamictest2( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,6,?],f32>) -> !torch.vtensor<[?,2,3,?],f32> +// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,6,?],f32> -> tensor +// CHECK: %[[EXPAND:.*]] = tensor.reshape %[[BUILTIN_TENSOR]] +// CHECK: %[[BUILTIN_TENSOR_CAST:.*]] = torch_c.from_builtin_tensor %[[EXPAND]] : tensor -> !torch.vtensor<[?,2,3,?],f32> +// CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[?,2,3,?],f32> + +func.func @torch.aten.view$dynamictest2(%arg0: !torch.vtensor<[?,6,?],f32>) -> !torch.vtensor<[?,2,3,?],f32> + attributes {torch.assume_strict_symbolic_shapes, torch.disable_legacy_view} +{ + %int3 = torch.constant.int 3 + %int2 = torch.constant.int 2 + %int0 = torch.constant.int 0 + %2 = torch.aten.size.int %arg0, %int2 : !torch.vtensor<[?,6,?],f32>, !torch.int -> !torch.int + %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,6,?],f32>, !torch.int -> !torch.int + %1 = torch.prim.ListConstruct %0, %int2, %int3, %2 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3 = torch.aten.view %arg0, %1 : !torch.vtensor<[?,6,?],f32>, !torch.list -> !torch.vtensor<[?,2,3,?], f32> + return %3 : !torch.vtensor<[?,2,3,?], f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.view$dynamicVal( +// CHECK: tensor.reshape {{.*}} : (tensor<1x?x128xf32>, tensor<3xi64>) -> tensor<16x1x128xf32> +func.func @torch.aten.view$dynamicVal(%arg0: !torch.vtensor<[1,?,128],f32>) -> !torch.vtensor<[16,1,128],f32> + attributes {torch.assume_strict_symbolic_shapes, torch.disable_legacy_view} +{ + %int128 = torch.constant.int 128 + %int1 = torch.constant.int 1 + %int16 = torch.constant.int 16 + %0 = torch.prim.ListConstruct %int16, %int1, %int128 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[16,1,128],f32> + return %1 : !torch.vtensor<[16,1,128],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.view$expandInferredDim +// CHECK: %[[ARG0:.*]] = torch_c.to_builtin_tensor %arg0 +// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG0]] {{\[\[}}0, 1]] : tensor<2x6xf32> into tensor<12xf32> +// CHECK: %[[CAST1:.*]] = tensor.cast %[[COLLAPSED]] : tensor<12xf32> to tensor<12xf32> +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[CAST1]] {{\[\[}}0, 1, 2]] output_shape [3, 2, 2] : tensor<12xf32> into tensor<3x2x2xf32> +func.func @torch.aten.view$expandInferredDim(%arg0: !torch.vtensor<[2,6],f32>) -> !torch.vtensor<[3,2,2],f32> + attributes {torch.assume_strict_symbolic_shapes, torch.disable_legacy_view} +{ + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %int-1 = torch.constant.int -1 + %0 = torch.prim.ListConstruct %int3, %int2, %int-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[2,6],f32>, !torch.list -> !torch.vtensor<[3,2,2],f32> + return %1 : !torch.vtensor<[3,2,2],f32> +} + +// ----- +// Note that this is presently going down a fallback path as an explicit +// reshape. Someday, this should generate flatten/unflatten. +// CHECK-LABEL: func.func @torch.aten$dynamicValOutput +// CHECK: %[[SELF:.*]] = torch_c.to_builtin_tensor %arg0 +// CHECK: %[[CONSTANT1:.*]] = torch.constant.int 1 +// CHECK-DAG: %[[PROD1:.*]] = arith.constant 1 +// CHECK-DAG: %[[ARG1_CVT:.*]] = torch_c.to_i64 %arg1 +// CHECK-DAG: %[[PROD2:.*]] = arith.muli %[[PROD1]], %[[ARG1_CVT]] +// CHECK-DAG: %[[ONEI64:.*]] = torch_c.to_i64 %[[CONSTANT1]] +// CHECK-DAG: %[[PROD3:.*]] = arith.muli %[[PROD2]], %[[ONEI64]] +// CHECK-DAG: %[[ONEI64_0:.*]] = torch_c.to_i64 %[[CONSTANT1]] +// CHECK-DAG: %[[PROD4:.*]] = arith.muli %[[PROD3]], %[[ONEI64_0]] +// CHECK-DAG: %[[INDEX0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[DIM0_INDEX:.*]] = tensor.dim %[[SELF]], %[[INDEX0]] : tensor +// CHECK-DAG: %[[DIM0:.*]] = arith.index_cast %[[DIM0_INDEX]] : index to i64 +// CHECK-DAG: %[[KNOWN0:.*]] = arith.muli %[[PROD1]], %[[DIM0]] : i64 +// CHECK-DAG: %[[INDEX1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[DIM1_INDEX:.*]] = tensor.dim %[[SELF]], %[[INDEX1]] : tensor +// CHECK-DAG: %[[DIM1:.*]] = arith.index_cast %[[DIM1_INDEX]] : index to i64 +// CHECK-DAG: %[[KNOWN1:.*]] = arith.muli %[[KNOWN0]], %[[DIM1]] : i64 +// CHECK-DAG: %[[INDEX2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[DIM2_INDEX:.*]] = tensor.dim %[[SELF]], %[[INDEX2]] : tensor +// CHECK-DAG: %[[DIM2:.*]] = arith.index_cast %[[DIM2_INDEX]] : index to i64 +// CHECK-DAG: %[[KNOWN2:.*]] = arith.muli %[[KNOWN1]], %[[DIM2]] : i64 +// CHECK-DAG: %[[DIMINFER:.*]] = arith.divui %[[KNOWN2]], %[[PROD4]] : i64 +// CHECK: %[[DIM0:.*]] = torch_c.to_i64 %arg1 +// CHECK: %[[DIM1:.*]] = torch_c.to_i64 %[[CONSTANT1]] +// CHECK: %[[DIM3:.*]] = torch_c.to_i64 %[[CONSTANT1]] +// CHECK: %[[OUTPUT_DIMS:.*]] = tensor.from_elements %[[DIM0]], %[[DIM1]], %[[DIMINFER]], %[[DIM3]] : tensor<4xi64> +// CHECK: tensor.reshape %[[SELF]](%[[OUTPUT_DIMS]]) : (tensor, tensor<4xi64>) -> tensor +// +func.func @torch.aten$dynamicValOutput(%arg0: !torch.vtensor<[?, ?, ?],f32>, %arg1: !torch.int) -> !torch.vtensor<[?,1,?,1],f32> + attributes {torch.assume_strict_symbolic_shapes, torch.disable_legacy_view} +{ + %int1 = torch.constant.int 1 + %int-1 = torch.constant.int -1 + %0 = torch.prim.ListConstruct %arg1, %int1, %int-1, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[?, ?, ?],f32>, !torch.list -> !torch.vtensor<[?,1,?,1],f32> + return %1 : !torch.vtensor<[?,1,?,1],f32> +}