mirror of https://github.com/llvm/torch-mlir
Implement lowering for torch.aten.hann_window.periodic (#3502)
parent
b59efc75f3
commit
fde286f491
|
@ -12570,6 +12570,34 @@ def Torch_AtenBaddbmm_Op : Torch_Op<"aten.baddbmm_", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenHannWindowPeriodicOp : Torch_Op<"aten.hann_window.periodic", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::hann_window.periodic : (int, bool, int?, int?, Device?, bool?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
Torch_IntType:$window_length,
|
||||
Torch_BoolType:$periodic,
|
||||
AnyTorchOptionalIntType:$dtype,
|
||||
AnyTorchOptionalIntType:$layout,
|
||||
AnyTorchOptionalDeviceType:$device,
|
||||
AnyTorchOptionalBoolType:$pin_memory
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenHannWindowPeriodicOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 6, 1);
|
||||
}
|
||||
void AtenHannWindowPeriodicOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 6, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenFftFftOp : Torch_Op<"aten.fft_fft", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -6619,6 +6619,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.hann_window.periodic\"(%arg0: !torch.int, %arg1: !torch.bool, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.list<int> {\n"
|
||||
" %0 = torch.prim.ListConstruct %arg0 : (!torch.int) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.hardshrink\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -10786,6 +10790,26 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.hann_window.periodic\"(%arg0: !torch.int, %arg1: !torch.bool, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.int {\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0 = torch.aten.__is__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %int6 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
|
||||
" torch.prim.If.yield %3 : !torch.int\n"
|
||||
" }\n"
|
||||
" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1) : (!torch.int) -> !torch.bool\n"
|
||||
" torch.prim.If %2 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" return %1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.hardshrink\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
|
|
|
@ -8128,6 +8128,71 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.hann_window` into `aten.arange.start`, `aten.mul.Scalar`,
|
||||
// `aten.sin` and `aten.square` or into `aten.ones` in the trivial case
|
||||
class DecomposeAtenHannWindowPeriodicOp
|
||||
: public OpRewritePattern<AtenHannWindowPeriodicOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenHannWindowPeriodicOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
MLIRContext *context = op.getContext();
|
||||
Type opType = op.getType();
|
||||
|
||||
Value opWindowLength = op.getWindowLength();
|
||||
Value opDtype = op.getDtype();
|
||||
Value opLayout = op.getLayout();
|
||||
Value opDevice = op.getDevice();
|
||||
Value opPinMemory = op.getPinMemory();
|
||||
|
||||
int64_t window_length;
|
||||
if (!matchPattern(opWindowLength, m_TorchConstantInt(&window_length)) ||
|
||||
window_length <= 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected a constant integer greater than zero");
|
||||
bool periodic;
|
||||
if (!matchPattern(op.getPeriodic(), m_TorchConstantBool(&periodic)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected a constant boolean value for periodic");
|
||||
|
||||
if (window_length == 1) {
|
||||
Value one =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
SmallVector<Value> sizes({one});
|
||||
Value sizeList = rewriter.create<PrimListConstructOp>(
|
||||
loc, ListType::get(IntType::get(context)), sizes);
|
||||
rewriter.replaceOpWithNewOp<AtenOnesOp>(op, opType, sizeList, opDtype,
|
||||
opLayout, opDevice, opPinMemory);
|
||||
return success();
|
||||
}
|
||||
|
||||
Value zero =
|
||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
|
||||
|
||||
Value arange = rewriter.create<AtenArangeStartOp>(
|
||||
loc, opType, zero, op.getWindowLength(), opDtype, opLayout, opDevice,
|
||||
opPinMemory);
|
||||
|
||||
double denominator = !periodic ? window_length - 1 : window_length;
|
||||
|
||||
double piOverDenominator = 3.14159 / denominator;
|
||||
|
||||
Value cstFactor = rewriter.create<ConstantFloatOp>(
|
||||
loc, rewriter.getF64FloatAttr(piOverDenominator));
|
||||
|
||||
Value fraction =
|
||||
rewriter.create<AtenMulScalarOp>(loc, opType, arange, cstFactor);
|
||||
Value sine = rewriter.create<AtenSinOp>(loc, opType, fraction);
|
||||
|
||||
rewriter.replaceOpWithNewOp<AtenSquareOp>(op, opType, sine);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.scatter.value` op into `aten.scatter.src` op.
|
||||
class DecomposeAtenScatterValueOp
|
||||
|
@ -8989,6 +9054,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenHannWindowPeriodicOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenScalarTensor>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenScatterValueOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSgnOp>(patterns);
|
||||
|
|
|
@ -540,6 +540,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenCrossEntropyLossOp>();
|
||||
target.addIllegalOp<AtenVarMeanDimOp>();
|
||||
target.addIllegalOp<AtenTopkOp>();
|
||||
target.addIllegalOp<AtenHannWindowPeriodicOp>();
|
||||
target.addIllegalOp<AtenScalarTensorOp>();
|
||||
target.addIllegalOp<AtenScatterValueOp>();
|
||||
target.addIllegalOp<AtenTypeAsOp>();
|
||||
|
|
|
@ -880,6 +880,8 @@ STABLEHLO_PASS_SET = {
|
|||
"ArgmaxModule_with_dim",
|
||||
"AtenComplex64Module_basic",
|
||||
"AtenFloatScalarModule_basic",
|
||||
"AtenHannWindowPeriodicFalseModule_basic",
|
||||
"AtenHannWindowPeriodicTrueModule_basic",
|
||||
"AtenIntBoolOpConstFalseModule_basic",
|
||||
"AtenIntBoolOpConstTrueModule_basic",
|
||||
"AtenIntBoolOpModule_basic",
|
||||
|
@ -2839,6 +2841,8 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"AtenEyeMModuleInt2D_basic",
|
||||
"AtenEyeModuleInt2D_basic",
|
||||
"AtenFloatScalarModule_basic",
|
||||
"AtenHannWindowPeriodicTrueModule_basic",
|
||||
"AtenHannWindowPeriodicFalseModule_basic",
|
||||
"AtenInstanceNormModule_basic",
|
||||
"AtenIntBoolOpConstFalseModule_basic",
|
||||
"AtenIntBoolOpConstTrueModule_basic",
|
||||
|
|
|
@ -290,6 +290,9 @@ def aten〇log〡shape(self: List[int]) -> List[int]:
|
|||
def aten〇log_sigmoid〡shape(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇hann_window〇periodic〡shape(window_length: int, periodic: bool, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
||||
return [window_length]
|
||||
|
||||
def aten〇hardshrink〡shape(self: List[int], lambd: float = 0.5) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
|
@ -2444,6 +2447,15 @@ def aten〇log_sigmoid〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
|||
assert not self_dtype == torch.bool
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function([Invocation(10, False), Invocation(10, True),
|
||||
Invocation(10, False, dtype=torch.float32), Invocation(10, True, dtype=torch.float32),
|
||||
Invocation(10, False, dtype=torch.float64), Invocation(10, True, dtype=torch.float64)])
|
||||
def aten〇hann_window〇periodic〡dtype(window_length: int, periodic: bool, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
|
||||
result_dtype = torch.float32 if dtype is None else dtype
|
||||
assert is_float_dtype(result_dtype)
|
||||
return result_dtype
|
||||
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, lambd=0.5))
|
||||
def aten〇hardshrink〡dtype(self_rank_dtype: Tuple[int, int], lambd: Union[int, float, complex] = 0.5) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
|
|
@ -920,6 +920,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit_with_mutating_variants(
|
||||
"aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::hann_window.periodic : (int, bool, int?, int?, Device?, bool?) -> (Tensor)"
|
||||
)
|
||||
emit("aten::fft_fft : (Tensor, int?, int, str?) -> (Tensor)")
|
||||
emit("aten::fft_ifft : (Tensor, int?, int, str?) -> (Tensor)")
|
||||
emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)")
|
||||
|
|
|
@ -41,6 +41,7 @@ def register_all_tests():
|
|||
from . import elementwise_comparison
|
||||
from . import squeeze
|
||||
from . import slice_like
|
||||
from . import spectral
|
||||
from . import nll_loss
|
||||
from . import index_select
|
||||
from . import linalg_algorithms
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir_e2e_test.framework import TestUtils
|
||||
from torch_mlir_e2e_test.registry import register_test_case
|
||||
from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AtenHannWindowPeriodicFalseModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return torch.ops.aten.hann_window(20, False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenHannWindowPeriodicFalseModule())
|
||||
def AtenHannWindowPeriodicFalseModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class AtenHannWindowPeriodicTrueModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return torch.ops.aten.hann_window(20, True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenHannWindowPeriodicTrueModule())
|
||||
def AtenHannWindowPeriodicTrueModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
Loading…
Reference in New Issue