mirror of https://github.com/llvm/torch-mlir
Add sigmoid lowering
Follows existing conventions for activation functionspull/296/head
parent
29e1b2fe89
commit
d9df4bfc95
|
@ -149,7 +149,6 @@ class ElementwiseFlattenBroadcastModule(torch.nn.Module):
|
|||
def ElementwiseFlattenBroadcastModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6), tu.rand())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
@ -169,3 +168,24 @@ class ElementwiseReluModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: ElementwiseReluModule())
|
||||
def ElementwiseReluModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 2) - 0.5)
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseSigmoidModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.sigmoid(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseSigmoidModule())
|
||||
def ElementwiseSigmoidModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
|
|
|
@ -436,6 +436,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
for key in [
|
||||
"aten::tanh : (Tensor) -> (Tensor)",
|
||||
"aten::relu : (Tensor) -> (Tensor)",
|
||||
"aten::sigmoid : (Tensor) -> (Tensor)",
|
||||
"aten::sin : (Tensor) -> (Tensor)",
|
||||
"aten::exp : (Tensor) -> (Tensor)",
|
||||
"aten::cos : (Tensor) -> (Tensor)",
|
||||
|
|
|
@ -71,6 +71,34 @@ def Torch_AtenRelu_Op : Torch_Op<"aten.relu_", [
|
|||
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenSigmoidOp : Torch_Op<"aten.sigmoid", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::sigmoid : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenSigmoid_Op : Torch_Op<"aten.sigmoid_", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::sigmoid_ : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenSinOp : Torch_Op<"aten.sin", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
|
|
@ -654,6 +654,14 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
ArrayRef<Value> operands) {
|
||||
if (isa<AtenTanhOp>(op))
|
||||
return b.create<math::TanhOp>(loc, payloadArgs[0]);
|
||||
if (isa<AtenSigmoidOp>(op)){
|
||||
Type elementType = payloadArgs[0].getType();
|
||||
auto one = b.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
||||
auto negate = b.create<NegFOp>(loc, payloadArgs[0]);
|
||||
auto exp = b.create<math::ExpOp>(loc, negate);
|
||||
auto added = b.create<AddFOp>(loc, exp, one);
|
||||
return b.create<DivFOp>(loc, one, added);
|
||||
}
|
||||
if (auto relu = dyn_cast<AtenReluOp>(op)) {
|
||||
if (!relu.getType()
|
||||
.cast<ValueTensorType>()
|
||||
|
@ -775,7 +783,8 @@ struct ConvertElementwiseOp : ConversionPattern {
|
|||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (!isa<AtenTanhOp, AtenReluOp, AtenAddTensorOp, AtenMulTensorOp,
|
||||
AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp>(op))
|
||||
AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp,
|
||||
AtenSigmoidOp>(op))
|
||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
|
@ -1137,7 +1146,8 @@ public:
|
|||
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
|
||||
target
|
||||
.addIllegalOp<AtenTanhOp, AtenReluOp, AtenAddTensorOp, AtenMulTensorOp,
|
||||
AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp>();
|
||||
AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp,
|
||||
AtenSigmoidOp>();
|
||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenUnsqueezeOp>();
|
||||
patterns.add<ConvertAtenUnsqueezeOp>(typeConverter, context);
|
||||
|
|
|
@ -175,7 +175,7 @@ public:
|
|||
AtenSubScalarOp, AtenMulScalarOp, AtenDivScalarOp, AtenFmodScalarOp,
|
||||
AtenFloorDivideScalarOp, AtenEqScalarOp, AtenGeScalarOp,
|
||||
AtenNeScalarOp, AtenBitwiseNotOp, AtenToDtypeOp, AtenExpOp,
|
||||
AtenSinOp, AtenCosOp, DerefineOp>(op)) {
|
||||
AtenSinOp, AtenCosOp, AtenSigmoidOp, DerefineOp>(op)) {
|
||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue