mirror of https://github.com/llvm/torch-mlir
Add unary tanh lowering.
parent
b0ac04001d
commit
c6d56fed8a
|
@ -47,6 +47,24 @@ class ConvertATenAdd : public OpRewritePattern<aten::AddOp> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Common conversion template for unary ops that map 1:1.
|
||||||
|
template <typename SourceOp, typename TargetOp>
|
||||||
|
class ConvertUnary : public OpRewritePattern<SourceOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<SourceOp>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(SourceOp srcOp,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto operands = srcOp.getOperation()->getOperands();
|
||||||
|
auto results = srcOp.getOperation()->getResults();
|
||||||
|
assert(operands.size() == 1 && "expected unary op");
|
||||||
|
assert(results.size() == 1 && "expected single result op");
|
||||||
|
Type resultType = results[0].getType();
|
||||||
|
rewriter.replaceOpWithNewOp<TargetOp>(srcOp, resultType,
|
||||||
|
srcOp->getOperand(0));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/// Common conversion template for true binary elementwise ops.
|
/// Common conversion template for true binary elementwise ops.
|
||||||
/// This does not apply to the handful of not-actually-binary PyTorch ops that
|
/// This does not apply to the handful of not-actually-binary PyTorch ops that
|
||||||
/// have broadcastable self/other operands but may have additional parameters.
|
/// have broadcastable self/other operands but may have additional parameters.
|
||||||
|
@ -152,6 +170,7 @@ class ConvertATenConv2d : public OpRewritePattern<aten::Conv2dOp> {
|
||||||
void mlir::NPCOMP::populateCoreATenToTCFPatterns(RewritePatternSet &patterns) {
|
void mlir::NPCOMP::populateCoreATenToTCFPatterns(RewritePatternSet &patterns) {
|
||||||
MLIRContext *context = patterns.getContext();
|
MLIRContext *context = patterns.getContext();
|
||||||
patterns.add<ConvertATenAdd>(context);
|
patterns.add<ConvertATenAdd>(context);
|
||||||
|
patterns.add<ConvertUnary<aten::TanhOp, tcf::TanhOp>>(context);
|
||||||
patterns.add<ConvertBinaryElementwise<aten::MulOp, tcf::MulOp>>(context);
|
patterns.add<ConvertBinaryElementwise<aten::MulOp, tcf::MulOp>>(context);
|
||||||
patterns.add<ConvertBinaryElementwise<aten::MaximumOp, tcf::MaxOp>>(context);
|
patterns.add<ConvertBinaryElementwise<aten::MaximumOp, tcf::MaxOp>>(context);
|
||||||
patterns.add<ConvertBinaryElementwise<aten::MmOp, tcf::MatmulOp>>(context);
|
patterns.add<ConvertBinaryElementwise<aten::MmOp, tcf::MatmulOp>>(context);
|
||||||
|
|
|
@ -13,6 +13,13 @@ func @conv2d(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>, %arg2: tens
|
||||||
return %3 : tensor<?x?x?x?xf32>
|
return %3 : tensor<?x?x?x?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @unary_ops
|
||||||
|
func @unary_ops(%arg0: tensor<?x5x1xf32>) -> tensor<?x5x1xf32> {
|
||||||
|
// CHECK: tcf.tanh %arg0 : tensor<?x5x1xf32>
|
||||||
|
%0 = "aten.tanh"(%arg0) : (tensor<?x5x1xf32>) -> tensor<?x5x1xf32>
|
||||||
|
return %0 : tensor<?x5x1xf32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @binary_elementwise_ops
|
// CHECK-LABEL: @binary_elementwise_ops
|
||||||
// NOTE: These are all template expanded, so just testing an examplar op and
|
// NOTE: These are all template expanded, so just testing an examplar op and
|
||||||
// special cases.
|
// special cases.
|
||||||
|
|
Loading…
Reference in New Issue