mirror of https://github.com/llvm/torch-mlir
Add lowering of aten.matmul op.
Lowering of `aten.matmul` op is added from torch to linalg dialect. The different cases correspond to https://pytorch.org/docs/stable/generated/torch.matmul.html. TODO: Broadcasting in case of batch-matmul is yet to be taken care of. Signed-off-by: Prashant Kumar <prashant@nod-labs.com>pull/385/head snapshot-20211026.46
parent
e276dbbaa6
commit
5009cbf55c
|
@ -35,6 +35,7 @@ from . import quantized_models
|
||||||
from . import elementwise
|
from . import elementwise
|
||||||
from . import reduction
|
from . import reduction
|
||||||
from . import argmax
|
from . import argmax
|
||||||
|
from . import matmul
|
||||||
|
|
||||||
def _get_argparse():
|
def _get_argparse():
|
||||||
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
|
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
|
||||||
|
|
|
@ -0,0 +1,132 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||||
|
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||||
|
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class MatmulDot(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, lhs, rhs):
|
||||||
|
return torch.matmul(lhs, rhs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: MatmulDot())
|
||||||
|
def Matmul_dot(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3), tu.rand(3))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class Matmul2D(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, lhs, rhs):
|
||||||
|
return torch.matmul(lhs, rhs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: Matmul2D())
|
||||||
|
def Matmul_2d(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4), tu.rand(4, 5))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class MatmulVecMat(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, lhs, rhs):
|
||||||
|
return torch.matmul(lhs, rhs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: MatmulVecMat())
|
||||||
|
def Matmul_vecmat(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(4), tu.rand(4, 5))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class MatmulMatVec(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, lhs, rhs):
|
||||||
|
return torch.matmul(lhs, rhs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: MatmulMatVec())
|
||||||
|
def Matmul_matvec(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(4, 5), tu.rand(5))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class Matmul3D(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, lhs, rhs):
|
||||||
|
return torch.matmul(lhs, rhs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: Matmul3D())
|
||||||
|
def Matmul_3d(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class Matmul4d(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, lhs, rhs):
|
||||||
|
return torch.matmul(lhs, rhs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: Matmul4d())
|
||||||
|
def Matmul_4d(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6))
|
||||||
|
|
||||||
|
# ==============================================================================
|
|
@ -871,6 +871,21 @@ def Torch_AtenMmOp : Torch_Op<"aten.mm", [
|
||||||
let assemblyFormat = "$self `,` $mat2 attr-dict `:` type($self) `,` type($mat2) `->` type($result)";
|
let assemblyFormat = "$self `,` $mat2 attr-dict `:` type($self) `,` type($mat2) `->` type($result)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenMatmulOp : Torch_Op<"aten.matmul", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::matmul : (Tensor, Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchTensorType:$other
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [
|
def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics
|
HasValueSemantics
|
||||||
|
|
|
@ -942,6 +942,153 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenMatmulOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
Value lhs = operands[0];
|
||||||
|
Value rhs = operands[1];
|
||||||
|
|
||||||
|
unsigned lhsRank = lhs.getType().cast<RankedTensorType>().getRank();
|
||||||
|
unsigned rhsRank = rhs.getType().cast<RankedTensorType>().getRank();
|
||||||
|
|
||||||
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
|
Type elementType = newResultType.cast<TensorType>().getElementType();
|
||||||
|
|
||||||
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// The different cases of torch_matmul op is mentioned here:
|
||||||
|
// https://pytorch.org/docs/stable/generated/torch.matmul.html
|
||||||
|
|
||||||
|
// First Case: Dot Product.
|
||||||
|
if (lhsRank == 1 && rhsRank == 1) {
|
||||||
|
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
|
||||||
|
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
|
||||||
|
|
||||||
|
checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0);
|
||||||
|
|
||||||
|
Value zeroTensor = createZeroInitTensor(rewriter, loc, {}, elementType);
|
||||||
|
Value dotProd =
|
||||||
|
rewriter
|
||||||
|
.create<linalg::DotOp>(loc, zeroTensor.getType(),
|
||||||
|
ValueRange{lhs, rhs}, zeroTensor)
|
||||||
|
.getResult(0);
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, dotProd);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second Case: Vec-Mat Multiplication.
|
||||||
|
if (lhsRank == 1 && rhsRank == 2) {
|
||||||
|
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
|
||||||
|
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
|
||||||
|
Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1);
|
||||||
|
checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0);
|
||||||
|
|
||||||
|
Value zeroTensor =
|
||||||
|
createZeroInitTensor(rewriter, loc, ValueRange{rhsDim1}, elementType);
|
||||||
|
Value matmul =
|
||||||
|
rewriter
|
||||||
|
.create<linalg::VecmatOp>(loc, zeroTensor.getType(),
|
||||||
|
ValueRange{lhs, rhs}, zeroTensor)
|
||||||
|
.getResult(0);
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Third Case: Matrix-Vec Multiplication.
|
||||||
|
if (lhsRank == 2 && rhsRank == 1) {
|
||||||
|
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
|
||||||
|
Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1);
|
||||||
|
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
|
||||||
|
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
|
||||||
|
|
||||||
|
Value zeroTensor =
|
||||||
|
createZeroInitTensor(rewriter, loc, ValueRange{lhsDim0}, elementType);
|
||||||
|
Value matmul =
|
||||||
|
rewriter
|
||||||
|
.create<linalg::MatvecOp>(loc, zeroTensor.getType(),
|
||||||
|
ValueRange{lhs, rhs}, zeroTensor)
|
||||||
|
.getResult(0);
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fourth Case: Batch-Matrix Multiplication.
|
||||||
|
// TODO: Broadcasting of batch dimension is remaining.
|
||||||
|
if (lhsRank >= 3 && rhsRank >= 3 && lhsRank == rhsRank) {
|
||||||
|
|
||||||
|
unsigned batchRank = lhsRank - 2;
|
||||||
|
SmallVector<Value, 4> resultShape;
|
||||||
|
|
||||||
|
SmallVector<AffineExpr> lhsExpr;
|
||||||
|
SmallVector<AffineExpr> rhsExpr;
|
||||||
|
SmallVector<AffineExpr> outExpr;
|
||||||
|
SmallVector<StringRef> iteratorTypes;
|
||||||
|
|
||||||
|
// Since broadcasting is a TODO, check whether the lhs and rhs batch
|
||||||
|
// dimension match.
|
||||||
|
for (unsigned i = 0; i < batchRank; i++) {
|
||||||
|
Value lhsBatch = getDimOp(rewriter, loc, lhs, i);
|
||||||
|
Value rhsBatch = getDimOp(rewriter, loc, rhs, i);
|
||||||
|
resultShape.push_back(lhsBatch);
|
||||||
|
lhsExpr.push_back(rewriter.getAffineDimExpr(i));
|
||||||
|
rhsExpr.push_back(rewriter.getAffineDimExpr(i));
|
||||||
|
outExpr.push_back(rewriter.getAffineDimExpr(i));
|
||||||
|
iteratorTypes.push_back(getParallelIteratorTypeName());
|
||||||
|
checkDimEqualHelper(rewriter, loc, lhsBatch, rhsBatch);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value lhsDim0 = getDimOp(rewriter, loc, lhs, batchRank);
|
||||||
|
Value lhsDim1 = getDimOp(rewriter, loc, lhs, batchRank + 1);
|
||||||
|
Value rhsDim0 = getDimOp(rewriter, loc, rhs, batchRank);
|
||||||
|
Value rhsDim1 = getDimOp(rewriter, loc, rhs, batchRank + 1);
|
||||||
|
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
|
||||||
|
|
||||||
|
// Push the final matrix dimension.
|
||||||
|
resultShape.insert(resultShape.end(), {lhsDim0, rhsDim1});
|
||||||
|
|
||||||
|
lhsExpr.insert(lhsExpr.end(), {rewriter.getAffineDimExpr(batchRank),
|
||||||
|
rewriter.getAffineDimExpr(batchRank + 1)});
|
||||||
|
rhsExpr.insert(rhsExpr.end(), {rewriter.getAffineDimExpr(batchRank + 1),
|
||||||
|
rewriter.getAffineDimExpr(batchRank + 2)});
|
||||||
|
outExpr.insert(outExpr.end(), {rewriter.getAffineDimExpr(batchRank),
|
||||||
|
rewriter.getAffineDimExpr(batchRank + 2)});
|
||||||
|
|
||||||
|
Value initTensor0 =
|
||||||
|
createZeroInitTensor(rewriter, loc, resultShape, elementType);
|
||||||
|
|
||||||
|
auto indexingMaps =
|
||||||
|
AffineMap::inferFromExprList({lhsExpr, rhsExpr, outExpr});
|
||||||
|
iteratorTypes.insert(iteratorTypes.end(),
|
||||||
|
{"parallel", "reduction", "parallel"});
|
||||||
|
|
||||||
|
Value finalRes =
|
||||||
|
rewriter
|
||||||
|
.create<linalg::GenericOp>(
|
||||||
|
loc, newResultType, ValueRange{lhs, rhs}, initTensor0,
|
||||||
|
/*indexingMaps=*/indexingMaps,
|
||||||
|
/*iteratorTypes=*/iteratorTypes,
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
Value l = args[0], r = args[1], res = args[2];
|
||||||
|
Value mul = b.create<arith::MulFOp>(loc, l, r);
|
||||||
|
Value add = b.create<arith::AddFOp>(loc, mul, res);
|
||||||
|
b.create<linalg::YieldOp>(loc, add);
|
||||||
|
})
|
||||||
|
.getResult(0);
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, finalRes);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertAtenBmmOp : public OpConversionPattern<AtenBmmOp> {
|
class ConvertAtenBmmOp : public OpConversionPattern<AtenBmmOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -2352,6 +2499,8 @@ public:
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
target.addIllegalOp<AtenMmOp>();
|
target.addIllegalOp<AtenMmOp>();
|
||||||
patterns.add<ConvertAtenMmOp>(typeConverter, context);
|
patterns.add<ConvertAtenMmOp>(typeConverter, context);
|
||||||
|
target.addIllegalOp<AtenMatmulOp>();
|
||||||
|
patterns.add<ConvertAtenMatmulOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenBmmOp>();
|
target.addIllegalOp<AtenBmmOp>();
|
||||||
patterns.add<ConvertAtenBmmOp>(typeConverter, context);
|
patterns.add<ConvertAtenBmmOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenLinearOp>();
|
target.addIllegalOp<AtenLinearOp>();
|
||||||
|
|
|
@ -20,6 +20,19 @@ using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
|
// Helper funtion to get rank of `Base tensor type`.
|
||||||
|
// -1 is returned if the tensorRank can't be determined.
|
||||||
|
static int getTensorRank(Value tensor) {
|
||||||
|
int tensorRank = -1;
|
||||||
|
BaseTensorType tensorType = tensor.getType().cast<BaseTensorType>();
|
||||||
|
|
||||||
|
if (tensorType.hasSizes()) {
|
||||||
|
ArrayRef<int64_t> tensorShape = tensorType.getSizes();
|
||||||
|
tensorRank = tensorShape.size();
|
||||||
|
}
|
||||||
|
return tensorRank;
|
||||||
|
}
|
||||||
|
|
||||||
// Decompose softmax into: exp(x) / sum(exp(x))
|
// Decompose softmax into: exp(x) / sum(exp(x))
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {
|
class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {
|
||||||
|
@ -75,6 +88,33 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// Decompose torch.matmul into: torch.mm and torch.bmm according to ranks.
|
||||||
|
namespace {
|
||||||
|
class DecomposeAtenMatmulOp : public OpRewritePattern<AtenMatmulOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenMatmulOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
Value lhs = op.self();
|
||||||
|
Value rhs = op.other();
|
||||||
|
|
||||||
|
int lhsRank = getTensorRank(lhs);
|
||||||
|
int rhsRank = getTensorRank(rhs);
|
||||||
|
|
||||||
|
// If both lhs and rhs ranks are 2 then map it to `aten.mm` op.
|
||||||
|
if (lhsRank == 2 && rhsRank == 2)
|
||||||
|
rewriter.replaceOpWithNewOp<AtenMmOp>(op, op.getType(), lhs, rhs);
|
||||||
|
|
||||||
|
// If both lhs and rhs ranks are 3 then map it to `aten.bmm` op.
|
||||||
|
if (lhsRank == 3 && rhsRank == 3)
|
||||||
|
rewriter.replaceOpWithNewOp<AtenBmmOp>(op, op.getType(), lhs, rhs);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeComplexOpsPass
|
class DecomposeComplexOpsPass
|
||||||
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
||||||
|
@ -86,7 +126,17 @@ class DecomposeComplexOpsPass
|
||||||
|
|
||||||
patterns.add<DecomposeAtenSoftmaxIntOp>(context);
|
patterns.add<DecomposeAtenSoftmaxIntOp>(context);
|
||||||
target.addIllegalOp<AtenSoftmaxIntOp>();
|
target.addIllegalOp<AtenSoftmaxIntOp>();
|
||||||
|
patterns.add<DecomposeAtenMatmulOp>(context);
|
||||||
|
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
|
||||||
|
Value lhs = op.self();
|
||||||
|
Value rhs = op.other();
|
||||||
|
|
||||||
|
int lhsRank = getTensorRank(lhs);
|
||||||
|
int rhsRank = getTensorRank(rhs);
|
||||||
|
|
||||||
|
// Make aten.matmul legal if the following condition is satisfied.
|
||||||
|
return (lhsRank != 2 || rhsRank != 2) && (lhsRank != 3 || rhsRank != 3);
|
||||||
|
});
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
std::move(patterns)))) {
|
std::move(patterns)))) {
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
|
|
|
@ -366,6 +366,8 @@ public:
|
||||||
return visitAtenEmbeddingOp(embedding, operands);
|
return visitAtenEmbeddingOp(embedding, operands);
|
||||||
} else if (auto bmm = dyn_cast<AtenBmmOp>(op)) {
|
} else if (auto bmm = dyn_cast<AtenBmmOp>(op)) {
|
||||||
return visitAtenBmmOp(bmm, operands);
|
return visitAtenBmmOp(bmm, operands);
|
||||||
|
} else if (auto matmul = dyn_cast<AtenMatmulOp>(op)) {
|
||||||
|
return visitAtenMatmulOp(matmul, operands);
|
||||||
} else if (auto softmaxIntOp = dyn_cast<AtenSoftmaxIntOp>(op)) {
|
} else if (auto softmaxIntOp = dyn_cast<AtenSoftmaxIntOp>(op)) {
|
||||||
return visitAtenSoftmaxIntOp(softmaxIntOp, operands);
|
return visitAtenSoftmaxIntOp(softmaxIntOp, operands);
|
||||||
}
|
}
|
||||||
|
@ -467,6 +469,9 @@ private:
|
||||||
ChangeResult
|
ChangeResult
|
||||||
visitAtenSoftmaxIntOp(AtenSoftmaxIntOp op,
|
visitAtenSoftmaxIntOp(AtenSoftmaxIntOp op,
|
||||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||||
|
ChangeResult
|
||||||
|
visitAtenMatmulOp(AtenMatmulOp op,
|
||||||
|
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -1118,6 +1123,28 @@ ChangeResult TypeAnalyzer::visitAtenBmmOp(
|
||||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ChangeResult TypeAnalyzer::visitAtenMatmulOp(
|
||||||
|
AtenMatmulOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||||
|
auto knowledge =
|
||||||
|
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||||
|
auto self = operands[0]->getValue();
|
||||||
|
auto other = operands[1]->getValue();
|
||||||
|
if (!self.hasSizes || !other.hasSizes)
|
||||||
|
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||||
|
unsigned maxRank = self.sizes.size() > other.sizes.size()
|
||||||
|
? self.sizes.size()
|
||||||
|
: other.sizes.size();
|
||||||
|
unsigned lhsDim = self.sizes.size() > 2 ? 2 : self.sizes.size();
|
||||||
|
unsigned rhsDim = other.sizes.size() > 2 ? 2 : other.sizes.size();
|
||||||
|
unsigned batchDim = maxRank > 2 ? maxRank - 2 : 0;
|
||||||
|
unsigned matDim = (lhsDim - 1) + (rhsDim - 1);
|
||||||
|
unsigned resultRank = batchDim + matDim;
|
||||||
|
knowledge.sizes.resize(resultRank, kUnknownSize);
|
||||||
|
knowledge.dtype = joinElementTypes(self.dtype, other.dtype);
|
||||||
|
knowledge.hasSizes = true;
|
||||||
|
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||||
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
// Transforms.
|
// Transforms.
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
|
|
@ -472,6 +472,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
||||||
# Non-elementwise tensor compute ops
|
# Non-elementwise tensor compute ops
|
||||||
emit("aten::linear : (Tensor, Tensor, Tensor?) -> (Tensor)")
|
emit("aten::linear : (Tensor, Tensor, Tensor?) -> (Tensor)")
|
||||||
emit("aten::mm : (Tensor, Tensor) -> (Tensor)")
|
emit("aten::mm : (Tensor, Tensor) -> (Tensor)")
|
||||||
|
emit("aten::matmul : (Tensor, Tensor) -> (Tensor)")
|
||||||
emit(
|
emit(
|
||||||
"aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)"
|
"aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)"
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
// RUN: torch-mlir-opt -torch-decompose-complex-ops -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @matmul_no_decompose
|
||||||
|
// CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
|
||||||
|
func @matmul_no_decompose(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
|
||||||
|
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
|
||||||
|
return %0 : !torch.tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @matmul_decompose_2d
|
||||||
|
// CHECK: torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.tensor
|
||||||
|
func @matmul_decompose_2d(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.tensor {
|
||||||
|
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.tensor
|
||||||
|
return %0 : !torch.tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @matmul_decompose_3d(
|
||||||
|
// CHECK: torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
|
||||||
|
func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
|
||||||
|
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
|
||||||
|
return %0 : !torch.tensor
|
||||||
|
}
|
|
@ -955,3 +955,33 @@ func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>, %dim:
|
||||||
%ret = torch.aten.softmax.int %t, %dim, %int4: !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor
|
%ret = torch.aten.softmax.int %t, %dim, %int4: !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor
|
||||||
return %ret : !torch.tensor
|
return %ret : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// ----
|
||||||
|
// CHECK-LABEL: func @aten_matmul_broadcast_matrix(
|
||||||
|
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
|
||||||
|
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>)
|
||||||
|
// CHECK-SAME: -> !torch.tensor {
|
||||||
|
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor<[?,?,?,?,?],f32>
|
||||||
|
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?,?],f32> to !torch.tensor
|
||||||
|
// CHECK: return %[[CAST]] : !torch.tensor
|
||||||
|
// CHECK: }
|
||||||
|
func @aten_matmul_broadcast_matrix(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
|
||||||
|
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
|
||||||
|
return %0 : !torch.tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// ----
|
||||||
|
// CHECK-LABEL: func @aten_matmul_broadcast_vector(
|
||||||
|
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
|
||||||
|
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?],f32>)
|
||||||
|
// CHECK-SAME: -> !torch.tensor {
|
||||||
|
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor<[?,?,?,?],f32>
|
||||||
|
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?],f32> to !torch.tensor
|
||||||
|
// CHECK: return %[[CAST]] : !torch.tensor
|
||||||
|
// CHECK: }
|
||||||
|
func @aten_matmul_broadcast_vector(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.tensor {
|
||||||
|
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor
|
||||||
|
return %0 : !torch.tensor
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue