mirror of https://github.com/llvm/torch-mlir
[MHLO] Eliminate explicit dynamic output shape generating in converting AtenSliceTensorOp (#1245)
[MHLO] Eliminate explicit dynamic output shape generating in converting AtenSliceTensorOppull/1250/head
parent
0d1aa43764
commit
7bd173a1c4
|
@ -53,8 +53,8 @@ Value getNormalizedDimSizeInternal(PatternRewriter &rewriter, Operation *op,
|
|||
}
|
||||
|
||||
Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op,
|
||||
Value input, Value startIndex, Value endIndex,
|
||||
Value step, size_t dimIndex,
|
||||
Type outTy, Value input, Value startIndex,
|
||||
Value endIndex, Value step, size_t dimIndex,
|
||||
ArrayRef<Value> dimSizes) {
|
||||
auto loc = op->getLoc();
|
||||
// startIndex & endIndex has been normailized into range [0, dSize]
|
||||
|
@ -98,20 +98,15 @@ Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op,
|
|||
auto stridesTensor =
|
||||
rewriter.create<tensor::FromElementsOp>(loc, strides).getResult();
|
||||
|
||||
auto inputShape = inputTy.getShape();
|
||||
SmallVector<int64_t, 4> sliceShape(inputShape.begin(), inputShape.end());
|
||||
sliceShape[dimIndex] = ShapedType::kDynamicSize;
|
||||
auto sliceoutputTy =
|
||||
RankedTensorType::get(sliceShape, inputTy.getElementType());
|
||||
return rewriter.create<mhlo::RealDynamicSliceOp>(
|
||||
loc, sliceoutputTy, input, startTensor, endTensor, stridesTensor);
|
||||
loc, outTy, input, startTensor, endTensor, stridesTensor);
|
||||
}
|
||||
|
||||
// Get a dynamic slice of the tensor from startIndex to endIndex with stride
|
||||
// step on the specifed dimension. The input startIndex(default to 0),
|
||||
// endIndex(default to dimSize), and step(default to 1) can be optional.
|
||||
FailureOr<Value> getDynamicSlice(PatternRewriter &rewriter, Operation *op,
|
||||
Value input,
|
||||
Type outTy, Value input,
|
||||
llvm::Optional<Value> startIndexOpt,
|
||||
llvm::Optional<Value> endIndexOpt,
|
||||
llvm::Optional<Value> stepOpt, int64_t dim) {
|
||||
|
@ -152,7 +147,7 @@ FailureOr<Value> getDynamicSlice(PatternRewriter &rewriter, Operation *op,
|
|||
op, "failed to get dimension sizes of the input");
|
||||
|
||||
auto dimSizes = *dimSizesInfo;
|
||||
return getDynamicSliceInternal(rewriter, op, input, normStartIndex,
|
||||
return getDynamicSliceInternal(rewriter, op, outTy, input, normStartIndex,
|
||||
normEndIndex, step, dim, dimSizes);
|
||||
}
|
||||
|
||||
|
@ -174,6 +169,8 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
|
|||
auto selfTy = self.getType().template cast<RankedTensorType>();
|
||||
if (!selfTy)
|
||||
return op.emitError("only ranked tensor types are supported");
|
||||
auto outTy =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -192,14 +189,12 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
|
|||
llvm::Optional<Value> step = getOptionalVal(adaptor.step());
|
||||
|
||||
FailureOr<Value> sliceInfo =
|
||||
getDynamicSlice(rewriter, op, self, start, end, step, dim);
|
||||
getDynamicSlice(rewriter, op, outTy, self, start, end, step, dim);
|
||||
if (failed(sliceInfo))
|
||||
return op.emitError("can not create a dynmaic slice");
|
||||
|
||||
auto slice = *sliceInfo;
|
||||
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), slice);
|
||||
|
||||
rewriter.replaceOp(op, slice);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -207,14 +202,13 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
|
|||
// specialized.
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenViewOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
AtenOpT op,
|
||||
OpAdaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto rankType =
|
||||
adaptor.self().getType().template dyn_cast<RankedTensorType>();
|
||||
if (!rankType)
|
||||
|
@ -236,18 +230,19 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenOpT> {
|
|||
return success();
|
||||
}
|
||||
|
||||
std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value& dSize) {
|
||||
std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value &dSize) {
|
||||
dSize = rewriter.create<ToI64Op>(loc, dSize).getResult();
|
||||
return dSize;
|
||||
});
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
|
||||
// The i64 calculation is much slower than i32 on some devices, such as Nvidia GPU.
|
||||
// One can truncate from i64 to i32 since dimension sizes are unlikely to exceed
|
||||
// the range of i32(4GiB)
|
||||
std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value& dSize) {
|
||||
// The i64 calculation is much slower than i32 on some devices, such as
|
||||
// Nvidia GPU. One can truncate from i64 to i32 since dimension sizes are
|
||||
// unlikely to exceed the range of i32(4GiB)
|
||||
std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value &dSize) {
|
||||
// dimSize: cast i64 -> i32
|
||||
dSize = rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), dSize);
|
||||
dSize =
|
||||
rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), dSize);
|
||||
return dSize;
|
||||
});
|
||||
#endif
|
||||
|
@ -272,28 +267,22 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenOpT> {
|
|||
return success();
|
||||
}
|
||||
|
||||
bool getAtenViewOpSizes(
|
||||
AtenOpT op,
|
||||
OpAdaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
SmallVector<Value, 4>& dimSizes) const;
|
||||
bool getAtenViewOpSizes(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
SmallVector<Value, 4> &dimSizes) const;
|
||||
};
|
||||
|
||||
template <>
|
||||
bool ConvertAtenViewOp<AtenViewOp>::getAtenViewOpSizes(
|
||||
AtenViewOp op,
|
||||
OpAdaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
SmallVector<Value, 4>& dimSizes) const {
|
||||
AtenViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
|
||||
SmallVector<Value, 4> &dimSizes) const {
|
||||
return getListConstructElements(adaptor.size(), dimSizes);
|
||||
}
|
||||
|
||||
template <>
|
||||
bool ConvertAtenViewOp<AtenReshapeOp>::getAtenViewOpSizes(
|
||||
AtenReshapeOp op,
|
||||
OpAdaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
SmallVector<Value, 4>& dimSizes) const {
|
||||
AtenReshapeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
|
||||
SmallVector<Value, 4> &dimSizes) const {
|
||||
return getListConstructElements(adaptor.shape(), dimSizes);
|
||||
}
|
||||
|
||||
|
@ -415,10 +404,10 @@ void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
|
|||
INSERT_ATENOP_PATTERN(AtenUnsqueezeOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_VIEW_OP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
#define INSERT_VIEW_OP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenViewOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_VIEW_OP_PATTERN(AtenViewOp);
|
||||
INSERT_VIEW_OP_PATTERN(AtenReshapeOp);
|
||||
INSERT_VIEW_OP_PATTERN(AtenViewOp);
|
||||
INSERT_VIEW_OP_PATTERN(AtenReshapeOp);
|
||||
#undef INSERT_VIEW_OP_PATTERN
|
||||
}
|
||||
|
|
|
@ -43,9 +43,8 @@
|
|||
// CHECK: %[[T27:.*]] = tensor.from_elements %[[T25]], %[[T21]], %[[T23]] : tensor<3xi64>
|
||||
// CHECK: %[[T28:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64>
|
||||
// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[T30:.*]] = mhlo.convert %[[T29]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[T31:.*]] = torch_c.from_builtin_tensor %[[T30]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: return %[[T31]] : !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: %[[T30:.*]] = torch_c.from_builtin_tensor %[[T29]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: return %[[T30]] : !torch.vtensor<[?,?,?],f32>
|
||||
func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
%int2 = torch.constant.int 2
|
||||
|
@ -96,10 +95,9 @@ func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32
|
|||
// CHECK: %[[T26:.*]] = tensor.from_elements %[[T11]], %[[C0_I64_2]], %[[C0_I64_2]] : tensor<3xi64>
|
||||
// CHECK: %[[T27:.*]] = tensor.from_elements %[[T25]], %[[T21]], %[[T23]] : tensor<3xi64>
|
||||
// CHECK: %[[T28:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64>
|
||||
// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x65x256xf32>
|
||||
// CHECK: %[[T30:.*]] = mhlo.convert(%[[T29]]) : (tensor<?x65x256xf32>) -> tensor<2x65x256xf32>
|
||||
// CHECK: %[[T31:.*]] = torch_c.from_builtin_tensor %[[T30]] : tensor<2x65x256xf32> -> !torch.vtensor<[2,65,256],f32>
|
||||
// CHECK: return %[[T31]] : !torch.vtensor<[2,65,256],f32>
|
||||
// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<2x65x256xf32>
|
||||
// CHECK: %[[T30:.*]] = torch_c.from_builtin_tensor %[[T29]] : tensor<2x65x256xf32> -> !torch.vtensor<[2,65,256],f32>
|
||||
// CHECK: return %[[T30]] : !torch.vtensor<[2,65,256],f32>
|
||||
func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[2,65,256],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
%int2 = torch.constant.int 2
|
||||
|
@ -151,10 +149,9 @@ func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,6
|
|||
// CHECK: %[[T26:.*]] = tensor.from_elements %[[C0_I64_2]], %[[T11]], %[[C0_I64_2]] : tensor<3xi64>
|
||||
// CHECK: %[[T27:.*]] = tensor.from_elements %[[T19]], %[[T25]], %[[T23]] : tensor<3xi64>
|
||||
// CHECK: %[[T28:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64>
|
||||
// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[T30:.*]] = mhlo.convert(%[[T29]]) : (tensor<?x?x?xf32>) -> tensor<?x1x?xf32>
|
||||
// CHECK: %[[T31:.*]] = torch_c.from_builtin_tensor %[[T30]] : tensor<?x1x?xf32> -> !torch.vtensor<[?,1,?],f32>
|
||||
// CHECK: return %[[T31]] : !torch.vtensor<[?,1,?],f32>
|
||||
// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x1x?xf32>
|
||||
// CHECK: %[[T30:.*]] = torch_c.from_builtin_tensor %[[T29]] : tensor<?x1x?xf32> -> !torch.vtensor<[?,1,?],f32>
|
||||
// CHECK: return %[[T30]] : !torch.vtensor<[?,1,?],f32>
|
||||
func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,1,?],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
%int1 = torch.constant.int 1
|
||||
|
@ -206,10 +203,9 @@ func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>)
|
|||
// CHECK: %[[T26:.*]] = tensor.from_elements %[[C0_I64_2]], %[[T11]], %[[C0_I64_2]] : tensor<3xi64>
|
||||
// CHECK: %[[T27:.*]] = tensor.from_elements %[[T19]], %[[T25]], %[[T23]] : tensor<3xi64>
|
||||
// CHECK: %[[T28:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64>
|
||||
// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x?x256xf32>
|
||||
// CHECK: %[[T30:.*]] = mhlo.convert(%[[T29]]) : (tensor<4x?x256xf32>) -> tensor<4x1x256xf32>
|
||||
// CHECK: %[[T31:.*]] = torch_c.from_builtin_tensor %[[T30]] : tensor<4x1x256xf32> -> !torch.vtensor<[4,1,256],f32>
|
||||
// CHECK: return %[[T31]] : !torch.vtensor<[4,1,256],f32>
|
||||
// CHECK: %[[T29:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T26]], %[[T27]], %[[T28]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x1x256xf32>
|
||||
// CHECK: %[[T30:.*]] = torch_c.from_builtin_tensor %[[T29]] : tensor<4x1x256xf32> -> !torch.vtensor<[4,1,256],f32>
|
||||
// CHECK: return %[[T30]] : !torch.vtensor<[4,1,256],f32>
|
||||
func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,1,256],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
%int1 = torch.constant.int 1
|
||||
|
@ -247,9 +243,8 @@ func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,2
|
|||
// CHECK: %[[T13:.*]] = tensor.from_elements %[[T5]], %[[T11]], %[[T9]] : tensor<3xi64>
|
||||
// CHECK: %[[T14:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64>
|
||||
// CHECK: %[[T15:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T12]], %[[T13]], %[[T14]] : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[T16:.*]] = mhlo.convert %[[T15]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[T17:.*]] = torch_c.from_builtin_tensor %[[T16]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: return %[[T17]] : !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: %[[T16:.*]] = torch_c.from_builtin_tensor %[[T15]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: return %[[T16]] : !torch.vtensor<[?,?,?],f32>
|
||||
func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||
%int1 = torch.constant.int 1
|
||||
%int2 = torch.constant.int 2
|
||||
|
@ -285,10 +280,9 @@ func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>)
|
|||
// CHECK: %[[T12:.*]] = tensor.from_elements %[[C0_I64_1]], %[[C0_I64]], %[[C0_I64_1]] : tensor<3xi64>
|
||||
// CHECK: %[[T13:.*]] = tensor.from_elements %[[T5]], %[[T11]], %[[T9]] : tensor<3xi64>
|
||||
// CHECK: %[[T14:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64>
|
||||
// CHECK: %[[T15:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T12]], %[[T13]], %[[T14]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x?x256xf32>
|
||||
// CHECK: %[[T16:.*]] = mhlo.convert(%[[T15]]) : (tensor<4x?x256xf32>) -> tensor<4x33x256xf32>
|
||||
// CHECK: %[[T17:.*]] = torch_c.from_builtin_tensor %[[T16]] : tensor<4x33x256xf32> -> !torch.vtensor<[4,33,256],f32>
|
||||
// CHECK: return %[[T17]] : !torch.vtensor<[4,33,256],f32>
|
||||
// CHECK: %[[T15:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[T12]], %[[T13]], %[[T14]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x33x256xf32>
|
||||
// CHECK: %[[T16:.*]] = torch_c.from_builtin_tensor %[[T15]] : tensor<4x33x256xf32> -> !torch.vtensor<[4,33,256],f32>
|
||||
// CHECK: return %[[T16]] : !torch.vtensor<[4,33,256],f32>
|
||||
func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,33,256],f32> {
|
||||
%int1 = torch.constant.int 1
|
||||
%int2 = torch.constant.int 2
|
||||
|
|
Loading…
Reference in New Issue