mirror of https://github.com/llvm/torch-mlir
[Stablehlo] Support AtenTrilOp (#3359)
1. lower aten.tril to stablehlo composed by iota, select and so forth 2. add related e2e test casespull/3368/head
parent
8814d0ae64
commit
cc28d566ff
|
@ -2052,6 +2052,77 @@ LogicalResult ConvertAtenOp<AtenBitwiseRightShiftTensorOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
|
||||||
|
AtenTrilOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
|
||||||
|
Value self = adaptor.getSelf();
|
||||||
|
|
||||||
|
auto selfTy = self.getType().cast<RankedTensorType>();
|
||||||
|
if (!selfTy.hasStaticShape()) {
|
||||||
|
return op->emitError("dynamic shaped input is not supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayRef<int64_t> selfShape = selfTy.getShape();
|
||||||
|
int64_t selfRank = selfTy.getRank();
|
||||||
|
auto iotaElementTy = mlir::IntegerType::get(op.getContext(), 64);
|
||||||
|
auto iotaTy = RankedTensorType::get(
|
||||||
|
{selfShape[selfRank - 2], selfShape[selfRank - 1]}, iotaElementTy);
|
||||||
|
Value colIdxTensor =
|
||||||
|
rewriter.create<stablehlo::IotaOp>(loc, iotaTy, 1).getResult();
|
||||||
|
Value rowIdxTensor =
|
||||||
|
rewriter.create<stablehlo::IotaOp>(loc, iotaTy, 0).getResult();
|
||||||
|
|
||||||
|
Value diagonal = adaptor.getDiagonal();
|
||||||
|
Value diagonalTensor =
|
||||||
|
rewriter.create<tensor::FromElementsOp>(loc, diagonal).getResult();
|
||||||
|
|
||||||
|
auto bcastDimensions = rewriter.getDenseI64ArrayAttr({1});
|
||||||
|
Value shiftedRowIdxTensor = rewriter.create<chlo::BroadcastAddOp>(
|
||||||
|
loc, rowIdxTensor, diagonalTensor, bcastDimensions);
|
||||||
|
|
||||||
|
auto cmpDirectionAttr = stablehlo::ComparisonDirectionAttr::get(
|
||||||
|
rewriter.getContext(), stablehlo::ComparisonDirection::LE);
|
||||||
|
auto cmpTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
||||||
|
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
|
||||||
|
auto cmpTy = iotaTy.clone(rewriter.getI1Type());
|
||||||
|
Value cmpRes = rewriter.create<stablehlo::CompareOp>(
|
||||||
|
loc, cmpTy, colIdxTensor, shiftedRowIdxTensor, cmpDirectionAttr,
|
||||||
|
cmpTypeAttr);
|
||||||
|
|
||||||
|
auto resTy =
|
||||||
|
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||||
|
|
||||||
|
auto bcastTy = resTy.clone(rewriter.getI1Type());
|
||||||
|
auto bcastAttr = rewriter.getDenseI64ArrayAttr({selfRank - 2, selfRank - 1});
|
||||||
|
Value bcastedCmpRes = rewriter.create<stablehlo::BroadcastInDimOp>(
|
||||||
|
loc, bcastTy, cmpRes, bcastAttr);
|
||||||
|
|
||||||
|
auto resElemTy = resTy.getElementType();
|
||||||
|
Value zeroTensor;
|
||||||
|
if (resElemTy.isa<mlir::FloatType>()) {
|
||||||
|
auto constAttr = SplatElementsAttr::get(
|
||||||
|
resTy, llvm::APFloat::getZero(
|
||||||
|
resElemTy.cast<FloatType>().getFloatSemantics(), false));
|
||||||
|
zeroTensor = rewriter.create<stablehlo::ConstantOp>(loc, resTy, constAttr);
|
||||||
|
} else if (resElemTy.isa<mlir::IntegerType>()) {
|
||||||
|
auto constAttr = SplatElementsAttr::get(
|
||||||
|
resTy,
|
||||||
|
llvm::APInt::getZero(resElemTy.cast<mlir::IntegerType>().getWidth()));
|
||||||
|
zeroTensor = rewriter.create<stablehlo::ConstantOp>(loc, resTy, constAttr);
|
||||||
|
} else {
|
||||||
|
return op.emitError("element type is not float or integer");
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<stablehlo::SelectOp>(
|
||||||
|
op.getOperation(), resTy, bcastedCmpRes, self, zeroTensor);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||||||
|
@ -2218,6 +2289,8 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
||||||
INSERT_ATENOP_PATTERN(AtenFmodTensorOp);
|
INSERT_ATENOP_PATTERN(AtenFmodTensorOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenBitwiseLeftShiftTensorOp);
|
INSERT_ATENOP_PATTERN(AtenBitwiseLeftShiftTensorOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenBitwiseRightShiftTensorOp);
|
INSERT_ATENOP_PATTERN(AtenBitwiseRightShiftTensorOp);
|
||||||
|
|
||||||
|
INSERT_ATENOP_PATTERN(AtenTrilOp);
|
||||||
#undef INSERT_ATENOP_PATTERN
|
#undef INSERT_ATENOP_PATTERN
|
||||||
|
|
||||||
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \
|
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \
|
||||||
|
|
|
@ -524,9 +524,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"AtenSubFloatModule_basic",
|
"AtenSubFloatModule_basic",
|
||||||
"AtenTopKModule_basic",
|
"AtenTopKModule_basic",
|
||||||
"AtenTopKSmallestModule_basic",
|
"AtenTopKSmallestModule_basic",
|
||||||
"AtenTrilModule_basic",
|
|
||||||
"AtenTrilWithNegDiagonalModule_basic",
|
|
||||||
"AtenTrilWithPosDiagonalModule_basic",
|
|
||||||
"Aten_EmbeddingBagExample_basic",
|
"Aten_EmbeddingBagExample_basic",
|
||||||
"AvgPool2dDivisorOverrideModule_basic",
|
"AvgPool2dDivisorOverrideModule_basic",
|
||||||
"BernoulliTensorModule_basic",
|
"BernoulliTensorModule_basic",
|
||||||
|
@ -867,6 +864,9 @@ STABLEHLO_PASS_SET = {
|
||||||
"AtenRoundIntModule_basic",
|
"AtenRoundIntModule_basic",
|
||||||
"AtenSubFloatModule_basic",
|
"AtenSubFloatModule_basic",
|
||||||
"AtenToDeviceModule_basic",
|
"AtenToDeviceModule_basic",
|
||||||
|
"AtenTrilStaticModule_basic",
|
||||||
|
"AtenTrilWithNegDiagonalStaticModule_basic",
|
||||||
|
"AtenTrilWithPosDiagonalStaticModule_basic",
|
||||||
"Aten_CastFloatModule_basic",
|
"Aten_CastFloatModule_basic",
|
||||||
"Aten_CastLongModule_basic",
|
"Aten_CastLongModule_basic",
|
||||||
"AvgPool1dStaticModule_basic",
|
"AvgPool1dStaticModule_basic",
|
||||||
|
|
|
@ -5338,6 +5338,29 @@ def AtenTrilModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class AtenTrilStaticModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([8, 8], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.tril(x)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: AtenTrilStaticModule())
|
||||||
|
def AtenTrilStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(8, 8))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class AtenTrilWithPosDiagonalModule(torch.nn.Module):
|
class AtenTrilWithPosDiagonalModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -5361,6 +5384,29 @@ def AtenTrilWithPosDiagonalModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class AtenTrilWithPosDiagonalStaticModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([9, 4, 3], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.tril(x, diagonal=2)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: AtenTrilWithPosDiagonalStaticModule())
|
||||||
|
def AtenTrilWithPosDiagonalStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(9, 4, 3))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class AtenTrilWithNegDiagonalModule(torch.nn.Module):
|
class AtenTrilWithNegDiagonalModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -5384,6 +5430,29 @@ def AtenTrilWithNegDiagonalModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class AtenTrilWithNegDiagonalStaticModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([3, 1, 5, 9], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.tril(x, diagonal=-4)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: AtenTrilWithNegDiagonalStaticModule())
|
||||||
|
def AtenTrilWithNegDiagonalStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 1, 5, 9))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class AtenRoundFloatModule(torch.nn.Module):
|
class AtenRoundFloatModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -319,3 +319,25 @@ func.func @torch.aten.bitwise_right_shift.Tensor(%arg0: !torch.vtensor<[3,4],si6
|
||||||
%0 = torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],si64>, !torch.vtensor<[3,4],si64> -> !torch.vtensor<[3,4],si64>
|
%0 = torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],si64>, !torch.vtensor<[3,4],si64> -> !torch.vtensor<[3,4],si64>
|
||||||
return %0 : !torch.vtensor<[3,4],si64>
|
return %0 : !torch.vtensor<[3,4],si64>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.tril(
|
||||||
|
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[2,3,5],f32>,
|
||||||
|
// CHECK-SAME: %[[ARG_1:.*]]: !torch.int) -> !torch.vtensor<[2,3,5],f32>
|
||||||
|
// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[2,3,5],f32> -> tensor<2x3x5xf32>
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_i64 %[[ARG_1]]
|
||||||
|
// CHECK: %[[VAL_2:.*]] = stablehlo.iota dim = 1 : tensor<3x5xi64>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = stablehlo.iota dim = 0 : tensor<3x5xi64>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = tensor.from_elements %[[VAL_1]] : tensor<1xi64>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = chlo.broadcast_add %[[VAL_3]], %[[VAL_4]] {broadcast_dimensions = array<i64: 1>} : (tensor<3x5xi64>, tensor<1xi64>) -> tensor<3x5xi64>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = stablehlo.compare LE, %[[VAL_2]], %[[VAL_5]], SIGNED : (tensor<3x5xi64>, tensor<3x5xi64>) -> tensor<3x5xi1>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = stablehlo.broadcast_in_dim %[[VAL_6]], dims = [1, 2] : (tensor<3x5xi1>) -> tensor<2x3x5xi1>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<2x3x5xf32>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = stablehlo.select %[[VAL_7]], %[[VAL_0]], %[[VAL_8]] : tensor<2x3x5xi1>, tensor<2x3x5xf32>
|
||||||
|
// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<2x3x5xf32> -> !torch.vtensor<[2,3,5],f32>
|
||||||
|
// CHECK: return %[[VAL_10:.*]] : !torch.vtensor<[2,3,5],f32>
|
||||||
|
func.func @torch.aten.tril(%arg0: !torch.vtensor<[2,3,5],f32>, %arg1: !torch.int) -> !torch.vtensor<[2,3,5],f32> {
|
||||||
|
%0 = torch.aten.tril %arg0, %arg1:!torch.vtensor<[2,3,5],f32>, !torch.int -> !torch.vtensor<[2,3,5],f32>
|
||||||
|
return %0 : !torch.vtensor<[2,3,5],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue