mirror of https://github.com/llvm/torch-mlir
Add lowering of aten.matmul op.
Lowering of `aten.matmul` op is added from torch to linalg dialect. The different cases correspond to https://pytorch.org/docs/stable/generated/torch.matmul.html. TODO: Broadcasting in case of batch-matmul is yet to be taken care of. Signed-off-by: Prashant Kumar <prashant@nod-labs.com>pull/385/head snapshot-20211026.46
parent
e276dbbaa6
commit
5009cbf55c
|
@ -35,6 +35,7 @@ from . import quantized_models
|
|||
from . import elementwise
|
||||
from . import reduction
|
||||
from . import argmax
|
||||
from . import matmul
|
||||
|
||||
def _get_argparse():
|
||||
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
|
||||
|
|
|
@ -0,0 +1,132 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class MatmulDot(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MatmulDot())
|
||||
def Matmul_dot(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3), tu.rand(3))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class Matmul2D(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Matmul2D())
|
||||
def Matmul_2d(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4), tu.rand(4, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class MatmulVecMat(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MatmulVecMat())
|
||||
def Matmul_vecmat(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4), tu.rand(4, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class MatmulMatVec(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MatmulMatVec())
|
||||
def Matmul_matvec(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5), tu.rand(5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class Matmul3D(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Matmul3D())
|
||||
def Matmul_3d(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class Matmul4d(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Matmul4d())
|
||||
def Matmul_4d(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6))
|
||||
|
||||
# ==============================================================================
|
|
@ -871,6 +871,21 @@ def Torch_AtenMmOp : Torch_Op<"aten.mm", [
|
|||
let assemblyFormat = "$self `,` $mat2 attr-dict `:` type($self) `,` type($mat2) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenMatmulOp : Torch_Op<"aten.matmul", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::matmul : (Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$other
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
|
|
@ -942,6 +942,153 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenMatmulOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
Value lhs = operands[0];
|
||||
Value rhs = operands[1];
|
||||
|
||||
unsigned lhsRank = lhs.getType().cast<RankedTensorType>().getRank();
|
||||
unsigned rhsRank = rhs.getType().cast<RankedTensorType>().getRank();
|
||||
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
Type elementType = newResultType.cast<TensorType>().getElementType();
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
// The different cases of torch_matmul op is mentioned here:
|
||||
// https://pytorch.org/docs/stable/generated/torch.matmul.html
|
||||
|
||||
// First Case: Dot Product.
|
||||
if (lhsRank == 1 && rhsRank == 1) {
|
||||
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
|
||||
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
|
||||
|
||||
checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0);
|
||||
|
||||
Value zeroTensor = createZeroInitTensor(rewriter, loc, {}, elementType);
|
||||
Value dotProd =
|
||||
rewriter
|
||||
.create<linalg::DotOp>(loc, zeroTensor.getType(),
|
||||
ValueRange{lhs, rhs}, zeroTensor)
|
||||
.getResult(0);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, dotProd);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Second Case: Vec-Mat Multiplication.
|
||||
if (lhsRank == 1 && rhsRank == 2) {
|
||||
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
|
||||
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
|
||||
Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1);
|
||||
checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0);
|
||||
|
||||
Value zeroTensor =
|
||||
createZeroInitTensor(rewriter, loc, ValueRange{rhsDim1}, elementType);
|
||||
Value matmul =
|
||||
rewriter
|
||||
.create<linalg::VecmatOp>(loc, zeroTensor.getType(),
|
||||
ValueRange{lhs, rhs}, zeroTensor)
|
||||
.getResult(0);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Third Case: Matrix-Vec Multiplication.
|
||||
if (lhsRank == 2 && rhsRank == 1) {
|
||||
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
|
||||
Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1);
|
||||
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
|
||||
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
|
||||
|
||||
Value zeroTensor =
|
||||
createZeroInitTensor(rewriter, loc, ValueRange{lhsDim0}, elementType);
|
||||
Value matmul =
|
||||
rewriter
|
||||
.create<linalg::MatvecOp>(loc, zeroTensor.getType(),
|
||||
ValueRange{lhs, rhs}, zeroTensor)
|
||||
.getResult(0);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Fourth Case: Batch-Matrix Multiplication.
|
||||
// TODO: Broadcasting of batch dimension is remaining.
|
||||
if (lhsRank >= 3 && rhsRank >= 3 && lhsRank == rhsRank) {
|
||||
|
||||
unsigned batchRank = lhsRank - 2;
|
||||
SmallVector<Value, 4> resultShape;
|
||||
|
||||
SmallVector<AffineExpr> lhsExpr;
|
||||
SmallVector<AffineExpr> rhsExpr;
|
||||
SmallVector<AffineExpr> outExpr;
|
||||
SmallVector<StringRef> iteratorTypes;
|
||||
|
||||
// Since broadcasting is a TODO, check whether the lhs and rhs batch
|
||||
// dimension match.
|
||||
for (unsigned i = 0; i < batchRank; i++) {
|
||||
Value lhsBatch = getDimOp(rewriter, loc, lhs, i);
|
||||
Value rhsBatch = getDimOp(rewriter, loc, rhs, i);
|
||||
resultShape.push_back(lhsBatch);
|
||||
lhsExpr.push_back(rewriter.getAffineDimExpr(i));
|
||||
rhsExpr.push_back(rewriter.getAffineDimExpr(i));
|
||||
outExpr.push_back(rewriter.getAffineDimExpr(i));
|
||||
iteratorTypes.push_back(getParallelIteratorTypeName());
|
||||
checkDimEqualHelper(rewriter, loc, lhsBatch, rhsBatch);
|
||||
}
|
||||
|
||||
Value lhsDim0 = getDimOp(rewriter, loc, lhs, batchRank);
|
||||
Value lhsDim1 = getDimOp(rewriter, loc, lhs, batchRank + 1);
|
||||
Value rhsDim0 = getDimOp(rewriter, loc, rhs, batchRank);
|
||||
Value rhsDim1 = getDimOp(rewriter, loc, rhs, batchRank + 1);
|
||||
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
|
||||
|
||||
// Push the final matrix dimension.
|
||||
resultShape.insert(resultShape.end(), {lhsDim0, rhsDim1});
|
||||
|
||||
lhsExpr.insert(lhsExpr.end(), {rewriter.getAffineDimExpr(batchRank),
|
||||
rewriter.getAffineDimExpr(batchRank + 1)});
|
||||
rhsExpr.insert(rhsExpr.end(), {rewriter.getAffineDimExpr(batchRank + 1),
|
||||
rewriter.getAffineDimExpr(batchRank + 2)});
|
||||
outExpr.insert(outExpr.end(), {rewriter.getAffineDimExpr(batchRank),
|
||||
rewriter.getAffineDimExpr(batchRank + 2)});
|
||||
|
||||
Value initTensor0 =
|
||||
createZeroInitTensor(rewriter, loc, resultShape, elementType);
|
||||
|
||||
auto indexingMaps =
|
||||
AffineMap::inferFromExprList({lhsExpr, rhsExpr, outExpr});
|
||||
iteratorTypes.insert(iteratorTypes.end(),
|
||||
{"parallel", "reduction", "parallel"});
|
||||
|
||||
Value finalRes =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, newResultType, ValueRange{lhs, rhs}, initTensor0,
|
||||
/*indexingMaps=*/indexingMaps,
|
||||
/*iteratorTypes=*/iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value l = args[0], r = args[1], res = args[2];
|
||||
Value mul = b.create<arith::MulFOp>(loc, l, r);
|
||||
Value add = b.create<arith::AddFOp>(loc, mul, res);
|
||||
b.create<linalg::YieldOp>(loc, add);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, finalRes);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenBmmOp : public OpConversionPattern<AtenBmmOp> {
|
||||
public:
|
||||
|
@ -2352,6 +2499,8 @@ public:
|
|||
RewritePatternSet patterns(context);
|
||||
target.addIllegalOp<AtenMmOp>();
|
||||
patterns.add<ConvertAtenMmOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenMatmulOp>();
|
||||
patterns.add<ConvertAtenMatmulOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenBmmOp>();
|
||||
patterns.add<ConvertAtenBmmOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenLinearOp>();
|
||||
|
|
|
@ -20,6 +20,19 @@ using namespace mlir;
|
|||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
// Helper funtion to get rank of `Base tensor type`.
|
||||
// -1 is returned if the tensorRank can't be determined.
|
||||
static int getTensorRank(Value tensor) {
|
||||
int tensorRank = -1;
|
||||
BaseTensorType tensorType = tensor.getType().cast<BaseTensorType>();
|
||||
|
||||
if (tensorType.hasSizes()) {
|
||||
ArrayRef<int64_t> tensorShape = tensorType.getSizes();
|
||||
tensorRank = tensorShape.size();
|
||||
}
|
||||
return tensorRank;
|
||||
}
|
||||
|
||||
// Decompose softmax into: exp(x) / sum(exp(x))
|
||||
namespace {
|
||||
class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {
|
||||
|
@ -75,6 +88,33 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose torch.matmul into: torch.mm and torch.bmm according to ranks.
|
||||
namespace {
|
||||
class DecomposeAtenMatmulOp : public OpRewritePattern<AtenMatmulOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenMatmulOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value lhs = op.self();
|
||||
Value rhs = op.other();
|
||||
|
||||
int lhsRank = getTensorRank(lhs);
|
||||
int rhsRank = getTensorRank(rhs);
|
||||
|
||||
// If both lhs and rhs ranks are 2 then map it to `aten.mm` op.
|
||||
if (lhsRank == 2 && rhsRank == 2)
|
||||
rewriter.replaceOpWithNewOp<AtenMmOp>(op, op.getType(), lhs, rhs);
|
||||
|
||||
// If both lhs and rhs ranks are 3 then map it to `aten.bmm` op.
|
||||
if (lhsRank == 3 && rhsRank == 3)
|
||||
rewriter.replaceOpWithNewOp<AtenBmmOp>(op, op.getType(), lhs, rhs);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeComplexOpsPass
|
||||
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
||||
|
@ -86,7 +126,17 @@ class DecomposeComplexOpsPass
|
|||
|
||||
patterns.add<DecomposeAtenSoftmaxIntOp>(context);
|
||||
target.addIllegalOp<AtenSoftmaxIntOp>();
|
||||
patterns.add<DecomposeAtenMatmulOp>(context);
|
||||
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
|
||||
Value lhs = op.self();
|
||||
Value rhs = op.other();
|
||||
|
||||
int lhsRank = getTensorRank(lhs);
|
||||
int rhsRank = getTensorRank(rhs);
|
||||
|
||||
// Make aten.matmul legal if the following condition is satisfied.
|
||||
return (lhsRank != 2 || rhsRank != 2) && (lhsRank != 3 || rhsRank != 3);
|
||||
});
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns)))) {
|
||||
return signalPassFailure();
|
||||
|
|
|
@ -366,6 +366,8 @@ public:
|
|||
return visitAtenEmbeddingOp(embedding, operands);
|
||||
} else if (auto bmm = dyn_cast<AtenBmmOp>(op)) {
|
||||
return visitAtenBmmOp(bmm, operands);
|
||||
} else if (auto matmul = dyn_cast<AtenMatmulOp>(op)) {
|
||||
return visitAtenMatmulOp(matmul, operands);
|
||||
} else if (auto softmaxIntOp = dyn_cast<AtenSoftmaxIntOp>(op)) {
|
||||
return visitAtenSoftmaxIntOp(softmaxIntOp, operands);
|
||||
}
|
||||
|
@ -467,6 +469,9 @@ private:
|
|||
ChangeResult
|
||||
visitAtenSoftmaxIntOp(AtenSoftmaxIntOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
ChangeResult
|
||||
visitAtenMatmulOp(AtenMatmulOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
@ -1118,6 +1123,28 @@ ChangeResult TypeAnalyzer::visitAtenBmmOp(
|
|||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
}
|
||||
|
||||
ChangeResult TypeAnalyzer::visitAtenMatmulOp(
|
||||
AtenMatmulOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
auto self = operands[0]->getValue();
|
||||
auto other = operands[1]->getValue();
|
||||
if (!self.hasSizes || !other.hasSizes)
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
unsigned maxRank = self.sizes.size() > other.sizes.size()
|
||||
? self.sizes.size()
|
||||
: other.sizes.size();
|
||||
unsigned lhsDim = self.sizes.size() > 2 ? 2 : self.sizes.size();
|
||||
unsigned rhsDim = other.sizes.size() > 2 ? 2 : other.sizes.size();
|
||||
unsigned batchDim = maxRank > 2 ? maxRank - 2 : 0;
|
||||
unsigned matDim = (lhsDim - 1) + (rhsDim - 1);
|
||||
unsigned resultRank = batchDim + matDim;
|
||||
knowledge.sizes.resize(resultRank, kUnknownSize);
|
||||
knowledge.dtype = joinElementTypes(self.dtype, other.dtype);
|
||||
knowledge.hasSizes = true;
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Transforms.
|
||||
// -----------------------------------------------------------------------------
|
||||
|
|
|
@ -472,6 +472,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
# Non-elementwise tensor compute ops
|
||||
emit("aten::linear : (Tensor, Tensor, Tensor?) -> (Tensor)")
|
||||
emit("aten::mm : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::matmul : (Tensor, Tensor) -> (Tensor)")
|
||||
emit(
|
||||
"aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)"
|
||||
)
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
// RUN: torch-mlir-opt -torch-decompose-complex-ops -split-input-file %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @matmul_no_decompose
|
||||
// CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
|
||||
func @matmul_no_decompose(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
|
||||
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
|
||||
return %0 : !torch.tensor
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @matmul_decompose_2d
|
||||
// CHECK: torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.tensor
|
||||
func @matmul_decompose_2d(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.tensor {
|
||||
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.tensor
|
||||
return %0 : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @matmul_decompose_3d(
|
||||
// CHECK: torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
|
||||
func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
|
||||
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
|
||||
return %0 : !torch.tensor
|
||||
}
|
|
@ -955,3 +955,33 @@ func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>, %dim:
|
|||
%ret = torch.aten.softmax.int %t, %dim, %int4: !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
|
||||
// ----
|
||||
// CHECK-LABEL: func @aten_matmul_broadcast_matrix(
|
||||
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>)
|
||||
// CHECK-SAME: -> !torch.tensor {
|
||||
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor<[?,?,?,?,?],f32>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?,?],f32> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
// CHECK: }
|
||||
func @aten_matmul_broadcast_matrix(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
|
||||
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
|
||||
return %0 : !torch.tensor
|
||||
}
|
||||
|
||||
|
||||
// ----
|
||||
// CHECK-LABEL: func @aten_matmul_broadcast_vector(
|
||||
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?],f32>)
|
||||
// CHECK-SAME: -> !torch.tensor {
|
||||
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor<[?,?,?,?],f32>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?],f32> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
// CHECK: }
|
||||
func @aten_matmul_broadcast_vector(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.tensor {
|
||||
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor
|
||||
return %0 : !torch.tensor
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue