diff --git a/include/npcomp/Dialect/TCF/IR/TCFOps.td b/include/npcomp/Dialect/TCF/IR/TCFOps.td index 0c19d5739..9862da514 100644 --- a/include/npcomp/Dialect/TCF/IR/TCFOps.td +++ b/include/npcomp/Dialect/TCF/IR/TCFOps.td @@ -43,6 +43,29 @@ def TCF_MaxOp : BinaryArithmeticOp<"max"> { }]; } +class UnaryArithmeticOp traits = []> : + TCF_Op])>, + AllTypesMatch<["operand", "result"]> { + let arguments = (ins AnyTensor:$operand); + let results = (outs AnyTensor:$result); + let assemblyFormat = "$operand attr-dict `:` type($operand)"; +} + +def TCF_ExpOp : UnaryArithmeticOp<"exp"> { + let summary = "base-e exponential"; + let description = [{ + See std.exp for more details. + }]; +} + +def TCF_TanhOp : UnaryArithmeticOp<"tanh"> { + let summary = "hyperbolic tangent"; + let description = [{ + See std.tanh for more details. + }]; +} + // TODO: Generalize this op appropriately and add more verification. // For example, an unranked operand probably should be allowed and verified // dynamically in TCF->TCP lowering if needed. diff --git a/include/npcomp/Dialect/TCP/IR/TCPOps.td b/include/npcomp/Dialect/TCP/IR/TCPOps.td index daabffc7a..d766a5ff3 100644 --- a/include/npcomp/Dialect/TCP/IR/TCPOps.td +++ b/include/npcomp/Dialect/TCP/IR/TCPOps.td @@ -42,6 +42,29 @@ def TCP_MaxOp : BinaryArithmeticOp<"max"> { }]; } +class UnaryArithmeticOp traits = []> : + TCP_Op])>, + AllTypesMatch<["operand", "result"]> { + let arguments = (ins AnyTensor:$operand); + let results = (outs AnyTensor:$result); + let assemblyFormat = "$operand attr-dict `:` type($operand)"; +} + +def TCP_ExpOp : UnaryArithmeticOp<"exp"> { + let summary = "base-e exponential"; + let description = [{ + See std.exp for more details. + }]; +} + +def TCP_TanhOp : UnaryArithmeticOp<"tanh"> { + let summary = "hyperbolic tangent"; + let description = [{ + See std.tanh for more details. + }]; +} + // TODO: Generalize this op appropriately and add more verification. // For example, should we have a single primitive that does multidimensional // contractions? + batching as well in the same op? In fact, if we want to diff --git a/lib/Conversion/TCFToTCP/TCFToTCP.cpp b/lib/Conversion/TCFToTCP/TCFToTCP.cpp index 38b0789ad..37fbaaadd 100644 --- a/lib/Conversion/TCFToTCP/TCFToTCP.cpp +++ b/lib/Conversion/TCFToTCP/TCFToTCP.cpp @@ -73,6 +73,10 @@ matchAndRewriteBinaryElementwise(Operation *op, PatternRewriter &rewriter) { } else if (isa(op)) { binaryOpResult = rewriter.create( loc, result.getType(), lhsBroadcasted, rhsBroadcasted); + } else { + op->dump(); + llvm::report_fatal_error( + "unhandled op (see dump above): TCF->TCP binary elementwise"); } rewriter.create(loc, binaryOpResult); @@ -93,6 +97,33 @@ public: }; } // namespace +static LogicalResult +matchAndRewriteUnaryElementwise(Operation *op, PatternRewriter &rewriter) { + if (isa(op)) { + rewriter.replaceOpWithNewOp(op, op->getOperand(0)); + } else if (isa(op)) { + rewriter.replaceOpWithNewOp(op, op->getOperand(0)); + } else { + op->dump(); + llvm::report_fatal_error( + "unhandled op (see dump above): TCF->TCP unary elementwise"); + } + return success(); + +} + +namespace { +template +class ConvertUnaryElementwise : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(SourceOp op, + PatternRewriter &rewriter) const override { + return matchAndRewriteUnaryElementwise(op, rewriter); + } +}; +} // namespace + namespace { class ConvertMatmul : public OpRewritePattern { public: @@ -134,6 +165,8 @@ public: MLIRContext *context = &getContext(); OwningRewritePatternList patterns; + patterns.insert, + ConvertUnaryElementwise>(context); patterns.insert, ConvertBinaryElementwise>(context); patterns.insert(context); diff --git a/lib/E2E/BypassShapes.cpp b/lib/E2E/BypassShapes.cpp index 60a4017b5..83a14e151 100644 --- a/lib/E2E/BypassShapes.cpp +++ b/lib/E2E/BypassShapes.cpp @@ -24,8 +24,8 @@ static SmallVector bypassResultShapes(Operation &op) { return {broadcastTo.shape()}; } - // Binary elementwise ops. - if (isa(op)) { + // Elementwise ops. + if (isa(op)) { return {builder.create(op.getLoc(), op.getOperand(0))}; } diff --git a/lib/E2E/CMakeLists.txt b/lib/E2E/CMakeLists.txt index 4178b98c4..c74c9e2e4 100644 --- a/lib/E2E/CMakeLists.txt +++ b/lib/E2E/CMakeLists.txt @@ -24,6 +24,7 @@ add_mlir_library(NPCOMPE2E MLIRSCFToStandard MLIRShapeToStandard MLIRStandardOps + MLIRStandardOpsTransforms MLIRStandardToLLVM ) diff --git a/lib/E2E/LowerToLLVM.cpp b/lib/E2E/LowerToLLVM.cpp index 72cda7be6..ea317f5c0 100644 --- a/lib/E2E/LowerToLLVM.cpp +++ b/lib/E2E/LowerToLLVM.cpp @@ -12,6 +12,7 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/Transforms/DialectConversion.h" #include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h" @@ -707,6 +708,10 @@ class LowerToLLVM : public LowerToLLVMBase { patterns.insert(context); patterns.insert(converter); + // TODO: Move these "std to std" legalizations to their own pass if we grow + // lots of these patterns. + populateExpandTanhPattern(patterns, context); + if (failed(applyFullConversion(module, target, patterns))) { return signalPassFailure(); } diff --git a/lib/E2E/TensorToMemref/LowerShapedResultsToMemref.cpp b/lib/E2E/TensorToMemref/LowerShapedResultsToMemref.cpp index 2a530d3ce..101ddb83b 100644 --- a/lib/E2E/TensorToMemref/LowerShapedResultsToMemref.cpp +++ b/lib/E2E/TensorToMemref/LowerShapedResultsToMemref.cpp @@ -132,27 +132,32 @@ public: }; } // namespace -static Value createLinalgBodyCalculationForBinaryElementwise(Operation *op, - Value lhsBodyArg, - Value rhsBodyArg, - OpBuilder &builder, - Location loc) { +static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, + ValueRange bodyArgs, + OpBuilder &builder, + Location loc) { if (isa(op)) - return builder.create(loc, lhsBodyArg, rhsBodyArg); + return builder.create(loc, bodyArgs[0], bodyArgs[1]); if (isa(op)) { auto greater = - builder.create(loc, CmpFPredicate::OGT, lhsBodyArg, rhsBodyArg); - return builder.create(loc, greater, lhsBodyArg, rhsBodyArg); + builder.create(loc, CmpFPredicate::OGT, bodyArgs[0], bodyArgs[1]); + return builder.create(loc, greater, bodyArgs[0], bodyArgs[1]); } + if (isa(op)) + return builder.create(loc, bodyArgs[0]); + + if (isa(op)) + return builder.create(loc, bodyArgs[0]); + op->dump(); - llvm::report_fatal_error( - "unhandled op (see dump above) when lowering binary elementwise ops"); + llvm::report_fatal_error("unhandled op (see dump above): linalg body " + "calculation for elementwise op"); } static LogicalResult -matchAndRewriteBinaryElementwiseOp(Operation *op, ArrayRef operands, +matchAndRewriteElementwiseOp(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) { Location loc = op->getLoc(); Value result = op->getResult(0); @@ -187,8 +192,8 @@ matchAndRewriteBinaryElementwiseOp(Operation *op, ArrayRef operands, /*iterator_types=*/iterators, /*bodyBuilder=*/ [&](OpBuilder &builder, Location loc, ValueRange regionArgs) { - auto scalarResult = createLinalgBodyCalculationForBinaryElementwise( - op, regionArgs[0], regionArgs[1], builder, loc); + auto scalarResult = createLinalgBodyCalculationForElementwiseOp( + op, regionArgs, builder, loc); builder.create(loc, ValueRange({scalarResult})); }); rewriter.replaceOp(op, results); @@ -197,13 +202,13 @@ matchAndRewriteBinaryElementwiseOp(Operation *op, ArrayRef operands, namespace { template -class LowerBinaryElementwiseOp : public OpConversionPattern { +class LowerElementwiseOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - return matchAndRewriteBinaryElementwiseOp(op, operands, rewriter); + return matchAndRewriteElementwiseOp(op, operands, rewriter); } }; } // namespace @@ -322,9 +327,10 @@ class LowerShapedResultsToMemref patterns.insert(typeConverter, context); target.addIllegalOp(); - patterns.insert, - LowerBinaryElementwiseOp>(typeConverter, - context); + patterns + .insert, LowerElementwiseOp, + LowerElementwiseOp, + LowerElementwiseOp>(typeConverter, context); target.addIllegalOp(); patterns.insert(typeConverter, context); target.addIllegalOp(); diff --git a/test/Conversion/TCFToTCP/basic.mlir b/test/Conversion/TCFToTCP/basic.mlir index 44f5d2470..5de379873 100644 --- a/test/Conversion/TCFToTCP/basic.mlir +++ b/test/Conversion/TCFToTCP/basic.mlir @@ -1,5 +1,15 @@ // RUN: npcomp-opt <%s -convert-tcf-to-tcp | FileCheck %s --dump-input=fail +// CHECK-LABEL: func @unary_ops( +// CHECK-SAME: %[[ARG:.*]]: tensor) -> tensor { +// CHECK: %[[RET:.*]] = tcp.exp %[[ARG]] : tensor +// CHECK: return %[[RET]] : tensor +// CHECK: } +func @unary_ops(%arg0: tensor) -> tensor { + %0 = tcf.exp %arg0 : tensor + return %0 : tensor +} + // CHECK-LABEL: func @tcf_add( // CHECK-SAME: %[[LHS:.*]]: tensor, // CHECK-SAME: %[[RHS:.*]]: tensor) -> tensor { diff --git a/test/Dialect/TCF/ops.mlir b/test/Dialect/TCF/ops.mlir index 0834b0313..8f5f9577f 100644 --- a/test/Dialect/TCF/ops.mlir +++ b/test/Dialect/TCF/ops.mlir @@ -1,11 +1,13 @@ -// RUN: npcomp-opt <%s | FileCheck %s --dump-input=fail +// RUN: npcomp-opt <%s | npcomp-opt | FileCheck %s --dump-input=fail // CHECK-LABEL: func @binary_elementwise func @binary_elementwise(%arg0: tensor, %arg1: tensor) { // CHECK: tcf.add %arg0, %arg1 : (tensor, tensor) -> tensor // CHECK: tcf.max %arg0, %arg1 : (tensor, tensor) -> tensor + // CHECK: tcf.exp %arg0 : tensor %0 = tcf.add %arg0, %arg1 : (tensor, tensor) -> tensor %1 = tcf.max %arg0, %arg1 : (tensor, tensor) -> tensor + %2 = tcf.exp %arg0 : tensor return } diff --git a/test/Dialect/TCP/ops.mlir b/test/Dialect/TCP/ops.mlir index 85312250f..c3d574081 100644 --- a/test/Dialect/TCP/ops.mlir +++ b/test/Dialect/TCP/ops.mlir @@ -14,8 +14,10 @@ func @global() { func @binary_elementwise(%arg0: tensor, %arg1: tensor, %arg2: i32) { // CHECK: tcp.add %arg0, %arg1 : (tensor, tensor) -> tensor // CHECK: tcp.max %arg0, %arg1 : (tensor, tensor) -> tensor + // CHECK: tcp.exp %arg0 : tensor %0 = tcp.add %arg0, %arg1 : (tensor, tensor) -> tensor %1 = tcp.max %arg0, %arg1 : (tensor, tensor) -> tensor + %2 = tcp.exp %arg0 : tensor return } diff --git a/test/npcomp-run-mlir/binary-elementwise.mlir b/test/npcomp-run-mlir/binary-elementwise.mlir deleted file mode 100644 index f176c82e3..000000000 --- a/test/npcomp-run-mlir/binary-elementwise.mlir +++ /dev/null @@ -1,16 +0,0 @@ -// RUN: npcomp-run-mlir %s \ -// RUN: -invoke max \ -// RUN: -arg-value="dense<[1.0]> : tensor<1xf32>" \ -// RUN: -arg-value="dense<[3.0]> : tensor<1xf32>" \ -// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \ -// RUN: | FileCheck %s --check-prefix=MAX - -// These ops share a lot of code paths. So we don't test the exact -// broadcasting behavior and error checking for all of them. - -// MAX: output #0: dense<3.000000e+00> : tensor<1xf32> -func @max(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = tcf.max %arg0, %arg1 : (tensor, tensor) -> tensor - return %0 : tensor -} - diff --git a/test/npcomp-run-mlir/elementwise.mlir b/test/npcomp-run-mlir/elementwise.mlir new file mode 100644 index 000000000..ec14c347d --- /dev/null +++ b/test/npcomp-run-mlir/elementwise.mlir @@ -0,0 +1,39 @@ +// RUN: npcomp-run-mlir %s \ +// RUN: -invoke max \ +// RUN: -arg-value="dense<[1.0]> : tensor<1xf32>" \ +// RUN: -arg-value="dense<[3.0]> : tensor<1xf32>" \ +// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \ +// RUN: | FileCheck %s --check-prefix=MAX + +// RUN: npcomp-run-mlir %s \ +// RUN: -invoke exp \ +// RUN: -arg-value="dense<[0.0, 1.0]> : tensor<2xf32>" \ +// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \ +// RUN: | FileCheck %s --check-prefix=EXP + +// RUN: npcomp-run-mlir %s \ +// RUN: -invoke tanh \ +// RUN: -arg-value="dense<[0.0, 1.0]> : tensor<2xf32>" \ +// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \ +// RUN: | FileCheck %s --check-prefix=TANH + +// These ops share a lot of code paths. So we don't test the exact +// broadcasting behavior and error checking for all of them. + +// MAX: output #0: dense<3.000000e+00> : tensor<1xf32> +func @max(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = tcf.max %arg0, %arg1 : (tensor, tensor) -> tensor + return %0 : tensor +} + +// EXP: output #0: dense<[1.000000e+00, 2.71828175]> : tensor<2xf32> +func @exp(%arg0: tensor) -> tensor { + %0 = tcf.exp %arg0 : tensor + return %0 : tensor +} + +// TANH: output #0: dense<[0.000000e+00, 0.761594116]> : tensor<2xf32> +func @tanh(%arg0: tensor) -> tensor { + %0 = tcf.tanh %arg0 : tensor + return %0 : tensor +}