From 5009cbf55ca2210ee5d7822e73e674f039b139a2 Mon Sep 17 00:00:00 2001 From: Prashant Kumar Date: Thu, 21 Oct 2021 05:15:10 +0000 Subject: [PATCH] Add lowering of aten.matmul op. Lowering of `aten.matmul` op is added from torch to linalg dialect. The different cases correspond to https://pytorch.org/docs/stable/generated/torch.matmul.html. TODO: Broadcasting in case of batch-matmul is yet to be taken care of. Signed-off-by: Prashant Kumar --- e2e_testing/torchscript/main.py | 1 + e2e_testing/torchscript/matmul.py | 132 ++++++++++++++++ .../Dialect/Torch/IR/GeneratedAtenOps.td | 15 ++ .../TorchToLinalg/TorchToLinalg.cpp | 149 ++++++++++++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 50 ++++++ lib/Dialect/Torch/Transforms/RefineTypes.cpp | 27 ++++ .../jit_ir/build_tools/torch_ods_gen.py | 1 + test/Dialect/Torch/decompose-complex-ops.mlir | 27 ++++ test/Dialect/Torch/refine-types.mlir | 30 ++++ 9 files changed, 432 insertions(+) create mode 100644 e2e_testing/torchscript/matmul.py create mode 100644 test/Dialect/Torch/decompose-complex-ops.mlir diff --git a/e2e_testing/torchscript/main.py b/e2e_testing/torchscript/main.py index de5cdcd3e..8566fd2b8 100644 --- a/e2e_testing/torchscript/main.py +++ b/e2e_testing/torchscript/main.py @@ -35,6 +35,7 @@ from . import quantized_models from . import elementwise from . import reduction from . import argmax +from . import matmul def _get_argparse(): config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external'] diff --git a/e2e_testing/torchscript/matmul.py b/e2e_testing/torchscript/matmul.py new file mode 100644 index 000000000..09d8f50a0 --- /dev/null +++ b/e2e_testing/torchscript/matmul.py @@ -0,0 +1,132 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import torch + +from torch_mlir_e2e_test.torchscript.framework import TestUtils +from torch_mlir_e2e_test.torchscript.registry import register_test_case +from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export + +# ============================================================================== + +class MatmulDot(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ]) + def forward(self, lhs, rhs): + return torch.matmul(lhs, rhs) + + +@register_test_case(module_factory=lambda: MatmulDot()) +def Matmul_dot(module, tu: TestUtils): + module.forward(tu.rand(3), tu.rand(3)) + +# ============================================================================== + +class Matmul2D(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ]) + def forward(self, lhs, rhs): + return torch.matmul(lhs, rhs) + + +@register_test_case(module_factory=lambda: Matmul2D()) +def Matmul_2d(module, tu: TestUtils): + module.forward(tu.rand(3, 4), tu.rand(4, 5)) + +# ============================================================================== + +class MatmulVecMat(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1, -1], torch.float32, True), + ]) + def forward(self, lhs, rhs): + return torch.matmul(lhs, rhs) + + +@register_test_case(module_factory=lambda: MatmulVecMat()) +def Matmul_vecmat(module, tu: TestUtils): + module.forward(tu.rand(4), tu.rand(4, 5)) + +# ============================================================================== + +class MatmulMatVec(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.float32, True), + ]) + def forward(self, lhs, rhs): + return torch.matmul(lhs, rhs) + + +@register_test_case(module_factory=lambda: MatmulMatVec()) +def Matmul_matvec(module, tu: TestUtils): + module.forward(tu.rand(4, 5), tu.rand(5)) + +# ============================================================================== + +class Matmul3D(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, lhs, rhs): + return torch.matmul(lhs, rhs) + + +@register_test_case(module_factory=lambda: Matmul3D()) +def Matmul_3d(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4)) + +# ============================================================================== + +class Matmul4d(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, lhs, rhs): + return torch.matmul(lhs, rhs) + + +@register_test_case(module_factory=lambda: Matmul4d()) +def Matmul_4d(module, tu: TestUtils): + module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6)) + +# ============================================================================== diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td index cdf2f3021..2fba3ed43 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td @@ -871,6 +871,21 @@ def Torch_AtenMmOp : Torch_Op<"aten.mm", [ let assemblyFormat = "$self `,` $mat2 attr-dict `:` type($self) `,` type($mat2) `->` type($result)"; } +def Torch_AtenMatmulOp : Torch_Op<"aten.matmul", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::matmul : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; +} + def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [ AllowsTypeRefinement, HasValueSemantics diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index addde2e8f..b26c25cd8 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -942,6 +942,153 @@ public: }; } // namespace +namespace { +class ConvertAtenMatmulOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenMatmulOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + Value lhs = operands[0]; + Value rhs = operands[1]; + + unsigned lhsRank = lhs.getType().cast().getRank(); + unsigned rhsRank = rhs.getType().cast().getRank(); + + Type newResultType = getTypeConverter()->convertType(op.getType()); + Type elementType = newResultType.cast().getElementType(); + + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + // The different cases of torch_matmul op is mentioned here: + // https://pytorch.org/docs/stable/generated/torch.matmul.html + + // First Case: Dot Product. + if (lhsRank == 1 && rhsRank == 1) { + Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0); + Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0); + + checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0); + + Value zeroTensor = createZeroInitTensor(rewriter, loc, {}, elementType); + Value dotProd = + rewriter + .create(loc, zeroTensor.getType(), + ValueRange{lhs, rhs}, zeroTensor) + .getResult(0); + rewriter.replaceOpWithNewOp(op, newResultType, dotProd); + return success(); + } + + // Second Case: Vec-Mat Multiplication. + if (lhsRank == 1 && rhsRank == 2) { + Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0); + Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0); + Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1); + checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0); + + Value zeroTensor = + createZeroInitTensor(rewriter, loc, ValueRange{rhsDim1}, elementType); + Value matmul = + rewriter + .create(loc, zeroTensor.getType(), + ValueRange{lhs, rhs}, zeroTensor) + .getResult(0); + rewriter.replaceOpWithNewOp(op, newResultType, matmul); + return success(); + } + + // Third Case: Matrix-Vec Multiplication. + if (lhsRank == 2 && rhsRank == 1) { + Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0); + Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1); + Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0); + checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0); + + Value zeroTensor = + createZeroInitTensor(rewriter, loc, ValueRange{lhsDim0}, elementType); + Value matmul = + rewriter + .create(loc, zeroTensor.getType(), + ValueRange{lhs, rhs}, zeroTensor) + .getResult(0); + rewriter.replaceOpWithNewOp(op, newResultType, matmul); + return success(); + } + + // Fourth Case: Batch-Matrix Multiplication. + // TODO: Broadcasting of batch dimension is remaining. + if (lhsRank >= 3 && rhsRank >= 3 && lhsRank == rhsRank) { + + unsigned batchRank = lhsRank - 2; + SmallVector resultShape; + + SmallVector lhsExpr; + SmallVector rhsExpr; + SmallVector outExpr; + SmallVector iteratorTypes; + + // Since broadcasting is a TODO, check whether the lhs and rhs batch + // dimension match. + for (unsigned i = 0; i < batchRank; i++) { + Value lhsBatch = getDimOp(rewriter, loc, lhs, i); + Value rhsBatch = getDimOp(rewriter, loc, rhs, i); + resultShape.push_back(lhsBatch); + lhsExpr.push_back(rewriter.getAffineDimExpr(i)); + rhsExpr.push_back(rewriter.getAffineDimExpr(i)); + outExpr.push_back(rewriter.getAffineDimExpr(i)); + iteratorTypes.push_back(getParallelIteratorTypeName()); + checkDimEqualHelper(rewriter, loc, lhsBatch, rhsBatch); + } + + Value lhsDim0 = getDimOp(rewriter, loc, lhs, batchRank); + Value lhsDim1 = getDimOp(rewriter, loc, lhs, batchRank + 1); + Value rhsDim0 = getDimOp(rewriter, loc, rhs, batchRank); + Value rhsDim1 = getDimOp(rewriter, loc, rhs, batchRank + 1); + checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0); + + // Push the final matrix dimension. + resultShape.insert(resultShape.end(), {lhsDim0, rhsDim1}); + + lhsExpr.insert(lhsExpr.end(), {rewriter.getAffineDimExpr(batchRank), + rewriter.getAffineDimExpr(batchRank + 1)}); + rhsExpr.insert(rhsExpr.end(), {rewriter.getAffineDimExpr(batchRank + 1), + rewriter.getAffineDimExpr(batchRank + 2)}); + outExpr.insert(outExpr.end(), {rewriter.getAffineDimExpr(batchRank), + rewriter.getAffineDimExpr(batchRank + 2)}); + + Value initTensor0 = + createZeroInitTensor(rewriter, loc, resultShape, elementType); + + auto indexingMaps = + AffineMap::inferFromExprList({lhsExpr, rhsExpr, outExpr}); + iteratorTypes.insert(iteratorTypes.end(), + {"parallel", "reduction", "parallel"}); + + Value finalRes = + rewriter + .create( + loc, newResultType, ValueRange{lhs, rhs}, initTensor0, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value l = args[0], r = args[1], res = args[2]; + Value mul = b.create(loc, l, r); + Value add = b.create(loc, mul, res); + b.create(loc, add); + }) + .getResult(0); + + rewriter.replaceOpWithNewOp(op, newResultType, finalRes); + return success(); + } + return failure(); + } +}; +} // namespace + namespace { class ConvertAtenBmmOp : public OpConversionPattern { public: @@ -2352,6 +2499,8 @@ public: RewritePatternSet patterns(context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 46e7432d9..374e3a4dd 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -20,6 +20,19 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +// Helper funtion to get rank of `Base tensor type`. +// -1 is returned if the tensorRank can't be determined. +static int getTensorRank(Value tensor) { + int tensorRank = -1; + BaseTensorType tensorType = tensor.getType().cast(); + + if (tensorType.hasSizes()) { + ArrayRef tensorShape = tensorType.getSizes(); + tensorRank = tensorShape.size(); + } + return tensorRank; +} + // Decompose softmax into: exp(x) / sum(exp(x)) namespace { class DecomposeAtenSoftmaxIntOp : public OpRewritePattern { @@ -75,6 +88,33 @@ public: }; } // namespace +// Decompose torch.matmul into: torch.mm and torch.bmm according to ranks. +namespace { +class DecomposeAtenMatmulOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenMatmulOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value lhs = op.self(); + Value rhs = op.other(); + + int lhsRank = getTensorRank(lhs); + int rhsRank = getTensorRank(rhs); + + // If both lhs and rhs ranks are 2 then map it to `aten.mm` op. + if (lhsRank == 2 && rhsRank == 2) + rewriter.replaceOpWithNewOp(op, op.getType(), lhs, rhs); + + // If both lhs and rhs ranks are 3 then map it to `aten.bmm` op. + if (lhsRank == 3 && rhsRank == 3) + rewriter.replaceOpWithNewOp(op, op.getType(), lhs, rhs); + + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -86,7 +126,17 @@ class DecomposeComplexOpsPass patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addDynamicallyLegalOp([](AtenMatmulOp op) { + Value lhs = op.self(); + Value rhs = op.other(); + int lhsRank = getTensorRank(lhs); + int rhsRank = getTensorRank(rhs); + + // Make aten.matmul legal if the following condition is satisfied. + return (lhsRank != 2 || rhsRank != 2) && (lhsRank != 3 || rhsRank != 3); + }); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { return signalPassFailure(); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 4dfb1c06f..87a3e1247 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -366,6 +366,8 @@ public: return visitAtenEmbeddingOp(embedding, operands); } else if (auto bmm = dyn_cast(op)) { return visitAtenBmmOp(bmm, operands); + } else if (auto matmul = dyn_cast(op)) { + return visitAtenMatmulOp(matmul, operands); } else if (auto softmaxIntOp = dyn_cast(op)) { return visitAtenSoftmaxIntOp(softmaxIntOp, operands); } @@ -467,6 +469,9 @@ private: ChangeResult visitAtenSoftmaxIntOp(AtenSoftmaxIntOp op, ArrayRef *> operands); + ChangeResult + visitAtenMatmulOp(AtenMatmulOp op, + ArrayRef *> operands); }; } // namespace @@ -1118,6 +1123,28 @@ ChangeResult TypeAnalyzer::visitAtenBmmOp( return getLatticeElement(op->getResult(0)).join(knowledge); } +ChangeResult TypeAnalyzer::visitAtenMatmulOp( + AtenMatmulOp op, ArrayRef *> operands) { + auto knowledge = + ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + auto self = operands[0]->getValue(); + auto other = operands[1]->getValue(); + if (!self.hasSizes || !other.hasSizes) + return getLatticeElement(op->getResult(0)).join(knowledge); + unsigned maxRank = self.sizes.size() > other.sizes.size() + ? self.sizes.size() + : other.sizes.size(); + unsigned lhsDim = self.sizes.size() > 2 ? 2 : self.sizes.size(); + unsigned rhsDim = other.sizes.size() > 2 ? 2 : other.sizes.size(); + unsigned batchDim = maxRank > 2 ? maxRank - 2 : 0; + unsigned matDim = (lhsDim - 1) + (rhsDim - 1); + unsigned resultRank = batchDim + matDim; + knowledge.sizes.resize(resultRank, kUnknownSize); + knowledge.dtype = joinElementTypes(self.dtype, other.dtype); + knowledge.hasSizes = true; + return getLatticeElement(op->getResult(0)).join(knowledge); +} + // ----------------------------------------------------------------------------- // Transforms. // ----------------------------------------------------------------------------- diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index a38b75cd4..87877f83c 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -472,6 +472,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): # Non-elementwise tensor compute ops emit("aten::linear : (Tensor, Tensor, Tensor?) -> (Tensor)") emit("aten::mm : (Tensor, Tensor) -> (Tensor)") + emit("aten::matmul : (Tensor, Tensor) -> (Tensor)") emit( "aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" ) diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir new file mode 100644 index 000000000..5127ce816 --- /dev/null +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -0,0 +1,27 @@ +// RUN: torch-mlir-opt -torch-decompose-complex-ops -split-input-file %s | FileCheck %s + +// CHECK-LABEL: func @matmul_no_decompose +// CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor +func @matmul_no_decompose(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor + return %0 : !torch.tensor +} + + +// ----- + +// CHECK-LABEL: func @matmul_decompose_2d +// CHECK: torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.tensor +func @matmul_decompose_2d(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.tensor { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.tensor + return %0 : !torch.tensor +} + +// ----- + +// CHECK-LABEL: func @matmul_decompose_3d( +// CHECK: torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor +func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor + return %0 : !torch.tensor +} diff --git a/test/Dialect/Torch/refine-types.mlir b/test/Dialect/Torch/refine-types.mlir index 2173d196f..1b715386d 100644 --- a/test/Dialect/Torch/refine-types.mlir +++ b/test/Dialect/Torch/refine-types.mlir @@ -955,3 +955,33 @@ func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>, %dim: %ret = torch.aten.softmax.int %t, %dim, %int4: !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor return %ret : !torch.tensor } + + +// ---- +// CHECK-LABEL: func @aten_matmul_broadcast_matrix( +// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>, +// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>) +// CHECK-SAME: -> !torch.tensor { +// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor<[?,?,?,?,?],f32> +// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?,?],f32> to !torch.tensor +// CHECK: return %[[CAST]] : !torch.tensor +// CHECK: } +func @aten_matmul_broadcast_matrix(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor + return %0 : !torch.tensor +} + + +// ---- +// CHECK-LABEL: func @aten_matmul_broadcast_vector( +// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>, +// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?],f32>) +// CHECK-SAME: -> !torch.tensor { +// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor<[?,?,?,?],f32> +// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<[?,?,?,?],f32> to !torch.tensor +// CHECK: return %[[CAST]] : !torch.tensor +// CHECK: } +func @aten_matmul_broadcast_vector(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.tensor { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor + return %0 : !torch.tensor +}