mirror of https://github.com/llvm/torch-mlir
[RefBackend] Support element-wise multiply op
Register the following for the multiply op: - tcf.mul - tcp.mul - TCP->TCP lowering - Shape transfer, broadcasted multiplicands - Lower to standard `MulFOp` oppull/96/head
parent
510f226df2
commit
94ea6f7c92
|
@ -43,6 +43,13 @@ def TCF_MaxOp : BinaryArithmeticOp<"max"> {
|
|||
}];
|
||||
}
|
||||
|
||||
def TCF_MulOp : BinaryArithmeticOp<"mul"> {
|
||||
let summary = "Multiply an input tensor by a scalar tensor.";
|
||||
let description = [{
|
||||
Multiplies each element of the input `input` with the scalar `other` and returns a new resulting tensor. The tensor types must match and shapes must be broadcastable.
|
||||
}];
|
||||
}
|
||||
|
||||
class UnaryArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
TCF_Op<mnemonic,
|
||||
!listconcat(traits, [AllTypesMatch<["operand", "result"]>])>,
|
||||
|
|
|
@ -42,6 +42,13 @@ def TCP_MaxOp : BinaryArithmeticOp<"max"> {
|
|||
}];
|
||||
}
|
||||
|
||||
def TCF_MulOp : BinaryArithmeticOp<"mul"> {
|
||||
let summary = "Multiply an input tensor by a scalar tensor.";
|
||||
let description = [{
|
||||
Multiplies each element of the input `input` with the scalar `other` and returns a new resulting tensor. The tensor types must match and shapes must be broadcastable.
|
||||
}];
|
||||
}
|
||||
|
||||
class UnaryArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
TCP_Op<mnemonic,
|
||||
!listconcat(traits, [AllTypesMatch<["operand", "result"]>])>,
|
||||
|
|
|
@ -73,6 +73,9 @@ matchAndRewriteBinaryElementwise(Operation *op, PatternRewriter &rewriter) {
|
|||
} else if (isa<tcf::MaxOp>(op)) {
|
||||
binaryOpResult = rewriter.create<tcp::MaxOp>(
|
||||
loc, result.getType(), lhsBroadcasted, rhsBroadcasted);
|
||||
} else if (isa<tcf::MulOp>(op)) {
|
||||
binaryOpResult = rewriter.create<tcp::MulOp>(
|
||||
loc, result.getType(), lhsBroadcasted, rhsBroadcasted);
|
||||
} else {
|
||||
op->dump();
|
||||
llvm::report_fatal_error(
|
||||
|
@ -167,7 +170,8 @@ public:
|
|||
patterns.insert<ConvertUnaryElementwise<tcf::ExpOp>,
|
||||
ConvertUnaryElementwise<tcf::TanhOp>>(context);
|
||||
patterns.insert<ConvertBinaryElementwise<tcf::AddOp>,
|
||||
ConvertBinaryElementwise<tcf::MaxOp>>(context);
|
||||
ConvertBinaryElementwise<tcf::MaxOp>,
|
||||
ConvertBinaryElementwise<tcf::MulOp>>(context);
|
||||
patterns.insert<ConvertMatmul>(context);
|
||||
(void)applyPatternsAndFoldGreedily(module, patterns);
|
||||
}
|
||||
|
|
|
@ -33,7 +33,7 @@ static SmallVector<Value, 6> bypassResultShapes(Operation &op) {
|
|||
}
|
||||
|
||||
// Elementwise ops.
|
||||
if (isa<tcp::AddOp, tcp::MaxOp, tcp::ExpOp, tcp::TanhOp>(op)) {
|
||||
if (isa<tcp::AddOp, tcp::MaxOp, tcp::MulOp, tcp::ExpOp, tcp::TanhOp>(op)) {
|
||||
return {builder.create<shape::ShapeOfOp>(op.getLoc(), op.getOperand(0))};
|
||||
}
|
||||
|
||||
|
@ -161,6 +161,10 @@ static Value createLinalgBodyCalculationForElementwiseOp(Operation *op,
|
|||
return builder.create<SelectOp>(loc, greater, bodyArgs[0], bodyArgs[1]);
|
||||
}
|
||||
|
||||
if (isa<tcp::MulOp>(op)) {
|
||||
return builder.create<MulFOp>(loc, bodyArgs[0], bodyArgs[1]);
|
||||
}
|
||||
|
||||
if (isa<tcp::ExpOp>(op))
|
||||
return builder.create<ExpOp>(loc, bodyArgs[0]);
|
||||
|
||||
|
@ -275,10 +279,11 @@ class TCPBufferizePass : public TCPBufferizeBase<TCPBufferizePass> {
|
|||
target.addIllegalOp<tcp::BroadcastToOp>();
|
||||
patterns.insert<BufferizeElementwiseOp<tcp::AddOp>,
|
||||
BufferizeElementwiseOp<tcp::MaxOp>,
|
||||
BufferizeElementwiseOp<tcp::MulOp>,
|
||||
BufferizeElementwiseOp<tcp::ExpOp>,
|
||||
BufferizeElementwiseOp<tcp::TanhOp>>(typeConverter,
|
||||
context);
|
||||
target.addIllegalOp<tcp::AddOp, tcp::MaxOp>();
|
||||
target.addIllegalOp<tcp::AddOp, tcp::MaxOp, tcp::MulOp>();
|
||||
patterns.insert<BufferizeMatmulOp>(typeConverter, context);
|
||||
target.addIllegalOp<tcp::MatmulOp>();
|
||||
|
||||
|
|
|
@ -5,6 +5,13 @@
|
|||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s --check-prefix=MAX
|
||||
|
||||
// RUN: npcomp-run-mlir %s \
|
||||
// RUN: -invoke mul \
|
||||
// RUN: -arg-value="dense<[1.0, 2.0]> : tensor<2xf32>" \
|
||||
// RUN: -arg-value="dense<[3.0, 4.0]> : tensor<2xf32>" \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s --check-prefix=MUL
|
||||
|
||||
// RUN: npcomp-run-mlir %s \
|
||||
// RUN: -invoke exp \
|
||||
// RUN: -arg-value="dense<[0.0, 1.0]> : tensor<2xf32>" \
|
||||
|
@ -26,6 +33,12 @@ func @max(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
|||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// MUL: output #0: dense<[3.000000e+00, 8.000000e+00]> : tensor<2xf32>
|
||||
func @mul(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = tcf.mul %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// EXP: output #0: dense<[1.000000e+00, 2.71828175]> : tensor<2xf32>
|
||||
func @exp(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = tcf.exp %arg0 : tensor<?xf32>
|
||||
|
|
Loading…
Reference in New Issue