Add lowering of aten.matmul op.

Lowering of `aten.matmul` op is added from torch to linalg dialect.
The different cases correspond to
https://pytorch.org/docs/stable/generated/torch.matmul.html.
TODO: Broadcasting in case of batch-matmul is yet to be taken care of.

Signed-off-by: Prashant Kumar <prashant@nod-labs.com>
pull/385/head snapshot-20211026.46
Prashant Kumar 2021-10-21 05:15:10 +00:00 committed by Yi Zhang
parent e276dbbaa6
commit 5009cbf55c
9 changed files with 432 additions and 0 deletions

View File

@ -35,6 +35,7 @@ from . import quantized_models
from . import elementwise from . import elementwise
from . import reduction from . import reduction
from . import argmax from . import argmax
from . import matmul
def _get_argparse(): def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external'] config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']

View File

@ -0,0 +1,132 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch
from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ==============================================================================
class MatmulDot(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
([-1], torch.float32, True),
])
def forward(self, lhs, rhs):
return torch.matmul(lhs, rhs)
@register_test_case(module_factory=lambda: MatmulDot())
def Matmul_dot(module, tu: TestUtils):
module.forward(tu.rand(3), tu.rand(3))
# ==============================================================================
class Matmul2D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, lhs, rhs):
return torch.matmul(lhs, rhs)
@register_test_case(module_factory=lambda: Matmul2D())
def Matmul_2d(module, tu: TestUtils):
module.forward(tu.rand(3, 4), tu.rand(4, 5))
# ==============================================================================
class MatmulVecMat(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, lhs, rhs):
return torch.matmul(lhs, rhs)
@register_test_case(module_factory=lambda: MatmulVecMat())
def Matmul_vecmat(module, tu: TestUtils):
module.forward(tu.rand(4), tu.rand(4, 5))
# ==============================================================================
class MatmulMatVec(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1], torch.float32, True),
])
def forward(self, lhs, rhs):
return torch.matmul(lhs, rhs)
@register_test_case(module_factory=lambda: MatmulMatVec())
def Matmul_matvec(module, tu: TestUtils):
module.forward(tu.rand(4, 5), tu.rand(5))
# ==============================================================================
class Matmul3D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
])
def forward(self, lhs, rhs):
return torch.matmul(lhs, rhs)
@register_test_case(module_factory=lambda: Matmul3D())
def Matmul_3d(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4))
# ==============================================================================
class Matmul4d(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, lhs, rhs):
return torch.matmul(lhs, rhs)
@register_test_case(module_factory=lambda: Matmul4d())
def Matmul_4d(module, tu: TestUtils):
module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6))
# ==============================================================================

View File

