mirror of https://github.com/llvm/torch-mlir
parent
e18bf42d0e
commit
fab2696489
|
@ -4223,6 +4223,52 @@ def Torch_AtenRound_Op : Torch_Op<"aten.round_", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenTruncOp : Torch_Op<"aten.trunc", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::trunc : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenTruncOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenTruncOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenTrunc_Op : Torch_Op<"aten.trunc_", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::trunc_ : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
Torch_NonValueTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalNonValueTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenTrunc_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenTrunc_Op::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenSignOp : Torch_Op<"aten.sign", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -1834,6 +1834,19 @@ OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
|
|||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenTruncOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenTruncOp::fold(FoldAdaptor adaptor) {
|
||||
auto resultType = getType().dyn_cast<ValueTensorType>();
|
||||
if (resultType && resultType.hasDtype() &&
|
||||
resultType.getDtype().isa<mlir::IntegerType>()) {
|
||||
return getSelf();
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenSignOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -6502,6 +6502,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.trunc\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.log\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -10003,6 +10007,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.trunc\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.clamp_max\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
|
|
|
@ -5886,6 +5886,32 @@ class DecomposeAtenCosineSimilarityOp
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// decompose `trunc(x)` to `sign(x) * floor(abs(x))`
|
||||
class DecomposeAtenTruncOp : public OpRewritePattern<AtenTruncOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenTruncOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value self = op.getSelf();
|
||||
|
||||
auto resultTy = dyn_cast<ValueTensorType>(op.getType());
|
||||
if (!resultTy || !resultTy.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result must have dtype");
|
||||
}
|
||||
|
||||
if (isa<mlir::FloatType>(resultTy.getDtype())) {
|
||||
Value sign = rewriter.create<AtenSgnOp>(loc, resultTy, self);
|
||||
Value abs = rewriter.create<AtenAbsOp>(loc, resultTy, self);
|
||||
Value floor = rewriter.create<AtenFloorOp>(loc, resultTy, abs);
|
||||
rewriter.replaceOpWithNewOp<AtenMulTensorOp>(op, resultTy, sign, floor);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.baddbmm` op into `aten.bmm`, `aten.mul.Scalar`, and
|
||||
// `aten.add.Tensor` op.
|
||||
|
@ -7700,6 +7726,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinTensorOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMaxOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenCosineSimilarityOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenTruncOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenBaddbmmOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenFloorDivideOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenFloorDivideScalarOp>(patterns);
|
||||
|
|
|
@ -512,6 +512,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenNormalFunctionalOp>();
|
||||
target.addIllegalOp<AtenVarMeanOp>();
|
||||
target.addIllegalOp<AtenCosineSimilarityOp>();
|
||||
target.addIllegalOp<AtenTruncOp>();
|
||||
target.addIllegalOp<AtenNewEmptyStridedOp>();
|
||||
target.addIllegalOp<AtenEmptyStridedOp>();
|
||||
target.addIllegalOp<AtenBucketizeTensorOp>();
|
||||
|
|
|
@ -1479,6 +1479,8 @@ STABLEHLO_PASS_SET = {
|
|||
"ElementwiseCoshModule_basic",
|
||||
"ElementwiseSinhIntModule_basic",
|
||||
"ElementwiseSinhModule_basic",
|
||||
"ElementwiseTruncIntModule_basic",
|
||||
"ElementwiseTruncModule_basic",
|
||||
}
|
||||
|
||||
STABLEHLO_CRASHING_SET = {
|
||||
|
@ -1488,6 +1490,8 @@ STABLEHLO_CRASHING_SET = {
|
|||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
TOSA_PASS_SET = {
|
||||
"ElementwiseTruncModule_basic",
|
||||
"ElementwiseTruncIntModule_basic",
|
||||
"ElementwiseSgnModule_basic",
|
||||
"ElementwiseSignIntModule_basic",
|
||||
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
||||
|
@ -2344,6 +2348,8 @@ ONNX_XFAIL_SET = {
|
|||
"ElementwiseSinhModule_basic",
|
||||
"ElementwiseCoshIntModule_basic",
|
||||
"ElementwiseCoshModule_basic",
|
||||
"ElementwiseTruncIntModule_basic",
|
||||
"ElementwiseTruncModule_basic",
|
||||
"ElementwiseDequantizePerChannelModule_basic",
|
||||
"ElementwiseDequantizePerTensorModule_basic",
|
||||
"ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic",
|
||||
|
|
|
@ -245,6 +245,9 @@ def aten〇hardtanh_backward〡shape(grad_output: List[int], self: List[int], mi
|
|||
def aten〇ceil〡shape(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇trunc〡shape(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇log〡shape(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
|
@ -2227,6 +2230,11 @@ def aten〇ceil〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
|||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||
def aten〇trunc〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, max=0))
|
||||
def aten〇clamp_max〡dtype(self_rank_dtype: Tuple[int, int], max: Union[int, float, complex]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
|
|
@ -359,6 +359,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_folder=True)
|
||||
emit_with_mutating_variants("aten::ceil : (Tensor) -> (Tensor)", has_folder=True)
|
||||
emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True)
|
||||
emit_with_mutating_variants("aten::trunc : (Tensor) -> (Tensor)", has_folder=True)
|
||||
emit_with_mutating_variants("aten::sign : (Tensor) -> (Tensor)", has_canonicalizer=True)
|
||||
emit_with_mutating_variants("aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", has_canonicalizer=True)
|
||||
|
||||
|
|
|
@ -2077,6 +2077,50 @@ def ElementwiseCeilModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseTruncModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, 6], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.trunc(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseTruncModule())
|
||||
def ElementwiseTruncModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor([[-torch.inf, torch.inf, torch.nan, -2.3, 0.0, 1.5]]))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseTruncIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.trunc(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseTruncIntModule())
|
||||
def ElementwiseTruncIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-100, high=100).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseSignModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
|
@ -2308,6 +2308,14 @@ func.func @torch.aten.floor$canonicalize(%arg0: !torch.vtensor<[?,?],si64>) -> !
|
|||
return %0 : !torch.vtensor<[?,?],si64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.trunc$canonicalize
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],si64>
|
||||
// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[?,?],si64>
|
||||
func.func @torch.aten.trunc$canonicalize(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> {
|
||||
%0 = torch.aten.trunc %arg0 : !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64>
|
||||
return %0 : !torch.vtensor<[?,?],si64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.numel$canonicalize
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,4],f32>
|
||||
// CHECK-NEXT: %int12 = torch.constant.int 12
|
||||
|
|
Loading…
Reference in New Issue