mirror of https://github.com/llvm/torch-mlir
Add unary tanh lowering.
parent
b0ac04001d
commit
c6d56fed8a
|
@ -47,6 +47,24 @@ class ConvertATenAdd : public OpRewritePattern<aten::AddOp> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Common conversion template for unary ops that map 1:1.
|
||||
template <typename SourceOp, typename TargetOp>
|
||||
class ConvertUnary : public OpRewritePattern<SourceOp> {
|
||||
public:
|
||||
using OpRewritePattern<SourceOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(SourceOp srcOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto operands = srcOp.getOperation()->getOperands();
|
||||
auto results = srcOp.getOperation()->getResults();
|
||||
assert(operands.size() == 1 && "expected unary op");
|
||||
assert(results.size() == 1 && "expected single result op");
|
||||
Type resultType = results[0].getType();
|
||||
rewriter.replaceOpWithNewOp<TargetOp>(srcOp, resultType,
|
||||
srcOp->getOperand(0));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Common conversion template for true binary elementwise ops.
|
||||
/// This does not apply to the handful of not-actually-binary PyTorch ops that
|
||||
/// have broadcastable self/other operands but may have additional parameters.
|
||||
|
@ -152,6 +170,7 @@ class ConvertATenConv2d : public OpRewritePattern<aten::Conv2dOp> {
|
|||
void mlir::NPCOMP::populateCoreATenToTCFPatterns(RewritePatternSet &patterns) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns.add<ConvertATenAdd>(context);
|
||||
patterns.add<ConvertUnary<aten::TanhOp, tcf::TanhOp>>(context);
|
||||
patterns.add<ConvertBinaryElementwise<aten::MulOp, tcf::MulOp>>(context);
|
||||
patterns.add<ConvertBinaryElementwise<aten::MaximumOp, tcf::MaxOp>>(context);
|
||||
patterns.add<ConvertBinaryElementwise<aten::MmOp, tcf::MatmulOp>>(context);
|
||||
|
|
|
@ -13,6 +13,13 @@ func @conv2d(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>, %arg2: tens
|
|||
return %3 : tensor<?x?x?x?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @unary_ops
|
||||
func @unary_ops(%arg0: tensor<?x5x1xf32>) -> tensor<?x5x1xf32> {
|
||||
// CHECK: tcf.tanh %arg0 : tensor<?x5x1xf32>
|
||||
%0 = "aten.tanh"(%arg0) : (tensor<?x5x1xf32>) -> tensor<?x5x1xf32>
|
||||
return %0 : tensor<?x5x1xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @binary_elementwise_ops
|
||||
// NOTE: These are all template expanded, so just testing an examplar op and
|
||||
// special cases.
|
||||
|
|
Loading…
Reference in New Issue