diff --git a/frontends/pytorch/e2e_testing/torchscript/basic.py b/frontends/pytorch/e2e_testing/torchscript/basic.py index f0cb1205c..fc1a53650 100644 --- a/frontends/pytorch/e2e_testing/torchscript/basic.py +++ b/frontends/pytorch/e2e_testing/torchscript/basic.py @@ -67,3 +67,18 @@ class MmTanhModule(torch.nn.Module): @register_test_case(module_factory=lambda: MmTanhModule()) def MmTanhModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 2), tu.rand(2, 4)) + +class ReluModule(torch.nn.Module): + def __init__(self): + super().__init__() + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.relu(x) + +@register_test_case(module_factory=lambda: ReluModule()) +def ReluModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 2) - 0.5) diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 6f6f3d351..3c23efcfb 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -249,7 +249,19 @@ public: }; } // namespace +static Value createScalarRelu(OpBuilder &b, Location loc, ValueRange args) { + Type elementType = args[0].getType(); + // TODO: Add support for integer types. + assert(elementType.isa<::mlir::FloatType>() && + "Only support float case for relu"); + + Value constZero = b.create(loc, FloatAttr::get(elementType, 0.0)); + Value pred = b.create(loc, CmpFPredicate::UGT, args[0], constZero); + return b.create(loc, pred, args[0], constZero); +} + namespace { + // Converts a unary op. There is no implicit broadcasting behavior, so these can // be trivially lowered to linalg. // TODO: For binary ops, we will need a "linalg.generic-like" op that models @@ -264,7 +276,7 @@ struct ConvertUnaryOp : ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - if (!isa(op)) + if (!isa(op) && !isa(op)) return rewriter.notifyMatchFailure(op, "not a unary op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -287,9 +299,11 @@ struct ConvertUnaryOp : ConversionPattern { /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value result; - if (isa(op)) { + if (isa(op)) result = b.create(loc, args[0]); - } + else if (isa(op)) + result = createScalarRelu(b, loc, args); + b.create(loc, result); });