[RefE2E] Add support for unary ops exp and tanh

This is fairly mechanical.
pull/60/head
Sean Silva 2020-09-24 17:14:21 -07:00
parent 6b69beae6a
commit f9b37c55b7
12 changed files with 165 additions and 37 deletions

View File

@ -43,6 +43,29 @@ def TCF_MaxOp : BinaryArithmeticOp<"max"> {
}]; }];
} }
class UnaryArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
TCF_Op<mnemonic,
!listconcat(traits, [AllTypesMatch<["operand", "result"]>])>,
AllTypesMatch<["operand", "result"]> {
let arguments = (ins AnyTensor:$operand);
let results = (outs AnyTensor:$result);
let assemblyFormat = "$operand attr-dict `:` type($operand)";
}
def TCF_ExpOp : UnaryArithmeticOp<"exp"> {
let summary = "base-e exponential";
let description = [{
See std.exp for more details.
}];
}
def TCF_TanhOp : UnaryArithmeticOp<"tanh"> {
let summary = "hyperbolic tangent";
let description = [{
See std.tanh for more details.
}];
}
// TODO: Generalize this op appropriately and add more verification. // TODO: Generalize this op appropriately and add more verification.
// For example, an unranked operand probably should be allowed and verified // For example, an unranked operand probably should be allowed and verified
// dynamically in TCF->TCP lowering if needed. // dynamically in TCF->TCP lowering if needed.

View File

@ -42,6 +42,29 @@ def TCP_MaxOp : BinaryArithmeticOp<"max"> {
}]; }];
} }
class UnaryArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
TCP_Op<mnemonic,
!listconcat(traits, [AllTypesMatch<["operand", "result"]>])>,
AllTypesMatch<["operand", "result"]> {
let arguments = (ins AnyTensor:$operand);
let results = (outs AnyTensor:$result);
let assemblyFormat = "$operand attr-dict `:` type($operand)";
}
def TCP_ExpOp : UnaryArithmeticOp<"exp"> {
let summary = "base-e exponential";
let description = [{
See std.exp for more details.
}];
}
def TCP_TanhOp : UnaryArithmeticOp<"tanh"> {
let summary = "hyperbolic tangent";
let description = [{
See std.tanh for more details.
}];
}
// TODO: Generalize this op appropriately and add more verification. // TODO: Generalize this op appropriately and add more verification.
// For example, should we have a single primitive that does multidimensional // For example, should we have a single primitive that does multidimensional
// contractions? + batching as well in the same op? In fact, if we want to // contractions? + batching as well in the same op? In fact, if we want to

View File

