mirror of https://github.com/llvm/torch-mlir
[torch] Add support for aten.selu (#2640)
Add `aten.selu` operation to `torch` dialect.pull/2634/head snapshot-20231214.1052
parent
42392bc845
commit
6ddeb1a6ef
|
@ -346,6 +346,51 @@ def Torch_AtenLog_Op : Torch_Op<"aten.log_", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenSeluOp : Torch_Op<"aten.selu", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::selu : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenSeluOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenSeluOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenSelu_Op : Torch_Op<"aten.selu_", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::selu_ : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
Torch_NonValueTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
Torch_NonValueTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenSelu_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenSelu_Op::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenSigmoidOp : Torch_Op<"aten.sigmoid", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -6746,6 +6746,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.gather\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg2) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -10434,6 +10438,28 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.selu\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %1 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||
" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n"
|
||||
" torch.prim.If %3 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.remainder.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
|
|
|
@ -1937,6 +1937,55 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Selu = scale * (max(0,x) + min(0,alpha * (exp(x) − 1)))
|
||||
namespace {
|
||||
class DecomposeAtenSeluOp : public OpRewritePattern<AtenSeluOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenSeluOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.getSelf();
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
if (!resType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||
}
|
||||
|
||||
// Define λ and α
|
||||
double scale = 1.0507009873554804934193349852946;
|
||||
double alpha = 1.6732632423543772848170429916717;
|
||||
|
||||
// Create constants for λ and α
|
||||
Value scaleVal = rewriter.create<Torch::ConstantFloatOp>(loc, rewriter.getF64FloatAttr(scale));
|
||||
Value alphaVal = rewriter.create<Torch::ConstantFloatOp>(loc, rewriter.getF64FloatAttr(alpha));
|
||||
|
||||
// Create zero tensor for comparison
|
||||
Value constantZero =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero);
|
||||
|
||||
// Calculate positive and negative parts
|
||||
Value constantOne =
|
||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
||||
Value positiveOutput = rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, input);
|
||||
Value minZeroX =
|
||||
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, input);
|
||||
Value expInput = rewriter.create<AtenExpOp>(loc, resType, minZeroX);
|
||||
Value expInputMinusOne = rewriter.create<AtenSubScalarOp>(loc, resType, expInput, constantOne, constantOne);
|
||||
Value negativeOutput = rewriter.create<AtenMulScalarOp>(loc, resType, expInputMinusOne, alphaVal);
|
||||
|
||||
// Multiply the result by λ
|
||||
Value seluOutput = rewriter.create<AtenAddTensorOp>(
|
||||
loc, resType, positiveOutput, negativeOutput, constantOne);
|
||||
seluOutput = rewriter.create<AtenMulScalarOp>(loc, resType, seluOutput, scaleVal);
|
||||
|
||||
// Replace the original operation
|
||||
rewriter.replaceOp(op, seluOutput);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenTOp : public OpRewritePattern<AtenTOp> {
|
||||
public:
|
||||
|
@ -6460,6 +6509,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnLikeOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenEluOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSeluOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluBackwardOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyStridedOp>(patterns);
|
||||
|
|
|
@ -437,6 +437,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenRelu6Op>();
|
||||
target.addIllegalOp<AtenEluOp>();
|
||||
target.addIllegalOp<AtenGluOp>();
|
||||
target.addIllegalOp<AtenSeluOp>();
|
||||
target.addIllegalOp<AtenHardswishOp>();
|
||||
target.addIllegalOp<AtenSoftplusOp>();
|
||||
target.addIllegalOp<AtenSiluOp>();
|
||||
|
|
|
@ -486,6 +486,7 @@ STABLEHLO_PASS_SET = {
|
|||
"ElementwiseLeakyReluModule_basic",
|
||||
"ElementwiseEluModule_basic",
|
||||
"ElementwiseEluNonDefaultModule_basic",
|
||||
"ElementwiseSeluModule_basic",
|
||||
"ElementwiseLogModule_basic",
|
||||
"ElementwiseNegModule_basic",
|
||||
"ElementwiseRsqrtModule_basic",
|
||||
|
@ -1115,6 +1116,7 @@ TOSA_PASS_SET = {
|
|||
"ElementwiseRemainderScalarModule_Int_basic",
|
||||
"ElementwiseRemainderScalarModule_Int_basic",
|
||||
"ElementwiseRsqrtModule_basic",
|
||||
"ElementwiseSeluModule_basic",
|
||||
"ElementwiseSigmoidModule_basic",
|
||||
"ElementwiseSignModule_basic",
|
||||
"ElementwiseSqrtIntModule_basic",
|
||||
|
|
|
@ -373,6 +373,9 @@ def aten〇elu〡shape(self: List[int], alpha: float = 1, scale: float = 1, inpu
|
|||
def aten〇prelu〡shape(self: List[int], weight: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇selu〡shape(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇gather〡shape(self: List[int], dim: int, index: List[int], sparse_grad: bool = False) -> List[int]:
|
||||
return upstream_shape_functions.unary(index)
|
||||
|
||||
|
@ -3066,6 +3069,14 @@ def aten〇elu〡dtype(self_rank_dtype: Tuple[int, int], alpha: Union[int, float
|
|||
assert not is_integer_dtype(self_dtype)
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool, torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64}))
|
||||
def aten〇selu〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
assert self_dtype != torch.bool
|
||||
assert not is_integer_dtype(self_dtype)
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0))
|
||||
|
|
|
@ -262,6 +262,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
"aten::relu6 : (Tensor) -> (Tensor)",
|
||||
"aten::leaky_relu : (Tensor, Scalar) -> (Tensor)",
|
||||
"aten::log : (Tensor) -> (Tensor)",
|
||||
"aten::selu : (Tensor) -> (Tensor)",
|
||||
"aten::sigmoid : (Tensor) -> (Tensor)",
|
||||
"aten::sign : (Tensor) -> (Tensor)",
|
||||
"aten::sgn : (Tensor) -> (Tensor)",
|
||||
|
|
|
@ -564,6 +564,27 @@ def ElementwiseGeluModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseSeluModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.selu(x)
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseSeluModule())
|
||||
def ElementwiseSeluModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 3, low=-1, high=1))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseSigmoidModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue