Add unary tanh lowering.

pull/197/head
Sean Silva 2021-03-30 13:19:43 -07:00
parent b0ac04001d
commit c6d56fed8a
2 changed files with 26 additions and 0 deletions

View File

@ -47,6 +47,24 @@ class ConvertATenAdd : public OpRewritePattern<aten::AddOp> {
} }
}; };
/// Common conversion template for unary ops that map 1:1.
template <typename SourceOp, typename TargetOp>
class ConvertUnary : public OpRewritePattern<SourceOp> {
public:
using OpRewritePattern<SourceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SourceOp srcOp,
PatternRewriter &rewriter) const override {
auto operands = srcOp.getOperation()->getOperands();
auto results = srcOp.getOperation()->getResults();
assert(operands.size() == 1 && "expected unary op");
assert(results.size() == 1 && "expected single result op");
Type resultType = results[0].getType();
rewriter.replaceOpWithNewOp<TargetOp>(srcOp, resultType,
srcOp->getOperand(0));
return success();
}
};
/// Common conversion template for true binary elementwise ops. /// Common conversion template for true binary elementwise ops.
/// This does not apply to the handful of not-actually-binary PyTorch ops that /// This does not apply to the handful of not-actually-binary PyTorch ops that
/// have broadcastable self/other operands but may have additional parameters. /// have broadcastable self/other operands but may have additional parameters.
@ -152,6 +170,7 @@ class ConvertATenConv2d : public OpRewritePattern<aten::Conv2dOp> {
void mlir::NPCOMP::populateCoreATenToTCFPatterns(RewritePatternSet &patterns) { void mlir::NPCOMP::populateCoreATenToTCFPatterns(RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
patterns.add<ConvertATenAdd>(context); patterns.add<ConvertATenAdd>(context);
patterns.add<ConvertUnary<aten::TanhOp, tcf::TanhOp>>(context);
patterns.add<ConvertBinaryElementwise<aten::MulOp, tcf::MulOp>>(context); patterns.add<ConvertBinaryElementwise<aten::MulOp, tcf::MulOp>>(context);
patterns.add<ConvertBinaryElementwise<aten::MaximumOp, tcf::MaxOp>>(context); patterns.add<ConvertBinaryElementwise<aten::MaximumOp, tcf::MaxOp>>(context);
patterns.add<ConvertBinaryElementwise<aten::MmOp, tcf::MatmulOp>>(context); patterns.add<ConvertBinaryElementwise<aten::MmOp, tcf::MatmulOp>>(context);

View File

@ -13,6 +13,13 @@ func @conv2d(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>, %arg2: tens
return %3 : tensor<?x?x?x?xf32> return %3 : tensor<?x?x?x?xf32>
} }
// CHECK-LABEL: @unary_ops
func @unary_ops(%arg0: tensor<?x5x1xf32>) -> tensor<?x5x1xf32> {
// CHECK: tcf.tanh %arg0 : tensor<?x5x1xf32>
%0 = "aten.tanh"(%arg0) : (tensor<?x5x1xf32>) -> tensor<?x5x1xf32>
return %0 : tensor<?x5x1xf32>
}
// CHECK-LABEL: @binary_elementwise_ops // CHECK-LABEL: @binary_elementwise_ops
// NOTE: These are all template expanded, so just testing an examplar op and // NOTE: These are all template expanded, so just testing an examplar op and
// special cases. // special cases.