[stablehlo] add aten left/right shift op conversion support (#3234)

pull/3238/head
penguin_wwy 2024-04-26 09:20:49 +08:00 committed by GitHub
parent cd33d8b011
commit 122eb69a98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 67 additions and 0 deletions

View File

@ -1978,6 +1978,36 @@ LogicalResult ConvertAtenOp<AtenFmodTensorOp>::matchAndRewrite(
return success();
}
// AtenBitwiseLeftShiftTensorOp
template <>
LogicalResult ConvertAtenOp<AtenBitwiseLeftShiftTensorOp>::matchAndRewrite(
AtenBitwiseLeftShiftTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value lhs = adaptor.getSelf();
Value rhs = adaptor.getOther();
auto resultType =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType);
rewriter.replaceOpWithNewOp<stablehlo::ShiftLeftOp>(op, lhs, rhs);
return success();
}
// AtenBitwiseRightShiftTensorOp
template <>
LogicalResult ConvertAtenOp<AtenBitwiseRightShiftTensorOp>::matchAndRewrite(
AtenBitwiseRightShiftTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value lhs = adaptor.getSelf();
Value rhs = adaptor.getOther();
auto resultType =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType);
rewriter.replaceOpWithNewOp<stablehlo::ShiftRightArithmeticOp>(op, lhs, rhs);
return success();
}
void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options) {
@ -2137,6 +2167,8 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenFlipOp);
INSERT_ATENOP_PATTERN(AtenRemainderTensorOp);
INSERT_ATENOP_PATTERN(AtenFmodTensorOp);
INSERT_ATENOP_PATTERN(AtenBitwiseLeftShiftTensorOp);
INSERT_ATENOP_PATTERN(AtenBitwiseRightShiftTensorOp);
#undef INSERT_ATENOP_PATTERN
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \

View File

@ -1009,6 +1009,10 @@ STABLEHLO_PASS_SET = {
"ElementwiseBitwiseNotInt64Module_basic",
"ElementwiseBitwiseOrStaticShapeModule_basic",
"ElementwiseBitwiseXorStaticShapeModule_basic",
"ElementwiseBitwiseLeftShiftInt64Module_basic",
"ElementwiseBitwiseLeftShiftInt8Module_basic",
"ElementwiseBitwiseRightShiftInt64Module_basic",
"ElementwiseBitwiseRightShiftInt8Module_basic",
"ElementwiseCeilModule_basic",
"ElementwiseClampMaxModule_basic",
"ElementwiseClampMinModule_basic",

View File

@ -315,3 +315,34 @@ func.func @torch.aten.uniform(%arg0: !torch.vtensor<[32, 64],f64>) -> !torch.vte
%0 = torch.aten.uniform %arg0, %float0, %float1, %none : !torch.vtensor<[32, 64],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[32, 64],f64>
return %0 : !torch.vtensor<[32, 64],f64>
}
// -----
// CHECK-LABEL: func.func @torch.aten.bitwise_left_shift.Tensor(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[3,4],si32>,
// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[3,1],si32>) -> !torch.vtensor<[3,4],si32> {
// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0:.*]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32>
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1:.*]] : !torch.vtensor<[3,1],si32> -> tensor<3x1xi32>
// CHECK: %[[VAL_2:.*]] = stablehlo.broadcast_in_dim %[[VAL_1:.*]], dims = [0, 1] : (tensor<3x1xi32>) -> tensor<3x4xi32>
// CHECK: %[[VAL_3:.*]] = stablehlo.shift_left %[[VAL_0:.*]], %[[VAL_2:.*]] : tensor<3x4xi32>
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3:.*]] : tensor<3x4xi32> -> !torch.vtensor<[3,4],si32>
// CHECK: return %[[VAL_4:.*]] : !torch.vtensor<[3,4],si32>
func.func @torch.aten.bitwise_left_shift.Tensor(%arg0: !torch.vtensor<[3,4],si32>, %arg1: !torch.vtensor<[3,1],si32>) -> !torch.vtensor<[3,4],si32> {
%0 = torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],si32>, !torch.vtensor<[3,1],si32> -> !torch.vtensor<[3,4],si32>
return %0 : !torch.vtensor<[3,4],si32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.bitwise_right_shift.Tensor(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[3,4],si64>,
// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[3,4],si64>) -> !torch.vtensor<[3,4],si64> {
// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0:.*]] : !torch.vtensor<[3,4],si64> -> tensor<3x4xi64>
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1:.*]] : !torch.vtensor<[3,4],si64> -> tensor<3x4xi64>
// CHECK: %[[VAL_2:.*]] = stablehlo.shift_right_arithmetic %[[VAL_0:.*]], %[[VAL_1:.*]] : tensor<3x4xi64>
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2:.*]] : tensor<3x4xi64> -> !torch.vtensor<[3,4],si64>
// CHECK: return %[[VAL_3:.*]] : !torch.vtensor<[3,4],si64>
func.func @torch.aten.bitwise_right_shift.Tensor(%arg0: !torch.vtensor<[3,4],si64>, %arg1: !torch.vtensor<[3,4],si64>) -> !torch.vtensor<[3,4],si64> {
%0 = torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],si64>, !torch.vtensor<[3,4],si64> -> !torch.vtensor<[3,4],si64>
return %0 : !torch.vtensor<[3,4],si64>
}