[Stablehlo] Support AtenTrilOp (#3359)

1. lower aten.tril to stablehlo composed by iota, select and so forth
2. add related e2e test cases
pull/3368/head
Wu Yuan 2024-05-20 15:49:24 +08:00 committed by GitHub
parent 8814d0ae64
commit cc28d566ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 167 additions and 3 deletions

View File

@ -2052,6 +2052,77 @@ LogicalResult ConvertAtenOp<AtenBitwiseRightShiftTensorOp>::matchAndRewrite(
return success(); return success();
} }
template <>
LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
AtenTrilOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value self = adaptor.getSelf();
auto selfTy = self.getType().cast<RankedTensorType>();
if (!selfTy.hasStaticShape()) {
return op->emitError("dynamic shaped input is not supported");
}
ArrayRef<int64_t> selfShape = selfTy.getShape();
int64_t selfRank = selfTy.getRank();
auto iotaElementTy = mlir::IntegerType::get(op.getContext(), 64);
auto iotaTy = RankedTensorType::get(
{selfShape[selfRank - 2], selfShape[selfRank - 1]}, iotaElementTy);
Value colIdxTensor =
rewriter.create<stablehlo::IotaOp>(loc, iotaTy, 1).getResult();
Value rowIdxTensor =
rewriter.create<stablehlo::IotaOp>(loc, iotaTy, 0).getResult();
Value diagonal = adaptor.getDiagonal();
Value diagonalTensor =
rewriter.create<tensor::FromElementsOp>(loc, diagonal).getResult();
auto bcastDimensions = rewriter.getDenseI64ArrayAttr({1});
Value shiftedRowIdxTensor = rewriter.create<chlo::BroadcastAddOp>(
loc, rowIdxTensor, diagonalTensor, bcastDimensions);
auto cmpDirectionAttr = stablehlo::ComparisonDirectionAttr::get(
rewriter.getContext(), stablehlo::ComparisonDirection::LE);
auto cmpTypeAttr = stablehlo::ComparisonTypeAttr::get(
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
auto cmpTy = iotaTy.clone(rewriter.getI1Type());
Value cmpRes = rewriter.create<stablehlo::CompareOp>(
loc, cmpTy, colIdxTensor, shiftedRowIdxTensor, cmpDirectionAttr,
cmpTypeAttr);
auto resTy =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
auto bcastTy = resTy.clone(rewriter.getI1Type());
auto bcastAttr = rewriter.getDenseI64ArrayAttr({selfRank - 2, selfRank - 1});
Value bcastedCmpRes = rewriter.create<stablehlo::BroadcastInDimOp>(
loc, bcastTy, cmpRes, bcastAttr);
auto resElemTy = resTy.getElementType();
Value zeroTensor;
if (resElemTy.isa<mlir::FloatType>()) {
auto constAttr = SplatElementsAttr::get(
resTy, llvm::APFloat::getZero(
resElemTy.cast<FloatType>().getFloatSemantics(), false));
zeroTensor = rewriter.create<stablehlo::ConstantOp>(loc, resTy, constAttr);
} else if (resElemTy.isa<mlir::IntegerType>()) {
auto constAttr = SplatElementsAttr::get(
resTy,
llvm::APInt::getZero(resElemTy.cast<mlir::IntegerType>().getWidth()));
zeroTensor = rewriter.create<stablehlo::ConstantOp>(loc, resTy, constAttr);
} else {
return op.emitError("element type is not float or integer");
}
rewriter.replaceOpWithNewOp<stablehlo::SelectOp>(
op.getOperation(), resTy, bcastedCmpRes, self, zeroTensor);
return success();
}
void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns, TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options) { ConversionTarget &target, const TorchToStablehloOptions &options) {
@ -2218,6 +2289,8 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenFmodTensorOp); INSERT_ATENOP_PATTERN(AtenFmodTensorOp);
INSERT_ATENOP_PATTERN(AtenBitwiseLeftShiftTensorOp); INSERT_ATENOP_PATTERN(AtenBitwiseLeftShiftTensorOp);
INSERT_ATENOP_PATTERN(AtenBitwiseRightShiftTensorOp); INSERT_ATENOP_PATTERN(AtenBitwiseRightShiftTensorOp);
INSERT_ATENOP_PATTERN(AtenTrilOp);
#undef INSERT_ATENOP_PATTERN #undef INSERT_ATENOP_PATTERN
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \ #define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \

View File

