Add leakyrelu support

pull/436/head snapshot-20211127.110
Liam Fitzpatrick 2021-11-24 13:38:59 +00:00 committed by Prashant Kumar
parent f461a7ebce
commit 7616d28ce1
5 changed files with 73 additions and 4 deletions

View File

@ -196,6 +196,24 @@ def ElementwiseReluModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 2) - 0.5)
# ==============================================================================
class ElementwiseLeakyReluModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.leaky_relu(x, negative_slope=0.1)
@register_test_case(module_factory=lambda: ElementwiseLeakyReluModule())
def ElementwiseLeakyReluModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 2) - 0.5)
# ==============================================================================
class ElementwiseGeluModule(torch.nn.Module):

View File

@ -72,6 +72,36 @@ def Torch_AtenRelu_Op : Torch_Op<"aten.relu_", [
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
}
def Torch_AtenLeakyReluOp : Torch_Op<"aten.leaky_relu", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::leaky_relu : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$negative_slope
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $negative_slope attr-dict `:` type($self) `,` type($negative_slope) `->` type($result)";
}
def Torch_AtenLeakyRelu_Op : Torch_Op<"aten.leaky_relu_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::leaky_relu_ : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$negative_slope
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $negative_slope attr-dict `:` type($self) `,` type($negative_slope) `->` type($result)";
}
def Torch_AtenLogOp : Torch_Op<"aten.log", [
AllowsTypeRefinement,
HasValueSemantics

View File

@ -10,6 +10,7 @@
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "../PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@ -1378,11 +1379,30 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
Type elementType = payloadArgs[0].getType();
Value constZero =
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.0));
b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
payloadArgs[0], constZero);
return b.create<SelectOp>(loc, pred, payloadArgs[0], constZero);
}
if (auto lrelu = dyn_cast<AtenLeakyReluOp>(op)) {
if (!lrelu.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
lrelu.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Type elementType = payloadArgs[0].getType();
Value constZero =
b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
payloadArgs[0], constZero);
Value positivePart = b.create<SelectOp>(loc, pred, payloadArgs[0], constZero);
Value negativePart = b.create<SelectOp>(loc, pred, constZero, payloadArgs[0]);
Value scale = convertScalarToDtype(b, loc, operands[1], elementType);
Value scaledNegativePart = b.create<arith::MulFOp>(loc, negativePart, scale);
return b.create<arith::AddFOp>(loc, positivePart, scaledNegativePart);
}
if (auto gelu = dyn_cast<AtenGeluOp>(op)) {
if (!gelu.getType()
.cast<ValueTensorType>()
@ -1812,7 +1832,7 @@ struct ConvertElementwiseOp : ConversionPattern {
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (!isa<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp,
if (!isa<AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp,
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp,
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp,
@ -2969,7 +2989,7 @@ public:
target.addIllegalOp<AtenBatchNormOp>();
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
target.addIllegalOp<
AtenTanhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp,
AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp,
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp,
AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,

View File

@ -289,7 +289,7 @@ public:
return visitAtenAdaptiveAvgPool2dOp(avgPool2d, operands);
} else if (isa<AtenAddScalarOp, AtenSubScalarOp, AtenMulScalarOp,
AtenDivScalarOp, AtenFmodScalarOp, AtenFloorDivideScalarOp,
AtenPowTensorScalarOp, AtenRsubScalarOp>(op)) {
AtenPowTensorScalarOp, AtenRsubScalarOp, AtenLeakyReluOp>(op)) {
return visitBinaryTensorScalarOp(op, operands);
} else if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp,
AtenDivTensorOp, Aten__And__TensorOp, AtenEqTensorOp,

View File

@ -439,6 +439,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
for key in [
"aten::tanh : (Tensor) -> (Tensor)",
"aten::relu : (Tensor) -> (Tensor)",
"aten::leaky_relu : (Tensor, Scalar) -> (Tensor)",
"aten::log : (Tensor) -> (Tensor)",
"aten::sigmoid : (Tensor) -> (Tensor)",
"aten::sin : (Tensor) -> (Tensor)",