[Torch] emit aten.celu and decompose it (#3247)

CELU(x)=max(0,x)+min(0,α∗(exp(x/α)−1))
pull/3256/head
Xinyu Yang 2024-04-28 17:23:40 +08:00 committed by GitHub
parent 46c0f3cad0
commit 5684dc0441
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 159 additions and 0 deletions

View File

@ -4810,6 +4810,53 @@ def Torch_AtenPreluOp : Torch_Op<"aten.prelu", [
}];
}
def Torch_AtenCeluOp : Torch_Op<"aten.celu", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::celu : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$alpha
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenCeluOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenCeluOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenCelu_Op : Torch_Op<"aten.celu_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::celu_ : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
AnyTorchScalarType:$alpha
);
let results = (outs
AnyTorchOptionalNonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenCelu_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenCelu_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenRealOp : Torch_Op<"aten.real", [
AllowsTypeRefinement,
ReadOnly

View File

@ -6998,6 +6998,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.celu\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -10480,6 +10484,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.celu\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"

View File

@ -2415,6 +2415,50 @@ public:
} // namespace
// CELU(x)=max(0,x)+min(0,alpha(exp(x/alpha)1))
namespace {
class DecomposeAtenCeluOp : public OpRewritePattern<AtenCeluOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenCeluOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.getSelf();
Value alpha = op.getAlpha();
auto resType = cast<BaseTensorType>(op.getType());
if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
Value constantZero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value constantOne =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
// positiveOutput = max(0,x)
Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero);
Value positiveOutput =
rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, input);
// negativeOutput = min(0,alpha(exp(x/alpha)1))
Value scaledInput =
rewriter.create<AtenDivScalarOp>(loc, resType, input, alpha);
Value expX = rewriter.create<AtenExpOp>(loc, resType, scaledInput);
Value expXM1 = rewriter.create<AtenSubScalarOp>(loc, resType, expX,
constantOne, constantOne);
Value scaledExpXM1 =
rewriter.create<AtenMulScalarOp>(loc, resType, expXM1, alpha);
Value negativeOutput =
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, scaledExpXM1);
Value celuOutput = rewriter.create<AtenAddTensorOp>(
loc, resType, positiveOutput, negativeOutput, constantOne);
rewriter.replaceOp(op, celuOutput);
return success();
}
};
} // namespace
namespace {
class DecomposeAtenLerpScalarOp : public OpRewritePattern<AtenLerpScalarOp> {
public:
@ -7705,6 +7749,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenPreluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenCeluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEinsumOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTraceOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenHardswishOp>(patterns);

View File

@ -474,6 +474,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<Aten_UnsafeIndexPutHackedTwinOp>();
target.addIllegalOp<AtenPadOp>();
target.addIllegalOp<AtenPreluOp>();
target.addIllegalOp<AtenCeluOp>();
target.addIllegalOp<AtenToDtypeLayoutOp>();
target.addIllegalOp<AtenToDeviceOp>();
target.addIllegalOp<AtenToPrimDeviceOp>();

View File

@ -951,6 +951,7 @@ STABLEHLO_PASS_SET = {
"ElementwiseBitwiseRightShiftInt64Module_basic",
"ElementwiseBitwiseRightShiftInt8Module_basic",
"ElementwiseCeilModule_basic",
"ElementwiseCeluStaticModule_basic",
"ElementwiseClampMaxModule_basic",
"ElementwiseClampMinModule_basic",
"ElementwiseClampMinTensorFloatModule_basic",
@ -1571,6 +1572,8 @@ TOSA_PASS_SET = {
"ElementwiseBitwiseXorModule_basic",
"ElementwiseBitwiseXorStaticShapeModule_basic",
"ElementwiseCeilModule_basic",
"ElementwiseCeluModule_basic",
"ElementwiseCeluStaticModule_basic",
"ElementwiseClampMaxModule_basic",
"ElementwiseClampMinModule_basic",
"ElementwiseClampModule_basic",

View File

@ -526,6 +526,9 @@ def atenelu〡shape(self: List[int], alpha: float = 1, scale: float = 1, inpu
def atenprelu〡shape(self: List[int], weight: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
def atencelu〡shape(self: List[int], alpha: float = 1.) -> List[int]:
return upstream_shape_functions.unary(self)
def atenselu〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
@ -2652,6 +2655,11 @@ def atenprelu〡dtype(self_rank_dtype: Tuple[int, int], weight_rank_dtype: Tu
assert self_dtype == weight_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, alpha=1.))
def atencelu〡dtype(self_rank_dtype: Tuple[int, int], alpha: Union[int, float, complex] = 1.) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool}))
def atenrelu6〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype

View File

@ -472,6 +472,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::floor_divide : (Tensor, Tensor) -> (Tensor)")
emit("aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)")
emit("aten::prelu : (Tensor, Tensor) -> (Tensor)")
emit_with_mutating_variants("aten::celu : (Tensor, Scalar) -> (Tensor)")
emit("aten::real : (Tensor) -> (Tensor)")
emit("aten::imag : (Tensor) -> (Tensor)")
emit("aten::view_as_complex : (Tensor) -> (Tensor)")

View File

@ -1016,6 +1016,52 @@ def ElementwisePreluStaticModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseCeluModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.celu(x, 0.5)
@register_test_case(module_factory=lambda: ElementwiseCeluModule())
def ElementwiseCeluModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3, low=-1, high=1))
# ==============================================================================
class ElementwiseCeluStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([5, 3], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.celu(x)
@register_test_case(module_factory=lambda: ElementwiseCeluStaticModule())
def ElementwiseCeluStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3, low=-1, high=1))
# ==============================================================================
class ElementwiseGeluModule(torch.nn.Module):
def __init__(self):
super().__init__()