mirror of https://github.com/llvm/torch-mlir
parent
f461a7ebce
commit
7616d28ce1
|
@ -196,6 +196,24 @@ def ElementwiseReluModule_basic(module, tu: TestUtils):
|
|||
module.forward(tu.rand(4, 2) - 0.5)
|
||||
|
||||
# ==============================================================================
|
||||
class ElementwiseLeakyReluModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.leaky_relu(x, negative_slope=0.1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseLeakyReluModule())
|
||||
def ElementwiseLeakyReluModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 2) - 0.5)
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseGeluModule(torch.nn.Module):
|
||||
|
|
|
@ -72,6 +72,36 @@ def Torch_AtenRelu_Op : Torch_Op<"aten.relu_", [
|
|||
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenLeakyReluOp : Torch_Op<"aten.leaky_relu", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::leaky_relu : (Tensor, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchScalarType:$negative_slope
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $negative_slope attr-dict `:` type($self) `,` type($negative_slope) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenLeakyRelu_Op : Torch_Op<"aten.leaky_relu_", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::leaky_relu_ : (Tensor, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchScalarType:$negative_slope
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $negative_slope attr-dict `:` type($self) `,` type($negative_slope) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenLogOp : Torch_Op<"aten.log", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
@ -1378,11 +1379,30 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
Type elementType = payloadArgs[0].getType();
|
||||
Value constZero =
|
||||
b.create<arith::ConstantOp>(loc, FloatAttr::get(elementType, 0.0));
|
||||
b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
|
||||
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
|
||||
payloadArgs[0], constZero);
|
||||
return b.create<SelectOp>(loc, pred, payloadArgs[0], constZero);
|
||||
}
|
||||
if (auto lrelu = dyn_cast<AtenLeakyReluOp>(op)) {
|
||||
if (!lrelu.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getDtype()
|
||||
.isa<mlir::FloatType>()) {
|
||||
lrelu.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
Type elementType = payloadArgs[0].getType();
|
||||
Value constZero =
|
||||
b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
|
||||
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
|
||||
payloadArgs[0], constZero);
|
||||
Value positivePart = b.create<SelectOp>(loc, pred, payloadArgs[0], constZero);
|
||||
Value negativePart = b.create<SelectOp>(loc, pred, constZero, payloadArgs[0]);
|
||||
Value scale = convertScalarToDtype(b, loc, operands[1], elementType);
|
||||
Value scaledNegativePart = b.create<arith::MulFOp>(loc, negativePart, scale);
|
||||
return b.create<arith::AddFOp>(loc, positivePart, scaledNegativePart);
|
||||
}
|
||||
if (auto gelu = dyn_cast<AtenGeluOp>(op)) {
|
||||
if (!gelu.getType()
|
||||
.cast<ValueTensorType>()
|
||||
|
@ -1812,7 +1832,7 @@ struct ConvertElementwiseOp : ConversionPattern {
|
|||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (!isa<AtenTanhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp,
|
||||
if (!isa<AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp,
|
||||
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
|
||||
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp,
|
||||
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp,
|
||||
|
@ -2969,7 +2989,7 @@ public:
|
|||
target.addIllegalOp<AtenBatchNormOp>();
|
||||
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
|
||||
target.addIllegalOp<
|
||||
AtenTanhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp,
|
||||
AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp,
|
||||
AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp,
|
||||
AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
|
||||
AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp,
|
||||
|
|
|
@ -289,7 +289,7 @@ public:
|
|||
return visitAtenAdaptiveAvgPool2dOp(avgPool2d, operands);
|
||||
} else if (isa<AtenAddScalarOp, AtenSubScalarOp, AtenMulScalarOp,
|
||||
AtenDivScalarOp, AtenFmodScalarOp, AtenFloorDivideScalarOp,
|
||||
AtenPowTensorScalarOp, AtenRsubScalarOp>(op)) {
|
||||
AtenPowTensorScalarOp, AtenRsubScalarOp, AtenLeakyReluOp>(op)) {
|
||||
return visitBinaryTensorScalarOp(op, operands);
|
||||
} else if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp,
|
||||
AtenDivTensorOp, Aten__And__TensorOp, AtenEqTensorOp,
|
||||
|
|
|
@ -439,6 +439,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
for key in [
|
||||
"aten::tanh : (Tensor) -> (Tensor)",
|
||||
"aten::relu : (Tensor) -> (Tensor)",
|
||||
"aten::leaky_relu : (Tensor, Scalar) -> (Tensor)",
|
||||
"aten::log : (Tensor) -> (Tensor)",
|
||||
"aten::sigmoid : (Tensor) -> (Tensor)",
|
||||
"aten::sin : (Tensor) -> (Tensor)",
|
||||
|
|
Loading…
Reference in New Issue