diff --git a/lib/Conversion/TorchToMhlo/ViewLike.cpp b/lib/Conversion/TorchToMhlo/ViewLike.cpp index d72f4d013..0502694cd 100644 --- a/lib/Conversion/TorchToMhlo/ViewLike.cpp +++ b/lib/Conversion/TorchToMhlo/ViewLike.cpp @@ -53,8 +53,8 @@ Value getNormalizedDimSizeInternal(PatternRewriter &rewriter, Operation *op, } Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op, - Value input, Value startIndex, Value endIndex, - Value step, size_t dimIndex, + Type outTy, Value input, Value startIndex, + Value endIndex, Value step, size_t dimIndex, ArrayRef dimSizes) { auto loc = op->getLoc(); // startIndex & endIndex has been normailized into range [0, dSize] @@ -98,20 +98,15 @@ Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op, auto stridesTensor = rewriter.create(loc, strides).getResult(); - auto inputShape = inputTy.getShape(); - SmallVector sliceShape(inputShape.begin(), inputShape.end()); - sliceShape[dimIndex] = ShapedType::kDynamicSize; - auto sliceoutputTy = - RankedTensorType::get(sliceShape, inputTy.getElementType()); return rewriter.create( - loc, sliceoutputTy, input, startTensor, endTensor, stridesTensor); + loc, outTy, input, startTensor, endTensor, stridesTensor); } // Get a dynamic slice of the tensor from startIndex to endIndex with stride // step on the specifed dimension. The input startIndex(default to 0), // endIndex(default to dimSize), and step(default to 1) can be optional. FailureOr getDynamicSlice(PatternRewriter &rewriter, Operation *op, - Value input, + Type outTy, Value input, llvm::Optional startIndexOpt, llvm::Optional endIndexOpt, llvm::Optional stepOpt, int64_t dim) { @@ -152,7 +147,7 @@ FailureOr getDynamicSlice(PatternRewriter &rewriter, Operation *op, op, "failed to get dimension sizes of the input"); auto dimSizes = *dimSizesInfo; - return getDynamicSliceInternal(rewriter, op, input, normStartIndex, + return getDynamicSliceInternal(rewriter, op, outTy, input, normStartIndex, normEndIndex, step, dim, dimSizes); } @@ -174,6 +169,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto selfTy = self.getType().template cast(); if (!selfTy) return op.emitError("only ranked tensor types are supported"); + auto outTy = + getTypeConverter()->convertType(op.getType()).cast(); int64_t dim; if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure( @@ -192,14 +189,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( llvm::Optional step = getOptionalVal(adaptor.step()); FailureOr sliceInfo = - getDynamicSlice(rewriter, op, self, start, end, step, dim); + getDynamicSlice(rewriter, op, outTy, self, start, end, step, dim); if (failed(sliceInfo)) return op.emitError("can not create a dynmaic slice"); auto slice = *sliceInfo; - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), slice); - + rewriter.replaceOp(op, slice); return success(); } @@ -207,14 +202,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // specialized. template class ConvertAtenViewOp : public OpConversionPattern { - public: +public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; - LogicalResult matchAndRewrite( - AtenOpT op, - OpAdaptor adaptor, - ConversionPatternRewriter& rewriter) const override { + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { auto rankType = adaptor.self().getType().template dyn_cast(); if (!rankType) @@ -236,18 +230,19 @@ class ConvertAtenViewOp : public OpConversionPattern { return success(); } - std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value& dSize) { + std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value &dSize) { dSize = rewriter.create(loc, dSize).getResult(); return dSize; }); #ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32 - // The i64 calculation is much slower than i32 on some devices, such as Nvidia GPU. - // One can truncate from i64 to i32 since dimension sizes are unlikely to exceed - // the range of i32(4GiB) - std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value& dSize) { + // The i64 calculation is much slower than i32 on some devices, such as + // Nvidia GPU. One can truncate from i64 to i32 since dimension sizes are + // unlikely to exceed the range of i32(4GiB) + std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value &dSize) { // dimSize: cast i64 -> i32 - dSize = rewriter.create(loc, rewriter.getI32Type(), dSize); + dSize = + rewriter.create(loc, rewriter.getI32Type(), dSize); return dSize; }); #endif @@ -272,28 +267,22 @@ class ConvertAtenViewOp : public OpConversionPattern { return success(); } - bool getAtenViewOpSizes( - AtenOpT op, - OpAdaptor adaptor, - ConversionPatternRewriter& rewriter, - SmallVector& dimSizes) const; + bool getAtenViewOpSizes(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + SmallVector &dimSizes) const; }; template <> bool ConvertAtenViewOp::getAtenViewOpSizes( - AtenViewOp op, - OpAdaptor adaptor, - ConversionPatternRewriter& rewriter, - SmallVector& dimSizes) const { + AtenViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, + SmallVector &dimSizes) const { return getListConstructElements(adaptor.size(), dimSizes); } template <> bool ConvertAtenViewOp::getAtenViewOpSizes( - AtenReshapeOp op, - OpAdaptor adaptor, - ConversionPatternRewriter& rewriter, - SmallVector& dimSizes) const { + AtenReshapeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, + SmallVector &dimSizes) const { return getListConstructElements(adaptor.shape(), dimSizes); } @@ -415,10 +404,10 @@ void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenUnsqueezeOp); #undef INSERT_ATENOP_PATTERN -#define INSERT_VIEW_OP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ +#define INSERT_VIEW_OP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ patterns.add>(typeConverter, context); - INSERT_VIEW_OP_PATTERN(AtenViewOp); - INSERT_VIEW_OP_PATTERN(AtenReshapeOp); + INSERT_VIEW_OP_PATTERN(AtenViewOp); + INSERT_VIEW_OP_PATTERN(AtenReshapeOp); #undef INSERT_VIEW_OP_PATTERN } diff --git a/test/Conversion/TorchToMhlo/view_like.mlir b/test/Conversion/TorchToMhlo/view_like.mlir index 4783e9940..1c878a5a5 100644 --- a/test/Conversion/TorchToMhlo/view_like.mlir +++ b/test/Conversion/TorchToMhlo/view_like.mlir @@ -43,9 +43,8 @@ // CHECK: %[[T27:.*]] = tensor.from_elements %[[T25]], %[[T21]], %[[T23]] : tensor<3xi64> // CHECK: %[[T28:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64> // CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor -// CHECK: %[[T30:.*]] = mhlo.convert %[[T29]] : tensor -// CHECK: %[[T31:.*]] = torch_c.from_builtin_tensor %[[T30]] : tensor -> !torch.vtensor<[?,?,?],f32> -// CHECK: return %[[T31]] : !torch.vtensor<[?,?,?],f32> +// CHECK: %[[T30:.*]] = torch_c.from_builtin_tensor %[[T29]] : tensor -> !torch.vtensor<[?,?,?],f32> +// CHECK: return %[[T30]] : !torch.vtensor<[?,?,?],f32> func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 @@ -96,10 +95,9 @@ func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32 // CHECK: %[[T26:.*]] = tensor.from_elements %[[T11]], %[[C0_I64_2]], %[[C0_I64_2]] : tensor<3xi64> // CHECK: %[[T27:.*]] = tensor.from_elements %[[T25]], %[[T21]], %[[T23]] : tensor<3xi64> // CHECK: %[[T28:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor -// CHECK: %[[T30:.*]] = mhlo.convert(%[[T29]]) : (tensor) -> tensor<2x65x256xf32> -// CHECK: %[[T31:.*]] = torch_c.from_builtin_tensor %[[T30]] : tensor<2x65x256xf32> -> !torch.vtensor<[2,65,256],f32> -// CHECK: return %[[T31]] : !torch.vtensor<[2,65,256],f32> +// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<2x65x256xf32> +// CHECK: %[[T30:.*]] = torch_c.from_builtin_tensor %[[T29]] : tensor<2x65x256xf32> -> !torch.vtensor<[2,65,256],f32> +// CHECK: return %[[T30]] : !torch.vtensor<[2,65,256],f32> func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[2,65,256],f32> { %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 @@ -151,10 +149,9 @@ func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,6 // CHECK: %[[T26:.*]] = tensor.from_elements %[[C0_I64_2]], %[[T11]], %[[C0_I64_2]] : tensor<3xi64> // CHECK: %[[T27:.*]] = tensor.from_elements %[[T19]], %[[T25]], %[[T23]] : tensor<3xi64> // CHECK: %[[T28:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor -// CHECK: %[[T30:.*]] = mhlo.convert(%[[T29]]) : (tensor) -> tensor -// CHECK: %[[T31:.*]] = torch_c.from_builtin_tensor %[[T30]] : tensor -> !torch.vtensor<[?,1,?],f32> -// CHECK: return %[[T31]] : !torch.vtensor<[?,1,?],f32> +// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor +// CHECK: %[[T30:.*]] = torch_c.from_builtin_tensor %[[T29]] : tensor -> !torch.vtensor<[?,1,?],f32> +// CHECK: return %[[T30]] : !torch.vtensor<[?,1,?],f32> func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,1,?],f32> { %int0 = torch.constant.int 0 %int1 = torch.constant.int 1 @@ -206,10 +203,9 @@ func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) // CHECK: %[[T26:.*]] = tensor.from_elements %[[C0_I64_2]], %[[T11]], %[[C0_I64_2]] : tensor<3xi64> // CHECK: %[[T27:.*]] = tensor.from_elements %[[T19]], %[[T25]], %[[T23]] : tensor<3xi64> // CHECK: %[[T28:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x?x256xf32> -// CHECK: %[[T30:.*]] = mhlo.convert(%[[T29]]) : (tensor<4x?x256xf32>) -> tensor<4x1x256xf32> -// CHECK: %[[T31:.*]] = torch_c.from_builtin_tensor %[[T30]] : tensor<4x1x256xf32> -> !torch.vtensor<[4,1,256],f32> -// CHECK: return %[[T31]] : !torch.vtensor<[4,1,256],f32> +// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x1x256xf32> +// CHECK: %[[T30:.*]] = torch_c.from_builtin_tensor %[[T29]] : tensor<4x1x256xf32> -> !torch.vtensor<[4,1,256],f32> +// CHECK: return %[[T30]] : !torch.vtensor<[4,1,256],f32> func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,1,256],f32> { %int0 = torch.constant.int 0 %int1 = torch.constant.int 1 @@ -247,9 +243,8 @@ func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,2 // CHECK: %[[T13:.*]] = tensor.from_elements %[[T5]], %[[T11]], %[[T9]] : tensor<3xi64> // CHECK: %[[T14:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64> // CHECK: %[[T15:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T12]], %[[T13]], %[[T14]] : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor -// CHECK: %[[T16:.*]] = mhlo.convert %[[T15]] : tensor -// CHECK: %[[T17:.*]] = torch_c.from_builtin_tensor %[[T16]] : tensor -> !torch.vtensor<[?,?,?],f32> -// CHECK: return %[[T17]] : !torch.vtensor<[?,?,?],f32> +// CHECK: %[[T16:.*]] = torch_c.from_builtin_tensor %[[T15]] : tensor -> !torch.vtensor<[?,?,?],f32> +// CHECK: return %[[T16]] : !torch.vtensor<[?,?,?],f32> func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { %int1 = torch.constant.int 1 %int2 = torch.constant.int 2 @@ -285,10 +280,9 @@ func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) // CHECK: %[[T12:.*]] = tensor.from_elements %[[C0_I64_1]], %[[C0_I64]], %[[C0_I64_1]] : tensor<3xi64> // CHECK: %[[T13:.*]] = tensor.from_elements %[[T5]], %[[T11]], %[[T9]] : tensor<3xi64> // CHECK: %[[T14:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T15:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T12]], %[[T13]], %[[T14]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x?x256xf32> -// CHECK: %[[T16:.*]] = mhlo.convert(%[[T15]]) : (tensor<4x?x256xf32>) -> tensor<4x33x256xf32> -// CHECK: %[[T17:.*]] = torch_c.from_builtin_tensor %[[T16]] : tensor<4x33x256xf32> -> !torch.vtensor<[4,33,256],f32> -// CHECK: return %[[T17]] : !torch.vtensor<[4,33,256],f32> +// CHECK: %[[T15:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T12]], %[[T13]], %[[T14]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x33x256xf32> +// CHECK: %[[T16:.*]] = torch_c.from_builtin_tensor %[[T15]] : tensor<4x33x256xf32> -> !torch.vtensor<[4,33,256],f32> +// CHECK: return %[[T16]] : !torch.vtensor<[4,33,256],f32> func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,33,256],f32> { %int1 = torch.constant.int 1 %int2 = torch.constant.int 2