mirror of https://github.com/llvm/torch-mlir
[RefE2E] Add support for unary ops exp and tanh
This is fairly mechanical.pull/60/head
parent
6b69beae6a
commit
f9b37c55b7
|
@ -43,6 +43,29 @@ def TCF_MaxOp : BinaryArithmeticOp<"max"> {
|
|||
}];
|
||||
}
|
||||
|
||||
class UnaryArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
TCF_Op<mnemonic,
|
||||
!listconcat(traits, [AllTypesMatch<["operand", "result"]>])>,
|
||||
AllTypesMatch<["operand", "result"]> {
|
||||
let arguments = (ins AnyTensor:$operand);
|
||||
let results = (outs AnyTensor:$result);
|
||||
let assemblyFormat = "$operand attr-dict `:` type($operand)";
|
||||
}
|
||||
|
||||
def TCF_ExpOp : UnaryArithmeticOp<"exp"> {
|
||||
let summary = "base-e exponential";
|
||||
let description = [{
|
||||
See std.exp for more details.
|
||||
}];
|
||||
}
|
||||
|
||||
def TCF_TanhOp : UnaryArithmeticOp<"tanh"> {
|
||||
let summary = "hyperbolic tangent";
|
||||
let description = [{
|
||||
See std.tanh for more details.
|
||||
}];
|
||||
}
|
||||
|
||||
// TODO: Generalize this op appropriately and add more verification.
|
||||
// For example, an unranked operand probably should be allowed and verified
|
||||
// dynamically in TCF->TCP lowering if needed.
|
||||
|
|
|
@ -42,6 +42,29 @@ def TCP_MaxOp : BinaryArithmeticOp<"max"> {
|
|||
}];
|
||||
}
|
||||
|
||||
class UnaryArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
TCP_Op<mnemonic,
|
||||
!listconcat(traits, [AllTypesMatch<["operand", "result"]>])>,
|
||||
AllTypesMatch<["operand", "result"]> {
|
||||
let arguments = (ins AnyTensor:$operand);
|
||||
let results = (outs AnyTensor:$result);
|
||||
let assemblyFormat = "$operand attr-dict `:` type($operand)";
|
||||
}
|
||||
|
||||
def TCP_ExpOp : UnaryArithmeticOp<"exp"> {
|
||||
let summary = "base-e exponential";
|
||||
let description = [{
|
||||
See std.exp for more details.
|
||||
}];
|
||||
}
|
||||
|
||||
def TCP_TanhOp : UnaryArithmeticOp<"tanh"> {
|
||||
let summary = "hyperbolic tangent";
|
||||
let description = [{
|
||||
See std.tanh for more details.
|
||||
}];
|
||||
}
|
||||
|
||||
// TODO: Generalize this op appropriately and add more verification.
|
||||
// For example, should we have a single primitive that does multidimensional
|
||||
// contractions? + batching as well in the same op? In fact, if we want to
|
||||
|
|
|
@ -73,6 +73,10 @@ matchAndRewriteBinaryElementwise(Operation *op, PatternRewriter &rewriter) {
|
|||
} else if (isa<tcf::MaxOp>(op)) {
|
||||
binaryOpResult = rewriter.create<tcp::MaxOp>(
|
||||
loc, result.getType(), lhsBroadcasted, rhsBroadcasted);
|
||||
} else {
|
||||
op->dump();
|
||||
llvm::report_fatal_error(
|
||||
"unhandled op (see dump above): TCF->TCP binary elementwise");
|
||||
}
|
||||
rewriter.create<shape::AssumingYieldOp>(loc, binaryOpResult);
|
||||
|
||||
|
@ -93,6 +97,33 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
static LogicalResult
|
||||
matchAndRewriteUnaryElementwise(Operation *op, PatternRewriter &rewriter) {
|
||||
if (isa<tcf::ExpOp>(op)) {
|
||||
rewriter.replaceOpWithNewOp<tcp::ExpOp>(op, op->getOperand(0));
|
||||
} else if (isa<tcf::TanhOp>(op)) {
|
||||
rewriter.replaceOpWithNewOp<tcp::TanhOp>(op, op->getOperand(0));
|
||||
} else {
|
||||
op->dump();
|
||||
llvm::report_fatal_error(
|
||||
"unhandled op (see dump above): TCF->TCP unary elementwise");
|
||||
}
|
||||
return success();
|
||||
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename SourceOp>
|
||||
class ConvertUnaryElementwise : public OpRewritePattern<SourceOp> {
|
||||
public:
|
||||
using OpRewritePattern<SourceOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(SourceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
return matchAndRewriteUnaryElementwise(op, rewriter);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertMatmul : public OpRewritePattern<tcf::MatmulOp> {
|
||||
public:
|
||||
|
@ -134,6 +165,8 @@ public:
|
|||
MLIRContext *context = &getContext();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<ConvertUnaryElementwise<tcf::ExpOp>,
|
||||
ConvertUnaryElementwise<tcf::TanhOp>>(context);
|
||||
patterns.insert<ConvertBinaryElementwise<tcf::AddOp>,
|
||||
ConvertBinaryElementwise<tcf::MaxOp>>(context);
|
||||
patterns.insert<ConvertMatmul>(context);
|
||||
|
|
|
@ -24,8 +24,8 @@ static SmallVector<Value, 6> bypassResultShapes(Operation &op) {
|
|||
return {broadcastTo.shape()};
|
||||
}
|
||||
|
||||
// Binary elementwise ops.
|
||||
if (isa<tcp::AddOp, tcp::MaxOp>(op)) {
|
||||
// Elementwise ops.
|
||||
if (isa<tcp::AddOp, tcp::MaxOp, tcp::ExpOp, tcp::TanhOp>(op)) {
|
||||
return {builder.create<shape::ShapeOfOp>(op.getLoc(), op.getOperand(0))};
|
||||
}
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ add_mlir_library(NPCOMPE2E
|
|||
MLIRSCFToStandard
|
||||
MLIRShapeToStandard
|
||||
MLIRStandardOps
|
||||
MLIRStandardOpsTransforms
|
||||
MLIRStandardToLLVM
|
||||
)
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h"
|
||||
|
@ -707,6 +708,10 @@ class LowerToLLVM : public LowerToLLVMBase<LowerToLLVM> {
|
|||
patterns.insert<LowerModuleMetadata>(context);
|
||||
patterns.insert<LowerNpcomprtGlobalOp>(converter);
|
||||
|
||||
// TODO: Move these "std to std" legalizations to their own pass if we grow
|
||||
// lots of these patterns.
|
||||
populateExpandTanhPattern(patterns, context);
|
||||
|
||||
if (failed(applyFullConversion(module, target, patterns))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
|
|
@ -132,27 +132,32 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
static Value createLinalgBodyCalculationForBinaryElementwise(Operation *op,
|
||||
Value lhsBodyArg,
|
||||
Value rhsBodyArg,
|
||||
OpBuilder &builder,
|
||||
Location loc) {
|
||||
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op,
|
||||
ValueRange bodyArgs,
|
||||
OpBuilder &builder,
|
||||
Location loc) {
|
||||
if (isa<tcp::AddOp>(op))
|
||||
return builder.create<AddFOp>(loc, lhsBodyArg, rhsBodyArg);
|
||||
return builder.create<AddFOp>(loc, bodyArgs[0], bodyArgs[1]);
|
||||
|
||||
if (isa<tcp::MaxOp>(op)) {
|
||||
auto greater =
|
||||
builder.create<CmpFOp>(loc, CmpFPredicate::OGT, lhsBodyArg, rhsBodyArg);
|
||||
return builder.create<SelectOp>(loc, greater, lhsBodyArg, rhsBodyArg);
|
||||
builder.create<CmpFOp>(loc, CmpFPredicate::OGT, bodyArgs[0], bodyArgs[1]);
|
||||
return builder.create<SelectOp>(loc, greater, bodyArgs[0], bodyArgs[1]);
|
||||
}
|
||||
|
||||
if (isa<tcp::ExpOp>(op))
|
||||
return builder.create<ExpOp>(loc, bodyArgs[0]);
|
||||
|
||||
if (isa<tcp::TanhOp>(op))
|
||||
return builder.create<TanhOp>(loc, bodyArgs[0]);
|
||||
|
||||
op->dump();
|
||||
llvm::report_fatal_error(
|
||||
"unhandled op (see dump above) when lowering binary elementwise ops");
|
||||
llvm::report_fatal_error("unhandled op (see dump above): linalg body "
|
||||
"calculation for elementwise op");
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
matchAndRewriteBinaryElementwiseOp(Operation *op, ArrayRef<Value> operands,
|
||||
matchAndRewriteElementwiseOp(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
Location loc = op->getLoc();
|
||||
Value result = op->getResult(0);
|
||||
|
@ -187,8 +192,8 @@ matchAndRewriteBinaryElementwiseOp(Operation *op, ArrayRef<Value> operands,
|
|||
/*iterator_types=*/iterators,
|
||||
/*bodyBuilder=*/
|
||||
[&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
|
||||
auto scalarResult = createLinalgBodyCalculationForBinaryElementwise(
|
||||
op, regionArgs[0], regionArgs[1], builder, loc);
|
||||
auto scalarResult = createLinalgBodyCalculationForElementwiseOp(
|
||||
op, regionArgs, builder, loc);
|
||||
builder.create<linalg::YieldOp>(loc, ValueRange({scalarResult}));
|
||||
});
|
||||
rewriter.replaceOp(op, results);
|
||||
|
@ -197,13 +202,13 @@ matchAndRewriteBinaryElementwiseOp(Operation *op, ArrayRef<Value> operands,
|
|||
|
||||
namespace {
|
||||
template <typename SourceOp>
|
||||
class LowerBinaryElementwiseOp : public OpConversionPattern<SourceOp> {
|
||||
class LowerElementwiseOp : public OpConversionPattern<SourceOp> {
|
||||
public:
|
||||
using OpConversionPattern<SourceOp>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
return matchAndRewriteBinaryElementwiseOp(op, operands, rewriter);
|
||||
return matchAndRewriteElementwiseOp(op, operands, rewriter);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
@ -322,9 +327,10 @@ class LowerShapedResultsToMemref
|
|||
|
||||
patterns.insert<LowerBroadcastToToLoopsPattern>(typeConverter, context);
|
||||
target.addIllegalOp<tcp::BroadcastToOp>();
|
||||
patterns.insert<LowerBinaryElementwiseOp<tcp::AddOp>,
|
||||
LowerBinaryElementwiseOp<tcp::MaxOp>>(typeConverter,
|
||||
context);
|
||||
patterns
|
||||
.insert<LowerElementwiseOp<tcp::AddOp>, LowerElementwiseOp<tcp::MaxOp>,
|
||||
LowerElementwiseOp<tcp::ExpOp>,
|
||||
LowerElementwiseOp<tcp::TanhOp>>(typeConverter, context);
|
||||
target.addIllegalOp<tcp::AddOp, tcp::MaxOp>();
|
||||
patterns.insert<LowerTcpMatmulOp>(typeConverter, context);
|
||||
target.addIllegalOp<tcp::MatmulOp>();
|
||||
|
|
|
@ -1,5 +1,15 @@
|
|||
// RUN: npcomp-opt <%s -convert-tcf-to-tcp | FileCheck %s --dump-input=fail
|
||||
|
||||
// CHECK-LABEL: func @unary_ops(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: %[[RET:.*]] = tcp.exp %[[ARG]] : tensor<?xf32>
|
||||
// CHECK: return %[[RET]] : tensor<?xf32>
|
||||
// CHECK: }
|
||||
func @unary_ops(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = tcf.exp %arg0 : tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tcf_add(
|
||||
// CHECK-SAME: %[[LHS:.*]]: tensor<?xf32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: tensor<?xf32>) -> tensor<?xf32> {
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
// RUN: npcomp-opt <%s | FileCheck %s --dump-input=fail
|
||||
// RUN: npcomp-opt <%s | npcomp-opt | FileCheck %s --dump-input=fail
|
||||
|
||||
// CHECK-LABEL: func @binary_elementwise
|
||||
func @binary_elementwise(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
|
||||
// CHECK: tcf.add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK: tcf.max %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK: tcf.exp %arg0 : tensor<?xf32>
|
||||
%0 = tcf.add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%1 = tcf.max %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%2 = tcf.exp %arg0 : tensor<?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,10 @@ func @global() {
|
|||
func @binary_elementwise(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: i32) {
|
||||
// CHECK: tcp.add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK: tcp.max %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK: tcp.exp %arg0 : tensor<?xf32>
|
||||
%0 = tcp.add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%1 = tcp.max %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%2 = tcp.exp %arg0 : tensor<?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -1,16 +0,0 @@
|
|||
// RUN: npcomp-run-mlir %s \
|
||||
// RUN: -invoke max \
|
||||
// RUN: -arg-value="dense<[1.0]> : tensor<1xf32>" \
|
||||
// RUN: -arg-value="dense<[3.0]> : tensor<1xf32>" \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s --check-prefix=MAX
|
||||
|
||||
// These ops share a lot of code paths. So we don't test the exact
|
||||
// broadcasting behavior and error checking for all of them.
|
||||
|
||||
// MAX: output #0: dense<3.000000e+00> : tensor<1xf32>
|
||||
func @max(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = tcf.max %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
// RUN: npcomp-run-mlir %s \
|
||||
// RUN: -invoke max \
|
||||
// RUN: -arg-value="dense<[1.0]> : tensor<1xf32>" \
|
||||
// RUN: -arg-value="dense<[3.0]> : tensor<1xf32>" \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s --check-prefix=MAX
|
||||
|
||||
// RUN: npcomp-run-mlir %s \
|
||||
// RUN: -invoke exp \
|
||||
// RUN: -arg-value="dense<[0.0, 1.0]> : tensor<2xf32>" \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s --check-prefix=EXP
|
||||
|
||||
// RUN: npcomp-run-mlir %s \
|
||||
// RUN: -invoke tanh \
|
||||
// RUN: -arg-value="dense<[0.0, 1.0]> : tensor<2xf32>" \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s --check-prefix=TANH
|
||||
|
||||
// These ops share a lot of code paths. So we don't test the exact
|
||||
// broadcasting behavior and error checking for all of them.
|
||||
|
||||
// MAX: output #0: dense<3.000000e+00> : tensor<1xf32>
|
||||
func @max(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = tcf.max %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// EXP: output #0: dense<[1.000000e+00, 2.71828175]> : tensor<2xf32>
|
||||
func @exp(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = tcf.exp %arg0 : tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// TANH: output #0: dense<[0.000000e+00, 0.761594116]> : tensor<2xf32>
|
||||
func @tanh(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = tcf.tanh %arg0 : tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
Loading…
Reference in New Issue