[Torch] support aten.trunc (#3219)

decompose `trunc(x)` to `sign(x) * floor(abs(x))`
pull/3222/head
Yuanqiang Liu 2024-04-24 14:32:33 +08:00 committed by GitHub
parent e18bf42d0e
commit fab2696489
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 162 additions and 0 deletions

View File

@ -4223,6 +4223,52 @@ def Torch_AtenRound_Op : Torch_Op<"aten.round_", [
}];
}
def Torch_AtenTruncOp : Torch_Op<"aten.trunc", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::trunc : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenTruncOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenTruncOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasFolder = 1;
}
def Torch_AtenTrunc_Op : Torch_Op<"aten.trunc_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::trunc_ : (Tensor) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self
);
let results = (outs
AnyTorchOptionalNonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenTrunc_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenTrunc_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenSignOp : Torch_Op<"aten.sign", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -1834,6 +1834,19 @@ OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
return {};
}
//===----------------------------------------------------------------------===//
// AtenTruncOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenTruncOp::fold(FoldAdaptor adaptor) {
auto resultType = getType().dyn_cast<ValueTensorType>();
if (resultType && resultType.hasDtype() &&
resultType.getDtype().isa<mlir::IntegerType>()) {
return getSelf();
}
return {};
}
//===----------------------------------------------------------------------===//
// AtenSignOp
//===----------------------------------------------------------------------===//

View File

@ -6502,6 +6502,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.trunc\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.log\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -10003,6 +10007,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.trunc\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.clamp_max\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %int4 = torch.constant.int 4\n"
" %int11 = torch.constant.int 11\n"

View File

@ -5886,6 +5886,32 @@ class DecomposeAtenCosineSimilarityOp
};
} // namespace
namespace {
// decompose `trunc(x)` to `sign(x) * floor(abs(x))`
class DecomposeAtenTruncOp : public OpRewritePattern<AtenTruncOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenTruncOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.getSelf();
auto resultTy = dyn_cast<ValueTensorType>(op.getType());
if (!resultTy || !resultTy.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result must have dtype");
}
if (isa<mlir::FloatType>(resultTy.getDtype())) {
Value sign = rewriter.create<AtenSgnOp>(loc, resultTy, self);
Value abs = rewriter.create<AtenAbsOp>(loc, resultTy, self);
Value floor = rewriter.create<AtenFloorOp>(loc, resultTy, abs);
rewriter.replaceOpWithNewOp<AtenMulTensorOp>(op, resultTy, sign, floor);
return success();
}
return failure();
}
};
} // namespace
namespace {
// Decompose `aten.baddbmm` op into `aten.bmm`, `aten.mul.Scalar`, and
// `aten.add.Tensor` op.
@ -7700,6 +7726,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMaxOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenCosineSimilarityOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTruncOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBaddbmmOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFloorDivideOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFloorDivideScalarOp>(patterns);

View File

@ -512,6 +512,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenNormalFunctionalOp>();
target.addIllegalOp<AtenVarMeanOp>();
target.addIllegalOp<AtenCosineSimilarityOp>();
target.addIllegalOp<AtenTruncOp>();
target.addIllegalOp<AtenNewEmptyStridedOp>();
target.addIllegalOp<AtenEmptyStridedOp>();
target.addIllegalOp<AtenBucketizeTensorOp>();

View File

@ -1479,6 +1479,8 @@ STABLEHLO_PASS_SET = {
"ElementwiseCoshModule_basic",
"ElementwiseSinhIntModule_basic",
"ElementwiseSinhModule_basic",
"ElementwiseTruncIntModule_basic",
"ElementwiseTruncModule_basic",
}
STABLEHLO_CRASHING_SET = {
@ -1488,6 +1490,8 @@ STABLEHLO_CRASHING_SET = {
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"ElementwiseTruncModule_basic",
"ElementwiseTruncIntModule_basic",
"ElementwiseSgnModule_basic",
"ElementwiseSignIntModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
@ -2344,6 +2348,8 @@ ONNX_XFAIL_SET = {
"ElementwiseSinhModule_basic",
"ElementwiseCoshIntModule_basic",
"ElementwiseCoshModule_basic",
"ElementwiseTruncIntModule_basic",
"ElementwiseTruncModule_basic",
"ElementwiseDequantizePerChannelModule_basic",
"ElementwiseDequantizePerTensorModule_basic",
"ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic",

View File

@ -245,6 +245,9 @@ def atenhardtanh_backward〡shape(grad_output: List[int], self: List[int], mi
def atenceil〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
def atentrunc〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
def atenlog〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
@ -2227,6 +2230,11 @@ def atenceil〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atentrunc〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, max=0))
def atenclamp_max〡dtype(self_rank_dtype: Tuple[int, int], max: Union[int, float, complex]) -> int:
self_rank, self_dtype = self_rank_dtype

View File

@ -359,6 +359,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_folder=True)
emit_with_mutating_variants("aten::ceil : (Tensor) -> (Tensor)", has_folder=True)
emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True)
emit_with_mutating_variants("aten::trunc : (Tensor) -> (Tensor)", has_folder=True)
emit_with_mutating_variants("aten::sign : (Tensor) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", has_canonicalizer=True)

View File

@ -2077,6 +2077,50 @@ def ElementwiseCeilModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseTruncModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([1, 6], torch.float32, True),
])
def forward(self, a):
return torch.trunc(a)
@register_test_case(module_factory=lambda: ElementwiseTruncModule())
def ElementwiseTruncModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([[-torch.inf, torch.inf, torch.nan, -2.3, 0.0, 1.5]]))
# ==============================================================================
class ElementwiseTruncIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, a):
return torch.trunc(a)
@register_test_case(module_factory=lambda: ElementwiseTruncIntModule())
def ElementwiseTruncIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-100, high=100).to(torch.int32))
# ==============================================================================
class ElementwiseSignModule(torch.nn.Module):
def __init__(self):

View File

@ -2308,6 +2308,14 @@ func.func @torch.aten.floor$canonicalize(%arg0: !torch.vtensor<[?,?],si64>) -> !
return %0 : !torch.vtensor<[?,?],si64>
}
// CHECK-LABEL: func.func @torch.aten.trunc$canonicalize
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],si64>
// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[?,?],si64>
func.func @torch.aten.trunc$canonicalize(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> {
%0 = torch.aten.trunc %arg0 : !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64>
return %0 : !torch.vtensor<[?,?],si64>
}
// CHECK-LABEL: func.func @torch.aten.numel$canonicalize
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,4],f32>
// CHECK-NEXT: %int12 = torch.constant.int 12