From f0a24f59f6354df5a254779caaaaaf6918a4f038 Mon Sep 17 00:00:00 2001 From: Tanyo Kwok Date: Thu, 4 Aug 2022 10:10:54 +0800 Subject: [PATCH] [MHLO] Init MHLO linear op patterns (#1132) See RFC https://github.com/llvm/torch-mlir/issues/999 Co-authored-by: Bairen Yi yibairen.byron@bytedance.com Co-authored-by: Jiawei Wu xremold@gmail.com Co-authored-by: Tianyou Guo tianyou.gty@alibaba-inc.com Co-authored-by: Xu Yan yancey.yx@alibaba-inc.com Co-authored-by: Ziheng Jiang ziheng.jiang@bytedance.com --- lib/Conversion/TorchToMhlo/CMakeLists.txt | 1 + lib/Conversion/TorchToMhlo/Linear.cpp | 405 ++++++++++++++++++ lib/Conversion/TorchToMhlo/PopulatePatterns.h | 4 + lib/Conversion/TorchToMhlo/TorchToMhlo.cpp | 5 +- test/Conversion/TorchToMhlo/linear.mlir | 268 ++++++++++++ utils/bazel/torch-mlir-overlay/BUILD.bazel | 1 + 6 files changed, 683 insertions(+), 1 deletion(-) create mode 100644 lib/Conversion/TorchToMhlo/Linear.cpp create mode 100644 test/Conversion/TorchToMhlo/linear.mlir diff --git a/lib/Conversion/TorchToMhlo/CMakeLists.txt b/lib/Conversion/TorchToMhlo/CMakeLists.txt index 3c036e7ef..47126ab8d 100644 --- a/lib/Conversion/TorchToMhlo/CMakeLists.txt +++ b/lib/Conversion/TorchToMhlo/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo MhloLegalizeUtils.cpp BasicOp.cpp GatherOp.cpp + Linear.cpp ViewLikeOps.cpp ReductionOp.cpp diff --git a/lib/Conversion/TorchToMhlo/Linear.cpp b/lib/Conversion/TorchToMhlo/Linear.cpp new file mode 100644 index 000000000..d9475e8ef --- /dev/null +++ b/lib/Conversion/TorchToMhlo/Linear.cpp @@ -0,0 +1,405 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" + +#include "../PassDetail.h" +#include "./MhloLegalizeUtils.h" +#include "./PopulatePatterns.h" +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "torch-mlir/Conversion/Utils/Utils.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace { +Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor, + ArrayRef shape, ArrayRef dimSizes, + ArrayRef broadcastDims) { + auto tensorTy = tensor.getType().dyn_cast(); + auto loc = op->getLoc(); + Value mhloShape = rewriter.create(loc, dimSizes); + + RankedTensorType outTy = + RankedTensorType::get(shape, tensorTy.getElementType()); + + RankedTensorType attrTy = + RankedTensorType::get({static_cast(broadcastDims.size())}, + rewriter.getIntegerType(64)); + auto broadcastAttr = DenseIntElementsAttr::get(attrTy, broadcastDims); + + auto broadcast = rewriter.create( + loc, outTy, tensor, mhloShape, broadcastAttr); + return broadcast; +} + +Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input, + ArrayRef inpTransDims) { + auto inputTy = input.getType().dyn_cast(); + auto rank = inputTy.getRank(); + auto transDims = mhlo::toPositiveDims(inpTransDims, rank); + auto inpShape = inputTy.getShape(); + std::vector newShape; + newShape.reserve(rank); + + for (auto d : transDims) { + newShape.push_back(inpShape[d]); + } + + auto attrTy = RankedTensorType::get({static_cast(transDims.size())}, + rewriter.getIntegerType(64)); + auto permuteAttr = DenseIntElementsAttr::get(attrTy, transDims); + + auto outTy = RankedTensorType::get(newShape, inputTy.getElementType()); + auto result = rewriter.create(op->getLoc(), outTy, input, + permuteAttr); + return result.getResult(); +} + +void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs, + Value &inpRhs, int64_t leadingRank) { + Value lhs = inpLhs; + Value rhs = inpRhs; + auto lhsRankTy = inpLhs.getType().dyn_cast(); + auto rhsRankTy = inpRhs.getType().dyn_cast(); + + auto lhsRank = lhsRankTy.getRank(); + auto rhsRank = rhsRankTy.getRank(); + + // The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be + // broadcastable). + auto minRank = std::min(lhsRank, rhsRank); + auto leadingDims = llvm::to_vector<4>(llvm::seq(0, leadingRank)); + auto broadcastDims = llvm::to_vector<4>( + llvm::seq(leadingRank, minRank + leadingRank)); + auto lhsShape = lhsRankTy.getShape(); + auto rhsShape = rhsRankTy.getShape(); + if (lhsRank < rhsRank) { + std::vector newShape(rhsShape.begin(), + rhsShape.begin() + leadingRank); + newShape.insert(newShape.end(), lhsShape.begin(), lhsShape.end()); + auto newDimSizes = + *mhlo::getDimSizesOfTensor(rewriter, op, rhs, leadingDims); + auto lhsDimSizes = *mhlo::getDimSizesOfTensor(rewriter, op, lhs); + newDimSizes.insert(newDimSizes.end(), lhsDimSizes.begin(), + lhsDimSizes.end()); + lhs = getBroadcastTensor(rewriter, op, lhs, newShape, newDimSizes, + broadcastDims); + } else { + std::vector newShape(lhsShape.begin(), + lhsShape.begin() + leadingRank); + newShape.insert(newShape.end(), rhsShape.begin(), rhsShape.end()); + auto newDimSizes = + *mhlo::getDimSizesOfTensor(rewriter, op, lhs, leadingDims); + auto rhsDimSizes = *mhlo::getDimSizesOfTensor(rewriter, op, rhs); + newDimSizes.insert(newDimSizes.end(), rhsDimSizes.begin(), + rhsDimSizes.end()); + rhs = getBroadcastTensor(rewriter, op, rhs, newShape, newDimSizes, + broadcastDims); + } + + inpLhs = lhs; + inpRhs = rhs; +} + +// Perform the basic n-dim matmul operation encompassing the handling of +// broadcasting and dynamic shape propagation. +// All PyTorch ops that leverage matrix multiplication will derive this and +// implement their specialized input processing (e.g transpose), and output +// processing, e.g. GEMM or fully connected bias handling. +template +class ConvertAtenMatmulBaseOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + // Each variant must implement corresponding parameter parsing options. + // Maintain separate input read functions for each variant because it is not + // necessarily true with all variants that the first two operands are the lhs + // and rhs. + virtual LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Value &lhs, Value &rhs) const { + return rewriter.notifyMatchFailure( + op, + "unimplemented matrix multiplication variant input parsing function"); + } + LogicalResult performMatmul(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Value &lhs, + Value &rhs, Value &output) const { + auto lhsTy = lhs.getType().cast(); + auto rhsTy = rhs.getType().cast(); + + auto lhsRank = lhsTy.getRank(); + auto rhsRank = rhsTy.getRank(); + auto lhsElemTy = lhsTy.getElementType(); + auto rhsElemTy = rhsTy.getElementType(); + + if (lhsElemTy != rhsElemTy) + return op.emitError("matmul: input datatypes mismatched"); + if (lhsRank < 1 || rhsRank < 1) { + return op.emitError("matmul: inputs can't be 0-rank"); + } + + if (lhsRank <= 2 && rhsRank <= 2) { + output = rewriter.create(op->getLoc(), lhs, rhs, nullptr); + return success(); + } + + int64_t nBatchDims; + if (rhsRank <= 2) { + auto leadingRank = lhsRank - 2; + getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank); + nBatchDims = leadingRank; + } else if (lhsRank <= 2) { + auto leadingRank = rhsRank - 2; + getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank); + nBatchDims = leadingRank; + } else { + assert(rhsRank > 2 && lhsRank > 2); + auto leadingRank = std::max(lhsRank - rhsRank, rhsRank - lhsRank); + nBatchDims = std::max(lhsRank - 2, rhsRank - 2); + getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank); + } + auto batchDims = llvm::to_vector<4>(llvm::seq(0, nBatchDims)); + auto lhsContractingDim = nBatchDims + 1; + auto rhsContractingDim = nBatchDims; + if (lhsRank == 1) + lhsContractingDim = nBatchDims; + + mhlo::DotDimensionNumbersAttr dotDimensionNumbers = + mhlo::DotDimensionNumbersAttr::get( + rewriter.getContext(), + /*lhsBatchingDimensions=*/batchDims, + /*rhsBatchingDimensions=*/batchDims, + /*lhsContractingDimensions=*/{lhsContractingDim}, + /*rhsContractingDimensions=*/{rhsContractingDim}); + auto resultTy = OpConversionPattern::getTypeConverter() + ->convertType(op.getType()) + .template cast(); + + output = rewriter + .create(op->getLoc(), resultTy, lhs, rhs, + dotDimensionNumbers, nullptr) + .getResult(); + + return success(); + } + + // The default version just reads two inputs, computes output and returns it. + // Other versions may add a bias, apply GEMM-style alpha/beta scaling etc. + virtual LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value lhs, rhs; + if (failed(readMatMulInputs(op, adaptor, rewriter, lhs, rhs))) + return op.emitError("failed to read matmul inputs"); + + Value output; + + if (failed(performMatmul(op, adaptor, rewriter, lhs, rhs, output))) + return op.emitError("failed to perform matmul operation"); + + rewriter.replaceOpWithNewOp( + op, + OpConversionPattern::getTypeConverter() + ->convertType(op.getType()) + .template cast(), + output); + + return success(); + } +}; + +// Legalizes the torch.matmul op for general n-dim matmul. +template +class ConvertAtenMatMulOp : public ConvertAtenMatmulBaseOp { +public: + using ConvertAtenMatmulBaseOp::ConvertAtenMatmulBaseOp; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Value &lhs, Value &rhs) const override { + lhs = adaptor.self(); + auto lhsTy = lhs.getType().cast(); + + rhs = adaptor.other(); + auto rhsTy = rhs.getType().cast(); + + if (!lhsTy || !rhsTy) + return op.emitError( + "only ranked tensor types are supported in MHLO matmul"); + + return success(); + } +}; + +// Implements handling of aten.mm and aten.bmm ops. +template +class ConvertAtenMmOp : public ConvertAtenMatmulBaseOp { +public: + using ConvertAtenMatmulBaseOp::ConvertAtenMatmulBaseOp; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Value &lhs, Value &rhs) const override { + lhs = adaptor.self(); + auto lhsTy = lhs.getType().cast(); + + rhs = adaptor.mat2(); + auto rhsTy = rhs.getType().cast(); + + if (!lhsTy || !rhsTy) + return op.emitError( + "only ranked tensor types are supported in MHLO matmul"); + + auto lhsRank = lhsTy.getRank(); + auto rhsRank = rhsTy.getRank(); + + if (isa(op)) { + // Mm takes two 2D tensors. + if (lhsRank != 2 || rhsRank != 2) + return op.emitError("aten.mm called but matrix rank != 2"); + } else if (isa(op)) { + // Bmm takes two 3D tensors. + if (lhsRank != 3 || rhsRank != 3) + return op.emitError("aten.bmm called but matrix rank != 3"); + } + + return success(); + } +}; + +// Implements handling of aten.linear op. +template +class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { +public: + using ConvertAtenMatmulBaseOp::ConvertAtenMatmulBaseOp; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Value &lhs, Value &rhs) const override { + lhs = adaptor.input(); + auto lhsTy = lhs.getType().cast(); + + rhs = adaptor.weight(); + auto rhsTy = rhs.getType().cast(); + + if (!lhsTy || !rhsTy) + return op.emitError( + "only ranked tensor types are supported in MHLO matmul"); + + auto lhsRank = lhsTy.getRank(); + auto rhsRank = rhsTy.getRank(); + + if (lhsRank != 2 && lhsRank != 3) + return op.emitError("aten.Linear called but input rank not 2 or 3"); + if (rhsRank != 2 && rhsRank != 3) + return op.emitError("aten.Linear called but weight rank not 2 or 3"); + + return success(); + } + // Override the default rewriter to perform RHS transpose and bias addition + // as well. + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value lhs, rhs; + + if (failed(readMatMulInputs(op, adaptor, rewriter, lhs, rhs))) + return op.emitError("failed to read matmul inputs"); + + // The aten.Linear op has a bias tensor that is added to the matmul + // output. + auto bias = adaptor.bias(); + auto biasTy = bias.getType(); + + // MHLO does not mandate that elementwise op tensors need to be ranked. + if (!biasTy.template isa() && + !biasTy.template isa()) + return op.emitError("only ranked tensor types are supported in MHLO " + "matmul for bias tensor"); + + // weight.T + rhs = getPermutedTensor(rewriter, op, rhs, {1, 0}); + + auto lhsTy = lhs.getType().cast(); + auto rhsTy = rhs.getType().cast(); + auto leadingRank = std::max(lhsTy.getRank() - rhsTy.getRank(), + rhsTy.getRank() - lhsTy.getRank()); + getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank); + auto resultRank = std::max(lhsTy.getRank(), rhsTy.getRank()); + auto nBatchDims = resultRank - 2; + auto batchDims = llvm::to_vector<4>(llvm::seq(0, nBatchDims)); + auto lhsContractingDim = nBatchDims + 1; + auto rhsContractingDim = nBatchDims; + + mhlo::DotDimensionNumbersAttr dotDimensionNumbers = + mhlo::DotDimensionNumbersAttr::get( + rewriter.getContext(), + /*lhsBatchingDimensions=*/batchDims, + /*rhsBatchingDimensions=*/batchDims, + /*lhsContractingDimensions=*/{lhsContractingDim}, + /*rhsContractingDimensions=*/{rhsContractingDim}); + + auto resultTy = + OpConversionPattern::getTypeConverter()->convertType( + op.getType()); + + Value matmulOutput = rewriter.create( + op->getLoc(), resultTy, lhs, rhs, dotDimensionNumbers, nullptr); + + Value matmulPlusBias = matmulOutput; + if (!biasTy.template isa()) { + // Bias addition broadcasts to the matmul output shape. + matmulPlusBias = + rewriter + .create(op->getLoc(), resultTy, + matmulOutput, bias, nullptr) + .getResult(); + } + + rewriter.replaceOpWithNewOp(op, resultTy, matmulPlusBias); + return success(); + } +}; + +} // namespace + +void mlir::torch::torch_to_mhlo::populateLinearOpPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target) { + MLIRContext *context = patterns.getContext(); + +#define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp); +#undef INSERT_MATMUL_ATEMOP_PATTERN + +#define INSERT_MM_ATENOP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_MM_ATENOP_PATTERN(AtenMmOp); + INSERT_MM_ATENOP_PATTERN(AtenBmmOp); +#undef INSERT_MM_ATEMOP_PATTERN + +#define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp); +#undef INSERT_LINEAR_ATEMOP_PATTERN +} diff --git a/lib/Conversion/TorchToMhlo/PopulatePatterns.h b/lib/Conversion/TorchToMhlo/PopulatePatterns.h index c84cec638..2ff569cd0 100644 --- a/lib/Conversion/TorchToMhlo/PopulatePatterns.h +++ b/lib/Conversion/TorchToMhlo/PopulatePatterns.h @@ -28,6 +28,10 @@ void populateGatherOpPatternsAndLegality(TypeConverter &typeConverter, void populateReductionOpPatternsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target); +void populateLinearOpPatternsAndLegality(TypeConverter &typeConverter, + RewritePatternSet &patterns, + ConversionTarget &target); + } // namespace torch_to_mhlo } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp b/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp index 5007a8a26..a8314c7cc 100644 --- a/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp +++ b/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp @@ -11,6 +11,7 @@ #include "../PassDetail.h" #include "./PopulatePatterns.h" +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" @@ -60,6 +61,8 @@ public: target); torch_to_mhlo::populateReductionOpPatternsAndLegality(typeConverter, patterns, target); + torch_to_mhlo::populateLinearOpPatternsAndLegality(typeConverter, patterns, + target); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { @@ -73,4 +76,4 @@ public: std::unique_ptr> mlir::torch::createConvertTorchToMhloPass() { return std::make_unique(); -} \ No newline at end of file +} diff --git a/test/Conversion/TorchToMhlo/linear.mlir b/test/Conversion/TorchToMhlo/linear.mlir new file mode 100644 index 000000000..18ea97654 --- /dev/null +++ b/test/Conversion/TorchToMhlo/linear.mlir @@ -0,0 +1,268 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.mm$basic$static( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[3,3],f32>) -> !torch.vtensor<[2,3],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32> +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,3],f32> -> tensor<3x3xf32> +// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<2x3xf32>, tensor<3x3xf32>) -> tensor<2x3xf32> +// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor<2x3xf32> +// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32> +// CHECK: return %[[T4]] : !torch.vtensor<[2,3],f32> +func.func @torch.aten.mm$basic$static(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !torch.vtensor<[3,3],f32>) -> !torch.vtensor<[2,3],f32> { + %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,3],f32> -> !torch.vtensor<[2,3],f32> + return %0 : !torch.vtensor<[2,3],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.mm$basic$dynamic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[3,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,3],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,?],f32> -> tensor<3x?xf32> +// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor, tensor<3x?xf32>) -> tensor +// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor +// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T4]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.mm$basic$dynamic(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vtensor<[3,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,3],f32>, !torch.vtensor<[3,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.bmm$basic$static( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[10,3,4],f32>, %[[ARG1:.*]]: !torch.vtensor<[10,4,5],f32>) -> !torch.vtensor<[10,3,5],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[10,3,4],f32> -> tensor<10x3x4xf32> +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[10,4,5],f32> -> tensor<10x4x5xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor<10x4x5xf32> +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<10x4x5xf32> +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor<10x4x5xf32> +// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 +// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> +// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<10x4x5xf32>, tensor<3xi64>) -> tensor<10x4x5xf32> +// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<10x3x4xf32>, tensor<10x4x5xf32>) -> tensor<10x3x5xf32> +// CHECK: %[[T11:.*]] = mhlo.convert %[[T10]] : tensor<10x3x5xf32> +// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<10x3x5xf32> -> !torch.vtensor<[10,3,5],f32> +// CHECK: return %[[T12]] : !torch.vtensor<[10,3,5],f32> +func.func @torch.aten.bmm$basic$static(%arg0: !torch.vtensor<[10,3,4],f32>, %arg1: !torch.vtensor<[10,4,5],f32>) -> !torch.vtensor<[10,3,5],f32> { + %0 = torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[10,3,4],f32>, !torch.vtensor<[10,4,5],f32> -> !torch.vtensor<[10,3,5],f32> + return %0 : !torch.vtensor<[10,3,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.bmm$basic$dynamic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,4],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,4,?],f32>) -> !torch.vtensor<[?,?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,4],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,4,?],f32> -> tensor +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor +// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 +// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> +// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor, tensor<3xi64>) -> tensor +// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor) -> tensor +// CHECK: %[[T11:.*]] = mhlo.convert %[[T10]] : tensor +// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor -> !torch.vtensor<[?,?,?],f32> +// CHECK: return %[[T12]] : !torch.vtensor<[?,?,?],f32> +func.func @torch.aten.bmm$basic$dynamic(%arg0: !torch.vtensor<[?,?,4],f32>, %arg1: !torch.vtensor<[?,4,?],f32>) -> !torch.vtensor<[?,?,?],f32> { + %0 = torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[?,?,4],f32>, !torch.vtensor<[?,4,?],f32> -> !torch.vtensor<[?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.matmul$basic$static( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256,120],f32>, %[[ARG1:.*]]: !torch.vtensor<[4,120,256],f32>) -> !torch.vtensor<[4,256,256],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256,120],f32> -> tensor<256x120xf32> +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[4,120,256],f32> -> tensor<4x120x256xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor<4x120x256xf32> +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0_0]] : tensor<256x120xf32> +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T6:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<256x120xf32> +// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 +// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> +// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x120xf32>, tensor<3xi64>) -> tensor<4x256x120xf32> +// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T9]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<4x256x120xf32>, tensor<4x120x256xf32>) -> tensor<4x256x256xf32> +// CHECK: %[[T11:.*]] = mhlo.convert %[[T10]] : tensor<4x256x256xf32> +// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x256x256xf32> -> !torch.vtensor<[4,256,256],f32> +// CHECK: return %[[T12]] : !torch.vtensor<[4,256,256],f32> +func.func @torch.aten.matmul$basic$static(%arg0: !torch.vtensor<[256,120],f32>, %arg1: !torch.vtensor<[4,120,256],f32>) -> !torch.vtensor<[4,256,256],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[256,120],f32>, !torch.vtensor<[4,120,256],f32> -> !torch.vtensor<[4,256,256],f32> + return %0 : !torch.vtensor<[4,256,256],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.matmul$basic$dynamic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,?,256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[4,?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,?,256],f32> -> tensor<4x?x256xf32> +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256,?],f32> -> tensor<256x?xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<4x?x256xf32> +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256x?xf32> +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x?xf32> +// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 +// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> +// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x?xf32>, tensor<3xi64>) -> tensor<4x256x?xf32> +// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<4x?x256xf32>, tensor<4x256x?xf32>) -> tensor<4x?x?xf32> +// CHECK: %[[T11:.*]] = mhlo.convert %[[T10]] : tensor<4x?x?xf32> +// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x?x?xf32> -> !torch.vtensor<[4,?,?],f32> +// CHECK: return %[[T12]] : !torch.vtensor<[4,?,?],f32> +func.func @torch.aten.matmul$basic$dynamic(%arg0: !torch.vtensor<[4,?,256],f32>, %arg1: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[4,?,?],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[4,?,256],f32>, !torch.vtensor<[256,?],f32> -> !torch.vtensor<[4,?,?],f32> + return %0 : !torch.vtensor<[4,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.matmul$3dx1d( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,?,256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[1,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,?,256],f32> -> tensor<1x?x256xf32> +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<1x?x256xf32> +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256xf32> +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64> +// CHECK: %[[T7:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> +// CHECK: %[[T8:.*]] = "mhlo.dot_general"(%[[T0]], %[[T7]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<1x?x256xf32>, tensor<1x256xf32>) -> tensor<1x?xf32> +// CHECK: %[[T9:.*]] = mhlo.convert %[[T8]] : tensor<1x?xf32> +// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<1x?xf32> -> !torch.vtensor<[1,?],f32> +// CHECK: return %[[T10]] : !torch.vtensor<[1,?],f32> +func.func @torch.aten.matmul$3dx1d(%arg0: !torch.vtensor<[1,?,256],f32>, %arg1: !torch.vtensor<[256],f32>) -> !torch.vtensor<[1,?],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[1,?,256],f32>, !torch.vtensor<[256],f32> -> !torch.vtensor<[1,?],f32> + return %0 : !torch.vtensor<[1,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.matmul$1dx3d( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,256,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32> +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,256,?],f32> -> tensor +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0_0]] : tensor<256xf32> +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64> +// CHECK: %[[T7:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, tensor<2xi64>) -> tensor +// CHECK: %[[T8:.*]] = "mhlo.dot_general"(%[[T7]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor) -> tensor +// CHECK: %[[T9:.*]] = mhlo.convert %[[T8]] : tensor +// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T10]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.matmul$1dx3d(%arg0: !torch.vtensor<[256],f32>, %arg1: !torch.vtensor<[?,256,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[256],f32>, !torch.vtensor<[?,256,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.matmul$2dx1d( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32> +// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor, tensor<256xf32>) -> tensor +// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor +// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?],f32> +// CHECK: return %[[T4]] : !torch.vtensor<[?],f32> +func.func @torch.aten.matmul$2dx1d(%arg0: !torch.vtensor<[?,256],f32>, %arg1: !torch.vtensor<[256],f32>) -> !torch.vtensor<[?],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,256],f32>, !torch.vtensor<[256],f32> -> !torch.vtensor<[?],f32> + return %0 : !torch.vtensor<[?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.matmul$1dx2d( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32> +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256,?],f32> -> tensor<256x?xf32> +// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<256xf32>, tensor<256x?xf32>) -> tensor +// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor +// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?],f32> +// CHECK: return %[[T4]] : !torch.vtensor<[?],f32> +func.func @torch.aten.matmul$1dx2d(%arg0: !torch.vtensor<[256],f32>, %arg1: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[?],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[256],f32>, !torch.vtensor<[256,?],f32> -> !torch.vtensor<[?],f32> + return %0 : !torch.vtensor<[?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.matmul$1dx1d( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32> +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32> +// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<256xf32>, tensor<256xf32>) -> tensor +// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor +// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[],f32> +// CHECK: return %[[T4]] : !torch.vtensor<[],f32> +func.func @torch.aten.matmul$1dx1d(%arg0: !torch.vtensor<[256],f32>, %arg1: !torch.vtensor<[256],f32>) -> !torch.vtensor<[],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[256],f32>, !torch.vtensor<[256],f32> -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.matmul$proj( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,256],f32>) -> !torch.vtensor<[?,?,256],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,256],f32> -> tensor +// CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor<256x256xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 +// CHECK: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256x256xf32> +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x256xf32> +// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 +// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> +// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x256xf32>, tensor<3xi64>) -> tensor +// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor) -> tensor +// CHECK: %[[T11:.*]] = mhlo.convert %[[T10]] : tensor +// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor -> !torch.vtensor<[?,?,256],f32> +// CHECK: return %[[T12]] : !torch.vtensor<[?,?,256],f32> +func.func @torch.aten.matmul$proj(%arg0: !torch.vtensor<[?,?,256],f32>) -> !torch.vtensor<[?,?,256],f32> { + %0 = torch.vtensor.literal(dense<1.000000e+00> : tensor<256x256xf32>) : !torch.vtensor<[256,256],f32> + %1 = torch.aten.matmul %arg0, %0 : !torch.vtensor<[?,?,256],f32>, !torch.vtensor<[256,256],f32> -> !torch.vtensor<[?,?,256],f32> + return %1 : !torch.vtensor<[?,?,256],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.mm$proj( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,256],f32>) -> !torch.vtensor<[?,256],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor +// CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor<256x256xf32> +// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor, tensor<256x256xf32>) -> tensor +// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor +// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?,256],f32> +// CHECK: return %[[T4]] : !torch.vtensor<[?,256],f32> +func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vtensor<[?,256],f32> { + %0 = torch.vtensor.literal(dense<1.000000e+00> : tensor<256x256xf32>) : !torch.vtensor<[256,256],f32> + %1 = torch.aten.mm %arg0, %0 : !torch.vtensor<[?,256],f32>, !torch.vtensor<[256,256],f32> -> !torch.vtensor<[?,256],f32> + return %1 : !torch.vtensor<[?,256],f32> +} + diff --git a/utils/bazel/torch-mlir-overlay/BUILD.bazel b/utils/bazel/torch-mlir-overlay/BUILD.bazel index 2ce5e9b7d..d28545e95 100644 --- a/utils/bazel/torch-mlir-overlay/BUILD.bazel +++ b/utils/bazel/torch-mlir-overlay/BUILD.bazel @@ -450,6 +450,7 @@ cc_library( "lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp", "lib/Conversion/TorchToMhlo/BasicOp.cpp", "lib/Conversion/TorchToMhlo/GatherOp.cpp", + "lib/Conversion/TorchToMhlo/Linear.cpp", "lib/Conversion/TorchToMhlo/ViewLikeOps.cpp", "lib/Conversion/TorchToMhlo/ReductionOp.cpp", "lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h",