Add aten.relu Linalg lowering support

pull/227/head
Yi Zhang 2021-06-15 16:00:37 +00:00 committed by Sean Silva
parent 3ccf6002af
commit 7b7c9c5d3d
2 changed files with 32 additions and 3 deletions

View File

@ -67,3 +67,18 @@ class MmTanhModule(torch.nn.Module):
@register_test_case(module_factory=lambda: MmTanhModule())
def MmTanhModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 2), tu.rand(2, 4))
class ReluModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.relu(x)
@register_test_case(module_factory=lambda: ReluModule())
def ReluModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 2) - 0.5)

View File

@ -249,7 +249,19 @@ public:
};
} // namespace
static Value createScalarRelu(OpBuilder &b, Location loc, ValueRange args) {
Type elementType = args[0].getType();
// TODO: Add support for integer types.
assert(elementType.isa<::mlir::FloatType>() &&
"Only support float case for relu");
Value constZero = b.create<ConstantOp>(loc, FloatAttr::get(elementType, 0.0));
Value pred = b.create<CmpFOp>(loc, CmpFPredicate::UGT, args[0], constZero);
return b.create<SelectOp>(loc, pred, args[0], constZero);
}
namespace {
// Converts a unary op. There is no implicit broadcasting behavior, so these can
// be trivially lowered to linalg.
// TODO: For binary ops, we will need a "linalg.generic-like" op that models
@ -264,7 +276,7 @@ struct ConvertUnaryOp : ConversionPattern {
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (!isa<AtenTanhOp>(op))
if (!isa<AtenTanhOp>(op) && !isa<AtenReluOp>(op))
return rewriter.notifyMatchFailure(op, "not a unary op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -287,9 +299,11 @@ struct ConvertUnaryOp : ConversionPattern {
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value result;
if (isa<AtenTanhOp>(op)) {
if (isa<AtenTanhOp>(op))
result = b.create<math::TanhOp>(loc, args[0]);
}
else if (isa<AtenReluOp>(op))
result = createScalarRelu(b, loc, args);
b.create<linalg::YieldOp>(loc, result);
});