@ -871,6 +871,21 @@ def Torch_AtenMmOp : Torch_Op<"aten.mm", [
let assemblyFormat = "$self `,` $mat2 attr-dict `:` type($self) `,` type($mat2) `->` type($result)"; let assemblyFormat = "$self `,` $mat2 attr-dict `:` type($self) `,` type($mat2) `->` type($result)";
} }
def Torch_AtenMatmulOp : Torch_Op<"aten.matmul", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::matmul : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)";
}
def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [ def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics HasValueSemantics

View File

@ -942,6 +942,153 @@ public:
}; };
} // namespace } // namespace
namespace {
class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenMatmulOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value lhs = operands[0];
Value rhs = operands[1];
unsigned lhsRank = lhs.getType().cast<RankedTensorType>().getRank();
unsigned rhsRank = rhs.getType().cast<RankedTensorType>().getRank();
Type newResultType = getTypeConverter()->convertType(op.getType());
Type elementType = newResultType.cast<TensorType>().getElementType();
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
// The different cases of torch_matmul op is mentioned here:
// https://pytorch.org/docs/stable/generated/torch.matmul.html
// First Case: Dot Product.
if (lhsRank == 1 && rhsRank == 1) {
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0);
Value zeroTensor = createZeroInitTensor(rewriter, loc, {}, elementType);
Value dotProd =
rewriter
.create<linalg::DotOp>(loc, zeroTensor.getType(),
ValueRange{lhs, rhs}, zeroTensor)
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, dotProd);
return success();
}
// Second Case: Vec-Mat Multiplication.
if (lhsRank == 1 && rhsRank == 2) {
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1);
checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0);
Value zeroTensor =
createZeroInitTensor(rewriter, loc, ValueRange{rhsDim1}, elementType);
Value matmul =
rewriter
.create<linalg::VecmatOp>(loc, zeroTensor.getType(),
ValueRange{lhs, rhs}, zeroTensor)
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
return success();
}
// Third Case: Matrix-Vec Multiplication.
if (lhsRank == 2 && rhsRank == 1) {
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1);
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
Value zeroTensor =
createZeroInitTensor(rewriter, loc, ValueRange{lhsDim0}, elementType);
Value matmul =
rewriter
.create<linalg::MatvecOp>(loc, zeroTensor.getType(),
ValueRange{lhs, rhs}, zeroTensor)
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
return success();
}
// Fourth Case: Batch-Matrix Multiplication.
// TODO: Broadcasting of batch dimension is remaining.
if (lhsRank >= 3 && rhsRank >= 3 && lhsRank == rhsRank) {
unsigned batchRank = lhsRank - 2;
SmallVector<Value, 4> resultShape;
SmallVector<AffineExpr> lhsExpr;
SmallVector<AffineExpr> rhsExpr;
SmallVector<AffineExpr> outExpr;
SmallVector<StringRef> iteratorTypes;
// Since broadcasting is a TODO, check whether the lhs and rhs batch
// dimension match.
for (unsigned i = 0; i < batchRank; i++) {
Value lhsBatch = getDimOp(rewriter, loc, lhs, i);
Value rhsBatch = getDimOp(rewriter, loc, rhs, i);
resultShape.push_back(lhsBatch);
lhsExpr.push_back(rewriter.getAffineDimExpr(i));
rhsExpr.push_back(rewriter.getAffineDimExpr(i));
outExpr.push_back(rewriter.getAffineDimExpr(i));
iteratorTypes.push_back(getParallelIteratorTypeName());
checkDimEqualHelper(rewriter, loc, lhsBatch, rhsBatch);
}
Value lhsDim0 = getDimOp(rewriter, loc, lhs, batchRank);
Value lhsDim1 = getDimOp(rewriter, loc, lhs, batchRank + 1);
Value rhsDim0 = getDimOp(rewriter, loc, rhs, batchRank);
Value rhsDim1 = getDimOp(rewriter, loc, rhs, batchRank + 1);
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
// Push the final matrix dimension.
resultShape.insert(resultShape.end(), {lhsDim0, rhsDim1});
lhsExpr.insert(lhsExpr.end(), {rewriter.getAffineDimExpr(batchRank),
rewriter.getAffineDimExpr(batchRank + 1)});
rhsExpr.insert(rhsExpr.end(), {rewriter.getAffineDimExpr(batchRank + 1),
rewriter.getAffineDimExpr(batchRank + 2)});
outExpr.insert(outExpr.end(), {rewriter.getAffineDimExpr(batchRank),
rewriter.getAffineDimExpr(batchRank + 2)});
Value initTensor0 =
createZeroInitTensor(rewriter, loc, resultShape, elementType);
auto indexingMaps =
AffineMap::inferFromExprList({lhsExpr, rhsExpr, outExpr});
iteratorTypes.insert(iteratorTypes.end(),
{"parallel", "reduction", "parallel"});
Value finalRes =
rewriter
.create<linalg::GenericOp>(
loc, newResultType, ValueRange{lhs, rhs}, initTensor0,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value l = args[0], r = args[1], res = args[2];
Value mul = b.create<arith::MulFOp>(loc, l, r);
Value add = b.create<arith::AddFOp>(loc, mul, res);
b.create<linalg::YieldOp>(loc, add);
})
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, finalRes);
return success();
}
return failure();
}
};
} // namespace
namespace { namespace {
class ConvertAtenBmmOp : public OpConversionPattern<AtenBmmOp> { class ConvertAtenBmmOp : public OpConversionPattern<AtenBmmOp> {
public: public:
@ -2352,6 +2499,8 @@ public:
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
target.addIllegalOp<AtenMmOp>(); target.addIllegalOp<AtenMmOp>();
patterns.add<ConvertAtenMmOp>(typeConverter, context); patterns.add<ConvertAtenMmOp>(typeConverter, context);
target.addIllegalOp<AtenMatmulOp>();
patterns.add<ConvertAtenMatmulOp>(typeConverter, context);
target.addIllegalOp<AtenBmmOp>(); target.addIllegalOp<AtenBmmOp>();
patterns.add<ConvertAtenBmmOp>(typeConverter, context); patterns.add<ConvertAtenBmmOp>(typeConverter, context);
target.addIllegalOp<AtenLinearOp>(); target.addIllegalOp<AtenLinearOp>();

View File

@ -20,6 +20,19 @@ using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
// Helper funtion to get rank of `Base tensor type`.
// -1 is returned if the tensorRank can't be determined.
static int getTensorRank(Value tensor) {
int tensorRank = -1;
BaseTensorType tensorType = tensor.getType().cast<BaseTensorType>();
if (tensorType.hasSizes()) {
ArrayRef<int64_t> tensorShape = tensorType.getSizes();
tensorRank = tensorShape.size();
}
return tensorRank;
}
// Decompose softmax into: exp(x) / sum(exp(x)) // Decompose softmax into: exp(x) / sum(exp(x))
namespace { namespace {
class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> { class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {
@ -75,6 +88,33 @@ public:
}; };
} // namespace } // namespace
// Decompose torch.matmul into: torch.mm and torch.bmm according to ranks.
namespace {
class DecomposeAtenMatmulOp : public OpRewritePattern<AtenMatmulOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenMatmulOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value lhs = op.self();
Value rhs = op.other();
int lhsRank = getTensorRank(lhs);
int rhsRank = getTensorRank(rhs);
// If both lhs and rhs ranks are 2 then map it to `aten.mm` op.
if (lhsRank == 2 && rhsRank == 2)
rewriter.replaceOpWithNewOp<AtenMmOp>(op, op.getType(), lhs, rhs);
// If both lhs and rhs ranks are 3 then map it to `aten.bmm` op.
if (lhsRank == 3 && rhsRank == 3)
rewriter.replaceOpWithNewOp<AtenBmmOp>(op, op.getType(), lhs, rhs);
return success();
}
};
} // namespace
namespace { namespace {
class DecomposeComplexOpsPass class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> { : public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@ -86,7 +126,17 @@ class DecomposeComplexOpsPass
patterns.add<DecomposeAtenSoftmaxIntOp>(context); patterns.add<DecomposeAtenSoftmaxIntOp>(context);
target.addIllegalOp<AtenSoftmaxIntOp>(); target.addIllegalOp<AtenSoftmaxIntOp>();
patterns.add<DecomposeAtenMatmulOp>(context);
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
Value lhs = op.self();
Value rhs = op.other();
int lhsRank = getTensorRank(lhs);
int rhsRank = getTensorRank(rhs);
// Make aten.matmul legal if the following condition is satisfied.
return (lhsRank != 2 || rhsRank != 2) && (lhsRank != 3 || rhsRank != 3);
});
if (failed(applyPartialConversion(getOperation(), target, if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) { std::move(patterns)))) {
return signalPassFailure(); return signalPassFailure();

View File

@ -366,6 +366,8 @@ public:
return visitAtenEmbeddingOp(embedding, operands); return visitAtenEmbeddingOp(embedding, operands);
} else if (auto bmm = dyn_cast<AtenBmmOp>(op)) { } else if (auto bmm = dyn_cast<AtenBmmOp>(op)) {
return visitAtenBmmOp(bmm, operands); return visitAtenBmmOp(bmm, operands);
} else if (auto matmul = dyn_cast<AtenMatmulOp>(op)) {
return visitAtenMatmulOp(matmul, operands);
} else if (auto softmaxIntOp = dyn_cast<AtenSoftmaxIntOp>(op)) { } else if (auto softmaxIntOp = dyn_cast<AtenSoftmaxIntOp>(op)) {
return visitAtenSoftmaxIntOp(softmaxIntOp, operands); return visitAtenSoftmaxIntOp(softmaxIntOp, operands);
} }
@ -467,6 +469,9 @@ private:
ChangeResult ChangeResult
visitAtenSoftmaxIntOp(AtenSoftmaxIntOp op, visitAtenSoftmaxIntOp(AtenSoftmaxIntOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands); ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult
visitAtenMatmulOp(AtenMatmulOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
}; };
} // namespace } // namespace
@ -1118,6 +1123,28 @@ ChangeResult TypeAnalyzer::visitAtenBmmOp(
return getLatticeElement(op->getResult(0)).join(knowledge); return getLatticeElement(op->getResult(0)).join(knowledge);
} }
ChangeResult TypeAnalyzer::visitAtenMatmulOp(
AtenMatmulOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
auto self = operands[0]->getValue();
auto other = operands[1]->getValue();
if (!self.hasSizes || !other.hasSizes)
return getLatticeElement(op->getResult(0)).join(knowledge);
unsigned maxRank = self.sizes.size() > other.sizes.size()
? self.sizes.size()
: other.sizes.size();
unsigned lhsDim = self.sizes.size() > 2 ? 2 : self.sizes.size();
unsigned rhsDim = other.sizes.size() > 2 ? 2 : other.sizes.size();
unsigned batchDim = maxRank > 2 ? maxRank - 2 : 0;
unsigned matDim = (lhsDim - 1) + (rhsDim - 1);
unsigned resultRank = batchDim + matDim;
knowledge.sizes.resize(resultRank, kUnknownSize);
knowledge.dtype = joinElementTypes(self.dtype, other.dtype);
knowledge.hasSizes = true;
return getLatticeElement(op->getResult(0)).join(knowledge);
}
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
// Transforms. // Transforms.
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------

View File

@ -472,6 +472,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
# Non-elementwise tensor compute ops # Non-elementwise tensor compute ops
emit("aten::linear : (Tensor, Tensor, Tensor?) -> (Tensor)") emit("aten::linear : (Tensor, Tensor, Tensor?) -> (Tensor)")
emit("aten::mm : (Tensor, Tensor) -> (Tensor)") emit("aten::mm : (Tensor, Tensor) -> (Tensor)")
emit("aten::matmul : (Tensor, Tensor) -> (Tensor)")
emit( emit(
"aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" "aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)"
) )

View File

@ -0,0 +1,27 @@
// RUN: torch-mlir-opt -torch-decompose-complex-ops -split-input-file %s | FileCheck %s
// CHECK-LABEL: func @matmul_no_decompose
// CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
func @matmul_no_decompose(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
return %0 : !torch.tensor
}
// -----
// CHECK-LABEL: func @matmul_decompose_2d
// CHECK: torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.tensor
func @matmul_decompose_2d(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.tensor
return %0 : !torch.tensor
}
// -----
// CHECK-LABEL: func @matmul_decompose_3d(
// CHECK: torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
return %0 : !torch.tensor
}

View File

@ -955,3 +955,33 @@ func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>, %dim:
%ret = torch.aten.softmax.int %t, %dim, %int4: !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor %ret = torch.aten.softmax.int %t, %dim, %int4: !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// ----
// CHECK-LABEL: func @aten_matmul_broadcast_matrix(
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>)
// CHECK-SAME: -> !torch.tensor {
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor<[?,?,?,?,?],f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?,?],f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
// CHECK: }
func @aten_matmul_broadcast_matrix(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
return %0 : !torch.tensor
}
// ----
// CHECK-LABEL: func @aten_matmul_broadcast_vector(
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?],f32>)
// CHECK-SAME: -> !torch.tensor {
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor<[?,?,?,?],f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?],f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
// CHECK: }
func @aten_matmul_broadcast_vector(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor
return %0 : !torch.tensor
}