mirror of https://github.com/llvm/torch-mlir
[stablehlo] add aten left/right shift op conversion support (#3234)
parent
cd33d8b011
commit
122eb69a98
|
@ -1978,6 +1978,36 @@ LogicalResult ConvertAtenOp<AtenFmodTensorOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
// AtenBitwiseLeftShiftTensorOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenBitwiseLeftShiftTensorOp>::matchAndRewrite(
|
||||
AtenBitwiseLeftShiftTensorOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value lhs = adaptor.getSelf();
|
||||
Value rhs = adaptor.getOther();
|
||||
|
||||
auto resultType =
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ShiftLeftOp>(op, lhs, rhs);
|
||||
return success();
|
||||
}
|
||||
|
||||
// AtenBitwiseRightShiftTensorOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenBitwiseRightShiftTensorOp>::matchAndRewrite(
|
||||
AtenBitwiseRightShiftTensorOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value lhs = adaptor.getSelf();
|
||||
Value rhs = adaptor.getOther();
|
||||
|
||||
auto resultType =
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ShiftRightArithmeticOp>(op, lhs, rhs);
|
||||
return success();
|
||||
}
|
||||
|
||||
void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||
|
@ -2137,6 +2167,8 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
|||
INSERT_ATENOP_PATTERN(AtenFlipOp);
|
||||
INSERT_ATENOP_PATTERN(AtenRemainderTensorOp);
|
||||
INSERT_ATENOP_PATTERN(AtenFmodTensorOp);
|
||||
INSERT_ATENOP_PATTERN(AtenBitwiseLeftShiftTensorOp);
|
||||
INSERT_ATENOP_PATTERN(AtenBitwiseRightShiftTensorOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \
|
||||
|
|
|
@ -1009,6 +1009,10 @@ STABLEHLO_PASS_SET = {
|
|||
"ElementwiseBitwiseNotInt64Module_basic",
|
||||
"ElementwiseBitwiseOrStaticShapeModule_basic",
|
||||
"ElementwiseBitwiseXorStaticShapeModule_basic",
|
||||
"ElementwiseBitwiseLeftShiftInt64Module_basic",
|
||||
"ElementwiseBitwiseLeftShiftInt8Module_basic",
|
||||
"ElementwiseBitwiseRightShiftInt64Module_basic",
|
||||
"ElementwiseBitwiseRightShiftInt8Module_basic",
|
||||
"ElementwiseCeilModule_basic",
|
||||
"ElementwiseClampMaxModule_basic",
|
||||
"ElementwiseClampMinModule_basic",
|
||||
|
|
|
@ -315,3 +315,34 @@ func.func @torch.aten.uniform(%arg0: !torch.vtensor<[32, 64],f64>) -> !torch.vte
|
|||
%0 = torch.aten.uniform %arg0, %float0, %float1, %none : !torch.vtensor<[32, 64],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[32, 64],f64>
|
||||
return %0 : !torch.vtensor<[32, 64],f64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.bitwise_left_shift.Tensor(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[3,4],si32>,
|
||||
// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[3,1],si32>) -> !torch.vtensor<[3,4],si32> {
|
||||
// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0:.*]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32>
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1:.*]] : !torch.vtensor<[3,1],si32> -> tensor<3x1xi32>
|
||||
// CHECK: %[[VAL_2:.*]] = stablehlo.broadcast_in_dim %[[VAL_1:.*]], dims = [0, 1] : (tensor<3x1xi32>) -> tensor<3x4xi32>
|
||||
// CHECK: %[[VAL_3:.*]] = stablehlo.shift_left %[[VAL_0:.*]], %[[VAL_2:.*]] : tensor<3x4xi32>
|
||||
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3:.*]] : tensor<3x4xi32> -> !torch.vtensor<[3,4],si32>
|
||||
// CHECK: return %[[VAL_4:.*]] : !torch.vtensor<[3,4],si32>
|
||||
func.func @torch.aten.bitwise_left_shift.Tensor(%arg0: !torch.vtensor<[3,4],si32>, %arg1: !torch.vtensor<[3,1],si32>) -> !torch.vtensor<[3,4],si32> {
|
||||
%0 = torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],si32>, !torch.vtensor<[3,1],si32> -> !torch.vtensor<[3,4],si32>
|
||||
return %0 : !torch.vtensor<[3,4],si32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.bitwise_right_shift.Tensor(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[3,4],si64>,
|
||||
// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[3,4],si64>) -> !torch.vtensor<[3,4],si64> {
|
||||
// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0:.*]] : !torch.vtensor<[3,4],si64> -> tensor<3x4xi64>
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1:.*]] : !torch.vtensor<[3,4],si64> -> tensor<3x4xi64>
|
||||
// CHECK: %[[VAL_2:.*]] = stablehlo.shift_right_arithmetic %[[VAL_0:.*]], %[[VAL_1:.*]] : tensor<3x4xi64>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2:.*]] : tensor<3x4xi64> -> !torch.vtensor<[3,4],si64>
|
||||
// CHECK: return %[[VAL_3:.*]] : !torch.vtensor<[3,4],si64>
|
||||
func.func @torch.aten.bitwise_right_shift.Tensor(%arg0: !torch.vtensor<[3,4],si64>, %arg1: !torch.vtensor<[3,4],si64>) -> !torch.vtensor<[3,4],si64> {
|
||||
%0 = torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],si64>, !torch.vtensor<[3,4],si64> -> !torch.vtensor<[3,4],si64>
|
||||
return %0 : !torch.vtensor<[3,4],si64>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue