mirror of https://github.com/llvm/torch-mlir
[RefE2E] Add support for unary ops exp and tanh
This is fairly mechanical.pull/60/head
parent
6b69beae6a
commit
f9b37c55b7
|
@ -43,6 +43,29 @@ def TCF_MaxOp : BinaryArithmeticOp<"max"> {
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class UnaryArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
|
||||||
|
TCF_Op<mnemonic,
|
||||||
|
!listconcat(traits, [AllTypesMatch<["operand", "result"]>])>,
|
||||||
|
AllTypesMatch<["operand", "result"]> {
|
||||||
|
let arguments = (ins AnyTensor:$operand);
|
||||||
|
let results = (outs AnyTensor:$result);
|
||||||
|
let assemblyFormat = "$operand attr-dict `:` type($operand)";
|
||||||
|
}
|
||||||
|
|
||||||
|
def TCF_ExpOp : UnaryArithmeticOp<"exp"> {
|
||||||
|
let summary = "base-e exponential";
|
||||||
|
let description = [{
|
||||||
|
See std.exp for more details.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def TCF_TanhOp : UnaryArithmeticOp<"tanh"> {
|
||||||
|
let summary = "hyperbolic tangent";
|
||||||
|
let description = [{
|
||||||
|
See std.tanh for more details.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Generalize this op appropriately and add more verification.
|
// TODO: Generalize this op appropriately and add more verification.
|
||||||
// For example, an unranked operand probably should be allowed and verified
|
// For example, an unranked operand probably should be allowed and verified
|
||||||
// dynamically in TCF->TCP lowering if needed.
|
// dynamically in TCF->TCP lowering if needed.
|
||||||
|
|
|
@ -42,6 +42,29 @@ def TCP_MaxOp : BinaryArithmeticOp<"max"> {
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class UnaryArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
|
||||||
|
TCP_Op<mnemonic,
|
||||||
|
!listconcat(traits, [AllTypesMatch<["operand", "result"]>])>,
|
||||||
|
AllTypesMatch<["operand", "result"]> {
|
||||||
|
let arguments = (ins AnyTensor:$operand);
|
||||||
|
let results = (outs AnyTensor:$result);
|
||||||
|
let assemblyFormat = "$operand attr-dict `:` type($operand)";
|
||||||
|
}
|
||||||
|
|
||||||
|
def TCP_ExpOp : UnaryArithmeticOp<"exp"> {
|
||||||
|
let summary = "base-e exponential";
|
||||||
|
let description = [{
|
||||||
|
See std.exp for more details.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def TCP_TanhOp : UnaryArithmeticOp<"tanh"> {
|
||||||
|
let summary = "hyperbolic tangent";
|
||||||
|
let description = [{
|
||||||
|
See std.tanh for more details.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Generalize this op appropriately and add more verification.
|
// TODO: Generalize this op appropriately and add more verification.
|
||||||
// For example, should we have a single primitive that does multidimensional
|
// For example, should we have a single primitive that does multidimensional
|
||||||
// contractions? + batching as well in the same op? In fact, if we want to
|
// contractions? + batching as well in the same op? In fact, if we want to
|
||||||
|
|
|
@ -73,6 +73,10 @@ matchAndRewriteBinaryElementwise(Operation *op, PatternRewriter &rewriter) {
|
||||||
} else if (isa<tcf::MaxOp>(op)) {
|
} else if (isa<tcf::MaxOp>(op)) {
|
||||||
binaryOpResult = rewriter.create<tcp::MaxOp>(
|
binaryOpResult = rewriter.create<tcp::MaxOp>(
|
||||||
loc, result.getType(), lhsBroadcasted, rhsBroadcasted);
|
loc, result.getType(), lhsBroadcasted, rhsBroadcasted);
|
||||||
|
} else {
|
||||||
|
op->dump();
|
||||||
|
llvm::report_fatal_error(
|
||||||
|
"unhandled op (see dump above): TCF->TCP binary elementwise");
|
||||||
}
|
}
|
||||||
rewriter.create<shape::AssumingYieldOp>(loc, binaryOpResult);
|
rewriter.create<shape::AssumingYieldOp>(loc, binaryOpResult);
|
||||||
|
|
||||||
|
@ -93,6 +97,33 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
static LogicalResult
|
||||||
|
matchAndRewriteUnaryElementwise(Operation *op, PatternRewriter &rewriter) {
|
||||||
|
if (isa<tcf::ExpOp>(op)) {
|
||||||
|
rewriter.replaceOpWithNewOp<tcp::ExpOp>(op, op->getOperand(0));
|
||||||
|
} else if (isa<tcf::TanhOp>(op)) {
|
||||||
|
rewriter.replaceOpWithNewOp<tcp::TanhOp>(op, op->getOperand(0));
|
||||||
|
} else {
|
||||||
|
op->dump();
|
||||||
|
llvm::report_fatal_error(
|
||||||
|
"unhandled op (see dump above): TCF->TCP unary elementwise");
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <typename SourceOp>
|
||||||
|
class ConvertUnaryElementwise : public OpRewritePattern<SourceOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<SourceOp>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(SourceOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
return matchAndRewriteUnaryElementwise(op, rewriter);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertMatmul : public OpRewritePattern<tcf::MatmulOp> {
|
class ConvertMatmul : public OpRewritePattern<tcf::MatmulOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -134,6 +165,8 @@ public:
|
||||||
MLIRContext *context = &getContext();
|
MLIRContext *context = &getContext();
|
||||||
|
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
|
patterns.insert<ConvertUnaryElementwise<tcf::ExpOp>,
|
||||||
|
ConvertUnaryElementwise<tcf::TanhOp>>(context);
|
||||||
patterns.insert<ConvertBinaryElementwise<tcf::AddOp>,
|
patterns.insert<ConvertBinaryElementwise<tcf::AddOp>,
|
||||||
ConvertBinaryElementwise<tcf::MaxOp>>(context);
|
ConvertBinaryElementwise<tcf::MaxOp>>(context);
|
||||||
patterns.insert<ConvertMatmul>(context);
|
patterns.insert<ConvertMatmul>(context);
|
||||||
|
|
|
@ -24,8 +24,8 @@ static SmallVector<Value, 6> bypassResultShapes(Operation &op) {
|
||||||
return {broadcastTo.shape()};
|
return {broadcastTo.shape()};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Binary elementwise ops.
|
// Elementwise ops.
|
||||||
if (isa<tcp::AddOp, tcp::MaxOp>(op)) {
|
if (isa<tcp::AddOp, tcp::MaxOp, tcp::ExpOp, tcp::TanhOp>(op)) {
|
||||||
return {builder.create<shape::ShapeOfOp>(op.getLoc(), op.getOperand(0))};
|
return {builder.create<shape::ShapeOfOp>(op.getLoc(), op.getOperand(0))};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,7 @@ add_mlir_library(NPCOMPE2E
|
||||||
MLIRSCFToStandard
|
MLIRSCFToStandard
|
||||||
MLIRShapeToStandard
|
MLIRShapeToStandard
|
||||||
MLIRStandardOps
|
MLIRStandardOps
|
||||||
|
MLIRStandardOpsTransforms
|
||||||
MLIRStandardToLLVM
|
MLIRStandardToLLVM
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h"
|
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h"
|
||||||
|
@ -707,6 +708,10 @@ class LowerToLLVM : public LowerToLLVMBase<LowerToLLVM> {
|
||||||
patterns.insert<LowerModuleMetadata>(context);
|
patterns.insert<LowerModuleMetadata>(context);
|
||||||
patterns.insert<LowerNpcomprtGlobalOp>(converter);
|
patterns.insert<LowerNpcomprtGlobalOp>(converter);
|
||||||
|
|
||||||
|
// TODO: Move these "std to std" legalizations to their own pass if we grow
|
||||||
|
// lots of these patterns.
|
||||||
|
populateExpandTanhPattern(patterns, context);
|
||||||
|
|
||||||
if (failed(applyFullConversion(module, target, patterns))) {
|
if (failed(applyFullConversion(module, target, patterns))) {
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
}
|
}
|
||||||
|
|
|
@ -132,27 +132,32 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
static Value createLinalgBodyCalculationForBinaryElementwise(Operation *op,
|
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op,
|
||||||
Value lhsBodyArg,
|
ValueRange bodyArgs,
|
||||||
Value rhsBodyArg,
|
|
||||||
OpBuilder &builder,
|
OpBuilder &builder,
|
||||||
Location loc) {
|
Location loc) {
|
||||||
if (isa<tcp::AddOp>(op))
|
if (isa<tcp::AddOp>(op))
|
||||||
return builder.create<AddFOp>(loc, lhsBodyArg, rhsBodyArg);
|
return builder.create<AddFOp>(loc, bodyArgs[0], bodyArgs[1]);
|
||||||
|
|
||||||
if (isa<tcp::MaxOp>(op)) {
|
if (isa<tcp::MaxOp>(op)) {
|
||||||
auto greater =
|
auto greater =
|
||||||
builder.create<CmpFOp>(loc, CmpFPredicate::OGT, lhsBodyArg, rhsBodyArg);
|
builder.create<CmpFOp>(loc, CmpFPredicate::OGT, bodyArgs[0], bodyArgs[1]);
|
||||||
return builder.create<SelectOp>(loc, greater, lhsBodyArg, rhsBodyArg);
|
return builder.create<SelectOp>(loc, greater, bodyArgs[0], bodyArgs[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isa<tcp::ExpOp>(op))
|
||||||
|
return builder.create<ExpOp>(loc, bodyArgs[0]);
|
||||||
|
|
||||||
|
if (isa<tcp::TanhOp>(op))
|
||||||
|
return builder.create<TanhOp>(loc, bodyArgs[0]);
|
||||||
|
|
||||||
op->dump();
|
op->dump();
|
||||||
llvm::report_fatal_error(
|
llvm::report_fatal_error("unhandled op (see dump above): linalg body "
|
||||||
"unhandled op (see dump above) when lowering binary elementwise ops");
|
"calculation for elementwise op");
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult
|
static LogicalResult
|
||||||
matchAndRewriteBinaryElementwiseOp(Operation *op, ArrayRef<Value> operands,
|
matchAndRewriteElementwiseOp(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) {
|
ConversionPatternRewriter &rewriter) {
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
Value result = op->getResult(0);
|
Value result = op->getResult(0);
|
||||||
|
@ -187,8 +192,8 @@ matchAndRewriteBinaryElementwiseOp(Operation *op, ArrayRef<Value> operands,
|
||||||
/*iterator_types=*/iterators,
|
/*iterator_types=*/iterators,
|
||||||
/*bodyBuilder=*/
|
/*bodyBuilder=*/
|
||||||
[&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
|
[&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
|
||||||
auto scalarResult = createLinalgBodyCalculationForBinaryElementwise(
|
auto scalarResult = createLinalgBodyCalculationForElementwiseOp(
|
||||||
op, regionArgs[0], regionArgs[1], builder, loc);
|
op, regionArgs, builder, loc);
|
||||||
builder.create<linalg::YieldOp>(loc, ValueRange({scalarResult}));
|
builder.create<linalg::YieldOp>(loc, ValueRange({scalarResult}));
|
||||||
});
|
});
|
||||||
rewriter.replaceOp(op, results);
|
rewriter.replaceOp(op, results);
|
||||||
|
@ -197,13 +202,13 @@ matchAndRewriteBinaryElementwiseOp(Operation *op, ArrayRef<Value> operands,
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template <typename SourceOp>
|
template <typename SourceOp>
|
||||||
class LowerBinaryElementwiseOp : public OpConversionPattern<SourceOp> {
|
class LowerElementwiseOp : public OpConversionPattern<SourceOp> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern<SourceOp>::OpConversionPattern;
|
using OpConversionPattern<SourceOp>::OpConversionPattern;
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
|
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
return matchAndRewriteBinaryElementwiseOp(op, operands, rewriter);
|
return matchAndRewriteElementwiseOp(op, operands, rewriter);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -322,9 +327,10 @@ class LowerShapedResultsToMemref
|
||||||
|
|
||||||
patterns.insert<LowerBroadcastToToLoopsPattern>(typeConverter, context);
|
patterns.insert<LowerBroadcastToToLoopsPattern>(typeConverter, context);
|
||||||
target.addIllegalOp<tcp::BroadcastToOp>();
|
target.addIllegalOp<tcp::BroadcastToOp>();
|
||||||
patterns.insert<LowerBinaryElementwiseOp<tcp::AddOp>,
|
patterns
|
||||||
LowerBinaryElementwiseOp<tcp::MaxOp>>(typeConverter,
|
.insert<LowerElementwiseOp<tcp::AddOp>, LowerElementwiseOp<tcp::MaxOp>,
|
||||||
context);
|
LowerElementwiseOp<tcp::ExpOp>,
|
||||||
|
LowerElementwiseOp<tcp::TanhOp>>(typeConverter, context);
|
||||||
target.addIllegalOp<tcp::AddOp, tcp::MaxOp>();
|
target.addIllegalOp<tcp::AddOp, tcp::MaxOp>();
|
||||||
patterns.insert<LowerTcpMatmulOp>(typeConverter, context);
|
patterns.insert<LowerTcpMatmulOp>(typeConverter, context);
|
||||||
target.addIllegalOp<tcp::MatmulOp>();
|
target.addIllegalOp<tcp::MatmulOp>();
|
||||||
|
|
|
@ -1,5 +1,15 @@
|
||||||
// RUN: npcomp-opt <%s -convert-tcf-to-tcp | FileCheck %s --dump-input=fail
|
// RUN: npcomp-opt <%s -convert-tcf-to-tcp | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @unary_ops(
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
|
// CHECK: %[[RET:.*]] = tcp.exp %[[ARG]] : tensor<?xf32>
|
||||||
|
// CHECK: return %[[RET]] : tensor<?xf32>
|
||||||
|
// CHECK: }
|
||||||
|
func @unary_ops(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
|
%0 = tcf.exp %arg0 : tensor<?xf32>
|
||||||
|
return %0 : tensor<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @tcf_add(
|
// CHECK-LABEL: func @tcf_add(
|
||||||
// CHECK-SAME: %[[LHS:.*]]: tensor<?xf32>,
|
// CHECK-SAME: %[[LHS:.*]]: tensor<?xf32>,
|
||||||
// CHECK-SAME: %[[RHS:.*]]: tensor<?xf32>) -> tensor<?xf32> {
|
// CHECK-SAME: %[[RHS:.*]]: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
// RUN: npcomp-opt <%s | FileCheck %s --dump-input=fail
|
// RUN: npcomp-opt <%s | npcomp-opt | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
// CHECK-LABEL: func @binary_elementwise
|
// CHECK-LABEL: func @binary_elementwise
|
||||||
func @binary_elementwise(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
|
func @binary_elementwise(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
|
||||||
// CHECK: tcf.add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
// CHECK: tcf.add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
// CHECK: tcf.max %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
// CHECK: tcf.max %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
|
// CHECK: tcf.exp %arg0 : tensor<?xf32>
|
||||||
%0 = tcf.add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
%0 = tcf.add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
%1 = tcf.max %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
%1 = tcf.max %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
|
%2 = tcf.exp %arg0 : tensor<?xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -14,8 +14,10 @@ func @global() {
|
||||||
func @binary_elementwise(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: i32) {
|
func @binary_elementwise(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: i32) {
|
||||||
// CHECK: tcp.add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
// CHECK: tcp.add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
// CHECK: tcp.max %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
// CHECK: tcp.max %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
|
// CHECK: tcp.exp %arg0 : tensor<?xf32>
|
||||||
%0 = tcp.add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
%0 = tcp.add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
%1 = tcp.max %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
%1 = tcp.max %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
|
%2 = tcp.exp %arg0 : tensor<?xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,16 +0,0 @@
|
||||||
// RUN: npcomp-run-mlir %s \
|
|
||||||
// RUN: -invoke max \
|
|
||||||
// RUN: -arg-value="dense<[1.0]> : tensor<1xf32>" \
|
|
||||||
// RUN: -arg-value="dense<[3.0]> : tensor<1xf32>" \
|
|
||||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
|
||||||
// RUN: | FileCheck %s --check-prefix=MAX
|
|
||||||
|
|
||||||
// These ops share a lot of code paths. So we don't test the exact
|
|
||||||
// broadcasting behavior and error checking for all of them.
|
|
||||||
|
|
||||||
// MAX: output #0: dense<3.000000e+00> : tensor<1xf32>
|
|
||||||
func @max(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
|
||||||
%0 = tcf.max %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
|
||||||
return %0 : tensor<?xf32>
|
|
||||||
}
|
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
// RUN: npcomp-run-mlir %s \
|
||||||
|
// RUN: -invoke max \
|
||||||
|
// RUN: -arg-value="dense<[1.0]> : tensor<1xf32>" \
|
||||||
|
// RUN: -arg-value="dense<[3.0]> : tensor<1xf32>" \
|
||||||
|
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||||
|
// RUN: | FileCheck %s --check-prefix=MAX
|
||||||
|
|
||||||
|
// RUN: npcomp-run-mlir %s \
|
||||||
|
// RUN: -invoke exp \
|
||||||
|
// RUN: -arg-value="dense<[0.0, 1.0]> : tensor<2xf32>" \
|
||||||
|
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||||
|
// RUN: | FileCheck %s --check-prefix=EXP
|
||||||
|
|
||||||
|
// RUN: npcomp-run-mlir %s \
|
||||||
|
// RUN: -invoke tanh \
|
||||||
|
// RUN: -arg-value="dense<[0.0, 1.0]> : tensor<2xf32>" \
|
||||||
|
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||||
|
// RUN: | FileCheck %s --check-prefix=TANH
|
||||||
|
|
||||||
|
// These ops share a lot of code paths. So we don't test the exact
|
||||||
|
// broadcasting behavior and error checking for all of them.
|
||||||
|
|
||||||
|
// MAX: output #0: dense<3.000000e+00> : tensor<1xf32>
|
||||||
|
func @max(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
|
%0 = tcf.max %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
|
return %0 : tensor<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXP: output #0: dense<[1.000000e+00, 2.71828175]> : tensor<2xf32>
|
||||||
|
func @exp(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
|
%0 = tcf.exp %arg0 : tensor<?xf32>
|
||||||
|
return %0 : tensor<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// TANH: output #0: dense<[0.000000e+00, 0.761594116]> : tensor<2xf32>
|
||||||
|
func @tanh(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
|
%0 = tcf.tanh %arg0 : tensor<?xf32>
|
||||||
|
return %0 : tensor<?xf32>
|
||||||
|
}
|
Loading…
Reference in New Issue