@ -73,6 +73,10 @@ matchAndRewriteBinaryElementwise(Operation *op, PatternRewriter &rewriter) {
} else if (isa<tcf::MaxOp>(op)) { } else if (isa<tcf::MaxOp>(op)) {
binaryOpResult = rewriter.create<tcp::MaxOp>( binaryOpResult = rewriter.create<tcp::MaxOp>(
loc, result.getType(), lhsBroadcasted, rhsBroadcasted); loc, result.getType(), lhsBroadcasted, rhsBroadcasted);
} else {
op->dump();
llvm::report_fatal_error(
"unhandled op (see dump above): TCF->TCP binary elementwise");
} }
rewriter.create<shape::AssumingYieldOp>(loc, binaryOpResult); rewriter.create<shape::AssumingYieldOp>(loc, binaryOpResult);
@ -93,6 +97,33 @@ public:
}; };
} // namespace } // namespace
static LogicalResult
matchAndRewriteUnaryElementwise(Operation *op, PatternRewriter &rewriter) {
if (isa<tcf::ExpOp>(op)) {
rewriter.replaceOpWithNewOp<tcp::ExpOp>(op, op->getOperand(0));
} else if (isa<tcf::TanhOp>(op)) {
rewriter.replaceOpWithNewOp<tcp::TanhOp>(op, op->getOperand(0));
} else {
op->dump();
llvm::report_fatal_error(
"unhandled op (see dump above): TCF->TCP unary elementwise");
}
return success();
}
namespace {
template <typename SourceOp>
class ConvertUnaryElementwise : public OpRewritePattern<SourceOp> {
public:
using OpRewritePattern<SourceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SourceOp op,
PatternRewriter &rewriter) const override {
return matchAndRewriteUnaryElementwise(op, rewriter);
}
};
} // namespace
namespace { namespace {
class ConvertMatmul : public OpRewritePattern<tcf::MatmulOp> { class ConvertMatmul : public OpRewritePattern<tcf::MatmulOp> {
public: public:
@ -134,6 +165,8 @@ public:
MLIRContext *context = &getContext(); MLIRContext *context = &getContext();
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
patterns.insert<ConvertUnaryElementwise<tcf::ExpOp>,
ConvertUnaryElementwise<tcf::TanhOp>>(context);
patterns.insert<ConvertBinaryElementwise<tcf::AddOp>, patterns.insert<ConvertBinaryElementwise<tcf::AddOp>,
ConvertBinaryElementwise<tcf::MaxOp>>(context); ConvertBinaryElementwise<tcf::MaxOp>>(context);
patterns.insert<ConvertMatmul>(context); patterns.insert<ConvertMatmul>(context);

View File

@ -24,8 +24,8 @@ static SmallVector<Value, 6> bypassResultShapes(Operation &op) {
return {broadcastTo.shape()}; return {broadcastTo.shape()};
} }
// Binary elementwise ops. // Elementwise ops.
if (isa<tcp::AddOp, tcp::MaxOp>(op)) { if (isa<tcp::AddOp, tcp::MaxOp, tcp::ExpOp, tcp::TanhOp>(op)) {
return {builder.create<shape::ShapeOfOp>(op.getLoc(), op.getOperand(0))}; return {builder.create<shape::ShapeOfOp>(op.getLoc(), op.getOperand(0))};
} }

View File

@ -24,6 +24,7 @@ add_mlir_library(NPCOMPE2E
MLIRSCFToStandard MLIRSCFToStandard
MLIRShapeToStandard MLIRShapeToStandard
MLIRStandardOps MLIRStandardOps
MLIRStandardOpsTransforms
MLIRStandardToLLVM MLIRStandardToLLVM
) )

View File

@ -12,6 +12,7 @@
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h" #include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h"
@ -707,6 +708,10 @@ class LowerToLLVM : public LowerToLLVMBase<LowerToLLVM> {
patterns.insert<LowerModuleMetadata>(context); patterns.insert<LowerModuleMetadata>(context);
patterns.insert<LowerNpcomprtGlobalOp>(converter); patterns.insert<LowerNpcomprtGlobalOp>(converter);
// TODO: Move these "std to std" legalizations to their own pass if we grow
// lots of these patterns.
populateExpandTanhPattern(patterns, context);
if (failed(applyFullConversion(module, target, patterns))) { if (failed(applyFullConversion(module, target, patterns))) {
return signalPassFailure(); return signalPassFailure();
} }

View File

@ -132,27 +132,32 @@ public:
}; };
} // namespace } // namespace
static Value createLinalgBodyCalculationForBinaryElementwise(Operation *op, static Value createLinalgBodyCalculationForElementwiseOp(Operation *op,
Value lhsBodyArg, ValueRange bodyArgs,
Value rhsBodyArg,
OpBuilder &builder, OpBuilder &builder,
Location loc) { Location loc) {
if (isa<tcp::AddOp>(op)) if (isa<tcp::AddOp>(op))
return builder.create<AddFOp>(loc, lhsBodyArg, rhsBodyArg); return builder.create<AddFOp>(loc, bodyArgs[0], bodyArgs[1]);
if (isa<tcp::MaxOp>(op)) { if (isa<tcp::MaxOp>(op)) {
auto greater = auto greater =
builder.create<CmpFOp>(loc, CmpFPredicate::OGT, lhsBodyArg, rhsBodyArg); builder.create<CmpFOp>(loc, CmpFPredicate::OGT, bodyArgs[0], bodyArgs[1]);
return builder.create<SelectOp>(loc, greater, lhsBodyArg, rhsBodyArg); return builder.create<SelectOp>(loc, greater, bodyArgs[0], bodyArgs[1]);
} }
if (isa<tcp::ExpOp>(op))
return builder.create<ExpOp>(loc, bodyArgs[0]);
if (isa<tcp::TanhOp>(op))
return builder.create<TanhOp>(loc, bodyArgs[0]);
op->dump(); op->dump();
llvm::report_fatal_error( llvm::report_fatal_error("unhandled op (see dump above): linalg body "
"unhandled op (see dump above) when lowering binary elementwise ops"); "calculation for elementwise op");
} }
static LogicalResult static LogicalResult
matchAndRewriteBinaryElementwiseOp(Operation *op, ArrayRef<Value> operands, matchAndRewriteElementwiseOp(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) { ConversionPatternRewriter &rewriter) {
Location loc = op->getLoc(); Location loc = op->getLoc();
Value result = op->getResult(0); Value result = op->getResult(0);
@ -187,8 +192,8 @@ matchAndRewriteBinaryElementwiseOp(Operation *op, ArrayRef<Value> operands,
/*iterator_types=*/iterators, /*iterator_types=*/iterators,
/*bodyBuilder=*/ /*bodyBuilder=*/
[&](OpBuilder &builder, Location loc, ValueRange regionArgs) { [&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
auto scalarResult = createLinalgBodyCalculationForBinaryElementwise( auto scalarResult = createLinalgBodyCalculationForElementwiseOp(
op, regionArgs[0], regionArgs[1], builder, loc); op, regionArgs, builder, loc);
builder.create<linalg::YieldOp>(loc, ValueRange({scalarResult})); builder.create<linalg::YieldOp>(loc, ValueRange({scalarResult}));
}); });
rewriter.replaceOp(op, results); rewriter.replaceOp(op, results);
@ -197,13 +202,13 @@ matchAndRewriteBinaryElementwiseOp(Operation *op, ArrayRef<Value> operands,
namespace { namespace {
template <typename SourceOp> template <typename SourceOp>
class LowerBinaryElementwiseOp : public OpConversionPattern<SourceOp> { class LowerElementwiseOp : public OpConversionPattern<SourceOp> {
public: public:
using OpConversionPattern<SourceOp>::OpConversionPattern; using OpConversionPattern<SourceOp>::OpConversionPattern;
LogicalResult LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<Value> operands, matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
return matchAndRewriteBinaryElementwiseOp(op, operands, rewriter); return matchAndRewriteElementwiseOp(op, operands, rewriter);
} }
}; };
} // namespace } // namespace
@ -322,9 +327,10 @@ class LowerShapedResultsToMemref
patterns.insert<LowerBroadcastToToLoopsPattern>(typeConverter, context); patterns.insert<LowerBroadcastToToLoopsPattern>(typeConverter, context);
target.addIllegalOp<tcp::BroadcastToOp>(); target.addIllegalOp<tcp::BroadcastToOp>();
patterns.insert<LowerBinaryElementwiseOp<tcp::AddOp>, patterns
LowerBinaryElementwiseOp<tcp::MaxOp>>(typeConverter, .insert<LowerElementwiseOp<tcp::AddOp>, LowerElementwiseOp<tcp::MaxOp>,
context); LowerElementwiseOp<tcp::ExpOp>,
LowerElementwiseOp<tcp::TanhOp>>(typeConverter, context);
target.addIllegalOp<tcp::AddOp, tcp::MaxOp>(); target.addIllegalOp<tcp::AddOp, tcp::MaxOp>();
patterns.insert<LowerTcpMatmulOp>(typeConverter, context); patterns.insert<LowerTcpMatmulOp>(typeConverter, context);
target.addIllegalOp<tcp::MatmulOp>(); target.addIllegalOp<tcp::MatmulOp>();

View File

@ -1,5 +1,15 @@
// RUN: npcomp-opt <%s -convert-tcf-to-tcp | FileCheck %s --dump-input=fail // RUN: npcomp-opt <%s -convert-tcf-to-tcp | FileCheck %s --dump-input=fail
// CHECK-LABEL: func @unary_ops(
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: %[[RET:.*]] = tcp.exp %[[ARG]] : tensor<?xf32>
// CHECK: return %[[RET]] : tensor<?xf32>
// CHECK: }
func @unary_ops(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = tcf.exp %arg0 : tensor<?xf32>
return %0 : tensor<?xf32>
}
// CHECK-LABEL: func @tcf_add( // CHECK-LABEL: func @tcf_add(
// CHECK-SAME: %[[LHS:.*]]: tensor<?xf32>, // CHECK-SAME: %[[LHS:.*]]: tensor<?xf32>,
// CHECK-SAME: %[[RHS:.*]]: tensor<?xf32>) -> tensor<?xf32> { // CHECK-SAME: %[[RHS:.*]]: tensor<?xf32>) -> tensor<?xf32> {

View File

@ -1,11 +1,13 @@
// RUN: npcomp-opt <%s | FileCheck %s --dump-input=fail // RUN: npcomp-opt <%s | npcomp-opt | FileCheck %s --dump-input=fail
// CHECK-LABEL: func @binary_elementwise // CHECK-LABEL: func @binary_elementwise
func @binary_elementwise(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) { func @binary_elementwise(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
// CHECK: tcf.add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> // CHECK: tcf.add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK: tcf.max %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> // CHECK: tcf.max %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK: tcf.exp %arg0 : tensor<?xf32>
%0 = tcf.add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> %0 = tcf.add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%1 = tcf.max %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> %1 = tcf.max %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%2 = tcf.exp %arg0 : tensor<?xf32>
return return
} }

View File

@ -14,8 +14,10 @@ func @global() {
func @binary_elementwise(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: i32) { func @binary_elementwise(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: i32) {
// CHECK: tcp.add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> // CHECK: tcp.add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK: tcp.max %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> // CHECK: tcp.max %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK: tcp.exp %arg0 : tensor<?xf32>
%0 = tcp.add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> %0 = tcp.add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%1 = tcp.max %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> %1 = tcp.max %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%2 = tcp.exp %arg0 : tensor<?xf32>
return return
} }

View File

@ -1,16 +0,0 @@
// RUN: npcomp-run-mlir %s \
// RUN: -invoke max \
// RUN: -arg-value="dense<[1.0]> : tensor<1xf32>" \
// RUN: -arg-value="dense<[3.0]> : tensor<1xf32>" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s --check-prefix=MAX
// These ops share a lot of code paths. So we don't test the exact
// broadcasting behavior and error checking for all of them.
// MAX: output #0: dense<3.000000e+00> : tensor<1xf32>
func @max(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
%0 = tcf.max %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}

View File

@ -0,0 +1,39 @@
// RUN: npcomp-run-mlir %s \
// RUN: -invoke max \
// RUN: -arg-value="dense<[1.0]> : tensor<1xf32>" \
// RUN: -arg-value="dense<[3.0]> : tensor<1xf32>" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s --check-prefix=MAX
// RUN: npcomp-run-mlir %s \
// RUN: -invoke exp \
// RUN: -arg-value="dense<[0.0, 1.0]> : tensor<2xf32>" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s --check-prefix=EXP
// RUN: npcomp-run-mlir %s \
// RUN: -invoke tanh \
// RUN: -arg-value="dense<[0.0, 1.0]> : tensor<2xf32>" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s --check-prefix=TANH
// These ops share a lot of code paths. So we don't test the exact
// broadcasting behavior and error checking for all of them.
// MAX: output #0: dense<3.000000e+00> : tensor<1xf32>
func @max(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
%0 = tcf.max %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// EXP: output #0: dense<[1.000000e+00, 2.71828175]> : tensor<2xf32>
func @exp(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = tcf.exp %arg0 : tensor<?xf32>
return %0 : tensor<?xf32>
}
// TANH: output #0: dense<[0.000000e+00, 0.761594116]> : tensor<2xf32>
func @tanh(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = tcf.tanh %arg0 : tensor<?xf32>
return %0 : tensor<?xf32>
}