mirror of https://github.com/llvm/torch-mlir
[Stablehlo] Support AtenTrilOp (#3359)
1. lower aten.tril to stablehlo composed by iota, select and so forth 2. add related e2e test casespull/3368/head
parent
8814d0ae64
commit
cc28d566ff
|
@ -2052,6 +2052,77 @@ LogicalResult ConvertAtenOp<AtenBitwiseRightShiftTensorOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
|
||||
AtenTrilOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
Location loc = op.getLoc();
|
||||
|
||||
Value self = adaptor.getSelf();
|
||||
|
||||
auto selfTy = self.getType().cast<RankedTensorType>();
|
||||
if (!selfTy.hasStaticShape()) {
|
||||
return op->emitError("dynamic shaped input is not supported");
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> selfShape = selfTy.getShape();
|
||||
int64_t selfRank = selfTy.getRank();
|
||||
auto iotaElementTy = mlir::IntegerType::get(op.getContext(), 64);
|
||||
auto iotaTy = RankedTensorType::get(
|
||||
{selfShape[selfRank - 2], selfShape[selfRank - 1]}, iotaElementTy);
|
||||
Value colIdxTensor =
|
||||
rewriter.create<stablehlo::IotaOp>(loc, iotaTy, 1).getResult();
|
||||
Value rowIdxTensor =
|
||||
rewriter.create<stablehlo::IotaOp>(loc, iotaTy, 0).getResult();
|
||||
|
||||
Value diagonal = adaptor.getDiagonal();
|
||||
Value diagonalTensor =
|
||||
rewriter.create<tensor::FromElementsOp>(loc, diagonal).getResult();
|
||||
|
||||
auto bcastDimensions = rewriter.getDenseI64ArrayAttr({1});
|
||||
Value shiftedRowIdxTensor = rewriter.create<chlo::BroadcastAddOp>(
|
||||
loc, rowIdxTensor, diagonalTensor, bcastDimensions);
|
||||
|
||||
auto cmpDirectionAttr = stablehlo::ComparisonDirectionAttr::get(
|
||||
rewriter.getContext(), stablehlo::ComparisonDirection::LE);
|
||||
auto cmpTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
||||
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
|
||||
auto cmpTy = iotaTy.clone(rewriter.getI1Type());
|
||||
Value cmpRes = rewriter.create<stablehlo::CompareOp>(
|
||||
loc, cmpTy, colIdxTensor, shiftedRowIdxTensor, cmpDirectionAttr,
|
||||
cmpTypeAttr);
|
||||
|
||||
auto resTy =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
|
||||
auto bcastTy = resTy.clone(rewriter.getI1Type());
|
||||
auto bcastAttr = rewriter.getDenseI64ArrayAttr({selfRank - 2, selfRank - 1});
|
||||
Value bcastedCmpRes = rewriter.create<stablehlo::BroadcastInDimOp>(
|
||||
loc, bcastTy, cmpRes, bcastAttr);
|
||||
|
||||
auto resElemTy = resTy.getElementType();
|
||||
Value zeroTensor;
|
||||
if (resElemTy.isa<mlir::FloatType>()) {
|
||||
auto constAttr = SplatElementsAttr::get(
|
||||
resTy, llvm::APFloat::getZero(
|
||||
resElemTy.cast<FloatType>().getFloatSemantics(), false));
|
||||
zeroTensor = rewriter.create<stablehlo::ConstantOp>(loc, resTy, constAttr);
|
||||
} else if (resElemTy.isa<mlir::IntegerType>()) {
|
||||
auto constAttr = SplatElementsAttr::get(
|
||||
resTy,
|
||||
llvm::APInt::getZero(resElemTy.cast<mlir::IntegerType>().getWidth()));
|
||||
zeroTensor = rewriter.create<stablehlo::ConstantOp>(loc, resTy, constAttr);
|
||||
} else {
|
||||
return op.emitError("element type is not float or integer");
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<stablehlo::SelectOp>(
|
||||
op.getOperation(), resTy, bcastedCmpRes, self, zeroTensor);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||
|
@ -2218,6 +2289,8 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
|||
INSERT_ATENOP_PATTERN(AtenFmodTensorOp);
|
||||
INSERT_ATENOP_PATTERN(AtenBitwiseLeftShiftTensorOp);
|
||||
INSERT_ATENOP_PATTERN(AtenBitwiseRightShiftTensorOp);
|
||||
|
||||
INSERT_ATENOP_PATTERN(AtenTrilOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \
|
||||
|
|
|
@ -524,9 +524,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"AtenSubFloatModule_basic",
|
||||
"AtenTopKModule_basic",
|
||||
"AtenTopKSmallestModule_basic",
|
||||
"AtenTrilModule_basic",
|
||||
"AtenTrilWithNegDiagonalModule_basic",
|
||||
"AtenTrilWithPosDiagonalModule_basic",
|
||||
"Aten_EmbeddingBagExample_basic",
|
||||
"AvgPool2dDivisorOverrideModule_basic",
|
||||
"BernoulliTensorModule_basic",
|
||||
|
@ -867,6 +864,9 @@ STABLEHLO_PASS_SET = {
|
|||
"AtenRoundIntModule_basic",
|
||||
"AtenSubFloatModule_basic",
|
||||
"AtenToDeviceModule_basic",
|
||||
"AtenTrilStaticModule_basic",
|
||||
"AtenTrilWithNegDiagonalStaticModule_basic",
|
||||
"AtenTrilWithPosDiagonalStaticModule_basic",
|
||||
"Aten_CastFloatModule_basic",
|
||||
"Aten_CastLongModule_basic",
|
||||
"AvgPool1dStaticModule_basic",
|
||||
|
|
|
@ -5338,6 +5338,29 @@ def AtenTrilModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class AtenTrilStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([8, 8], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.tril(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenTrilStaticModule())
|
||||
def AtenTrilStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(8, 8))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AtenTrilWithPosDiagonalModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -5361,6 +5384,29 @@ def AtenTrilWithPosDiagonalModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class AtenTrilWithPosDiagonalStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([9, 4, 3], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.tril(x, diagonal=2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenTrilWithPosDiagonalStaticModule())
|
||||
def AtenTrilWithPosDiagonalStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(9, 4, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AtenTrilWithNegDiagonalModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -5384,6 +5430,29 @@ def AtenTrilWithNegDiagonalModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class AtenTrilWithNegDiagonalStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([3, 1, 5, 9], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.tril(x, diagonal=-4)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenTrilWithNegDiagonalStaticModule())
|
||||
def AtenTrilWithNegDiagonalStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 1, 5, 9))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AtenRoundFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -319,3 +319,25 @@ func.func @torch.aten.bitwise_right_shift.Tensor(%arg0: !torch.vtensor<[3,4],si6
|
|||
%0 = torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],si64>, !torch.vtensor<[3,4],si64> -> !torch.vtensor<[3,4],si64>
|
||||
return %0 : !torch.vtensor<[3,4],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.tril(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[2,3,5],f32>,
|
||||
// CHECK-SAME: %[[ARG_1:.*]]: !torch.int) -> !torch.vtensor<[2,3,5],f32>
|
||||
// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[2,3,5],f32> -> tensor<2x3x5xf32>
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_i64 %[[ARG_1]]
|
||||
// CHECK: %[[VAL_2:.*]] = stablehlo.iota dim = 1 : tensor<3x5xi64>
|
||||
// CHECK: %[[VAL_3:.*]] = stablehlo.iota dim = 0 : tensor<3x5xi64>
|
||||
// CHECK: %[[VAL_4:.*]] = tensor.from_elements %[[VAL_1]] : tensor<1xi64>
|
||||
// CHECK: %[[VAL_5:.*]] = chlo.broadcast_add %[[VAL_3]], %[[VAL_4]] {broadcast_dimensions = array<i64: 1>} : (tensor<3x5xi64>, tensor<1xi64>) -> tensor<3x5xi64>
|
||||
// CHECK: %[[VAL_6:.*]] = stablehlo.compare LE, %[[VAL_2]], %[[VAL_5]], SIGNED : (tensor<3x5xi64>, tensor<3x5xi64>) -> tensor<3x5xi1>
|
||||
// CHECK: %[[VAL_7:.*]] = stablehlo.broadcast_in_dim %[[VAL_6]], dims = [1, 2] : (tensor<3x5xi1>) -> tensor<2x3x5xi1>
|
||||
// CHECK: %[[VAL_8:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<2x3x5xf32>
|
||||
// CHECK: %[[VAL_9:.*]] = stablehlo.select %[[VAL_7]], %[[VAL_0]], %[[VAL_8]] : tensor<2x3x5xi1>, tensor<2x3x5xf32>
|
||||
// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<2x3x5xf32> -> !torch.vtensor<[2,3,5],f32>
|
||||
// CHECK: return %[[VAL_10:.*]] : !torch.vtensor<[2,3,5],f32>
|
||||
func.func @torch.aten.tril(%arg0: !torch.vtensor<[2,3,5],f32>, %arg1: !torch.int) -> !torch.vtensor<[2,3,5],f32> {
|
||||
%0 = torch.aten.tril %arg0, %arg1:!torch.vtensor<[2,3,5],f32>, !torch.int -> !torch.vtensor<[2,3,5],f32>
|
||||
return %0 : !torch.vtensor<[2,3,5],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue