[RefBackend] Support element-wise multiply op

Register the following for the multiply op:
- tcf.mul
- tcp.mul
- TCP->TCP lowering
- Shape transfer, broadcasted multiplicands
- Lower to standard `MulFOp` op
pull/96/head
Aaron J Arthurs 2020-10-27 10:04:15 -05:00 committed by Sean Silva
parent 510f226df2
commit 94ea6f7c92
5 changed files with 39 additions and 3 deletions

View File

@ -43,6 +43,13 @@ def TCF_MaxOp : BinaryArithmeticOp<"max"> {
}];
}
def TCF_MulOp : BinaryArithmeticOp<"mul"> {
let summary = "Multiply an input tensor by a scalar tensor.";
let description = [{
Multiplies each element of the input `input` with the scalar `other` and returns a new resulting tensor. The tensor types must match and shapes must be broadcastable.
}];
}
class UnaryArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
TCF_Op<mnemonic,
!listconcat(traits, [AllTypesMatch<["operand", "result"]>])>,

View File

@ -42,6 +42,13 @@ def TCP_MaxOp : BinaryArithmeticOp<"max"> {
}];
}
def TCF_MulOp : BinaryArithmeticOp<"mul"> {
let summary = "Multiply an input tensor by a scalar tensor.";
let description = [{
Multiplies each element of the input `input` with the scalar `other` and returns a new resulting tensor. The tensor types must match and shapes must be broadcastable.
}];
}
class UnaryArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
TCP_Op<mnemonic,
!listconcat(traits, [AllTypesMatch<["operand", "result"]>])>,

View File

@ -73,6 +73,9 @@ matchAndRewriteBinaryElementwise(Operation *op, PatternRewriter &rewriter) {
} else if (isa<tcf::MaxOp>(op)) {
binaryOpResult = rewriter.create<tcp::MaxOp>(
loc, result.getType(), lhsBroadcasted, rhsBroadcasted);
} else if (isa<tcf::MulOp>(op)) {
binaryOpResult = rewriter.create<tcp::MulOp>(
loc, result.getType(), lhsBroadcasted, rhsBroadcasted);
} else {
op->dump();
llvm::report_fatal_error(
@ -167,7 +170,8 @@ public:
patterns.insert<ConvertUnaryElementwise<tcf::ExpOp>,
ConvertUnaryElementwise<tcf::TanhOp>>(context);
patterns.insert<ConvertBinaryElementwise<tcf::AddOp>,
ConvertBinaryElementwise<tcf::MaxOp>>(context);
ConvertBinaryElementwise<tcf::MaxOp>,
ConvertBinaryElementwise<tcf::MulOp>>(context);
patterns.insert<ConvertMatmul>(context);
(void)applyPatternsAndFoldGreedily(module, patterns);
}

View File

@ -33,7 +33,7 @@ static SmallVector<Value, 6> bypassResultShapes(Operation &op) {
}
// Elementwise ops.
if (isa<tcp::AddOp, tcp::MaxOp, tcp::ExpOp, tcp::TanhOp>(op)) {
if (isa<tcp::AddOp, tcp::MaxOp, tcp::MulOp, tcp::ExpOp, tcp::TanhOp>(op)) {
return {builder.create<shape::ShapeOfOp>(op.getLoc(), op.getOperand(0))};
}
@ -161,6 +161,10 @@ static Value createLinalgBodyCalculationForElementwiseOp(Operation *op,
return builder.create<SelectOp>(loc, greater, bodyArgs[0], bodyArgs[1]);
}
if (isa<tcp::MulOp>(op)) {
return builder.create<MulFOp>(loc, bodyArgs[0], bodyArgs[1]);
}
if (isa<tcp::ExpOp>(op))
return builder.create<ExpOp>(loc, bodyArgs[0]);
@ -275,10 +279,11 @@ class TCPBufferizePass : public TCPBufferizeBase<TCPBufferizePass> {
target.addIllegalOp<tcp::BroadcastToOp>();
patterns.insert<BufferizeElementwiseOp<tcp::AddOp>,
BufferizeElementwiseOp<tcp::MaxOp>,
BufferizeElementwiseOp<tcp::MulOp>,
BufferizeElementwiseOp<tcp::ExpOp>,
BufferizeElementwiseOp<tcp::TanhOp>>(typeConverter,
context);
target.addIllegalOp<tcp::AddOp, tcp::MaxOp>();
target.addIllegalOp<tcp::AddOp, tcp::MaxOp, tcp::MulOp>();
patterns.insert<BufferizeMatmulOp>(typeConverter, context);
target.addIllegalOp<tcp::MatmulOp>();

View File

@ -5,6 +5,13 @@
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s --check-prefix=MAX
// RUN: npcomp-run-mlir %s \
// RUN: -invoke mul \
// RUN: -arg-value="dense<[1.0, 2.0]> : tensor<2xf32>" \
// RUN: -arg-value="dense<[3.0, 4.0]> : tensor<2xf32>" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s --check-prefix=MUL
// RUN: npcomp-run-mlir %s \
// RUN: -invoke exp \
// RUN: -arg-value="dense<[0.0, 1.0]> : tensor<2xf32>" \
@ -26,6 +33,12 @@ func @max(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
return %0 : tensor<?xf32>
}
// MUL: output #0: dense<[3.000000e+00, 8.000000e+00]> : tensor<2xf32>
func @mul(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
%0 = tcf.mul %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// EXP: output #0: dense<[1.000000e+00, 2.71828175]> : tensor<2xf32>
func @exp(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = tcf.exp %arg0 : tensor<?xf32>