@ -524,9 +524,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"AtenSubFloatModule_basic", "AtenSubFloatModule_basic",
"AtenTopKModule_basic", "AtenTopKModule_basic",
"AtenTopKSmallestModule_basic", "AtenTopKSmallestModule_basic",
"AtenTrilModule_basic",
"AtenTrilWithNegDiagonalModule_basic",
"AtenTrilWithPosDiagonalModule_basic",
"Aten_EmbeddingBagExample_basic", "Aten_EmbeddingBagExample_basic",
"AvgPool2dDivisorOverrideModule_basic", "AvgPool2dDivisorOverrideModule_basic",
"BernoulliTensorModule_basic", "BernoulliTensorModule_basic",
@ -867,6 +864,9 @@ STABLEHLO_PASS_SET = {
"AtenRoundIntModule_basic", "AtenRoundIntModule_basic",
"AtenSubFloatModule_basic", "AtenSubFloatModule_basic",
"AtenToDeviceModule_basic", "AtenToDeviceModule_basic",
"AtenTrilStaticModule_basic",
"AtenTrilWithNegDiagonalStaticModule_basic",
"AtenTrilWithPosDiagonalStaticModule_basic",
"Aten_CastFloatModule_basic", "Aten_CastFloatModule_basic",
"Aten_CastLongModule_basic", "Aten_CastLongModule_basic",
"AvgPool1dStaticModule_basic", "AvgPool1dStaticModule_basic",

View File

@ -5338,6 +5338,29 @@ def AtenTrilModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class AtenTrilStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([8, 8], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.tril(x)
@register_test_case(module_factory=lambda: AtenTrilStaticModule())
def AtenTrilStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(8, 8))
# ==============================================================================
class AtenTrilWithPosDiagonalModule(torch.nn.Module): class AtenTrilWithPosDiagonalModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -5361,6 +5384,29 @@ def AtenTrilWithPosDiagonalModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class AtenTrilWithPosDiagonalStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([9, 4, 3], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.tril(x, diagonal=2)
@register_test_case(module_factory=lambda: AtenTrilWithPosDiagonalStaticModule())
def AtenTrilWithPosDiagonalStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(9, 4, 3))
# ==============================================================================
class AtenTrilWithNegDiagonalModule(torch.nn.Module): class AtenTrilWithNegDiagonalModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -5384,6 +5430,29 @@ def AtenTrilWithNegDiagonalModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class AtenTrilWithNegDiagonalStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([3, 1, 5, 9], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.tril(x, diagonal=-4)
@register_test_case(module_factory=lambda: AtenTrilWithNegDiagonalStaticModule())
def AtenTrilWithNegDiagonalStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 5, 9))
# ==============================================================================
class AtenRoundFloatModule(torch.nn.Module): class AtenRoundFloatModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@ -319,3 +319,25 @@ func.func @torch.aten.bitwise_right_shift.Tensor(%arg0: !torch.vtensor<[3,4],si6
%0 = torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],si64>, !torch.vtensor<[3,4],si64> -> !torch.vtensor<[3,4],si64> %0 = torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],si64>, !torch.vtensor<[3,4],si64> -> !torch.vtensor<[3,4],si64>
return %0 : !torch.vtensor<[3,4],si64> return %0 : !torch.vtensor<[3,4],si64>
} }
// -----
// CHECK-LABEL: func.func @torch.aten.tril(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[2,3,5],f32>,
// CHECK-SAME: %[[ARG_1:.*]]: !torch.int) -> !torch.vtensor<[2,3,5],f32>
// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[2,3,5],f32> -> tensor<2x3x5xf32>
// CHECK: %[[VAL_1:.*]] = torch_c.to_i64 %[[ARG_1]]
// CHECK: %[[VAL_2:.*]] = stablehlo.iota dim = 1 : tensor<3x5xi64>
// CHECK: %[[VAL_3:.*]] = stablehlo.iota dim = 0 : tensor<3x5xi64>
// CHECK: %[[VAL_4:.*]] = tensor.from_elements %[[VAL_1]] : tensor<1xi64>
// CHECK: %[[VAL_5:.*]] = chlo.broadcast_add %[[VAL_3]], %[[VAL_4]] {broadcast_dimensions = array<i64: 1>} : (tensor<3x5xi64>, tensor<1xi64>) -> tensor<3x5xi64>
// CHECK: %[[VAL_6:.*]] = stablehlo.compare LE, %[[VAL_2]], %[[VAL_5]], SIGNED : (tensor<3x5xi64>, tensor<3x5xi64>) -> tensor<3x5xi1>
// CHECK: %[[VAL_7:.*]] = stablehlo.broadcast_in_dim %[[VAL_6]], dims = [1, 2] : (tensor<3x5xi1>) -> tensor<2x3x5xi1>
// CHECK: %[[VAL_8:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<2x3x5xf32>
// CHECK: %[[VAL_9:.*]] = stablehlo.select %[[VAL_7]], %[[VAL_0]], %[[VAL_8]] : tensor<2x3x5xi1>, tensor<2x3x5xf32>
// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<2x3x5xf32> -> !torch.vtensor<[2,3,5],f32>
// CHECK: return %[[VAL_10:.*]] : !torch.vtensor<[2,3,5],f32>
func.func @torch.aten.tril(%arg0: !torch.vtensor<[2,3,5],f32>, %arg1: !torch.int) -> !torch.vtensor<[2,3,5],f32> {
%0 = torch.aten.tril %arg0, %arg1:!torch.vtensor<[2,3,5],f32>, !torch.int -> !torch.vtensor<[2,3,5],f32>
return %0 : !torch.vtensor<[2,3,5],f32>
}