mirror of https://github.com/llvm/torch-mlir
[MHLO] Init MHLO linear op patterns (#1132)
See RFC https://github.com/llvm/torch-mlir/issues/999 Co-authored-by: Bairen Yi yibairen.byron@bytedance.com Co-authored-by: Jiawei Wu xremold@gmail.com Co-authored-by: Tianyou Guo tianyou.gty@alibaba-inc.com Co-authored-by: Xu Yan yancey.yx@alibaba-inc.com Co-authored-by: Ziheng Jiang ziheng.jiang@bytedance.compull/1152/head
parent
48ec300586
commit
f0a24f59f6
|
@ -3,6 +3,7 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo
|
|||
MhloLegalizeUtils.cpp
|
||||
BasicOp.cpp
|
||||
GatherOp.cpp
|
||||
Linear.cpp
|
||||
ViewLikeOps.cpp
|
||||
ReductionOp.cpp
|
||||
|
||||
|
|
|
@ -0,0 +1,405 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "./MhloLegalizeUtils.h"
|
||||
#include "./PopulatePatterns.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
namespace {
|
||||
Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
|
||||
ArrayRef<int64_t> shape, ArrayRef<Value> dimSizes,
|
||||
ArrayRef<int64_t> broadcastDims) {
|
||||
auto tensorTy = tensor.getType().dyn_cast<RankedTensorType>();
|
||||
auto loc = op->getLoc();
|
||||
Value mhloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
|
||||
|
||||
RankedTensorType outTy =
|
||||
RankedTensorType::get(shape, tensorTy.getElementType());
|
||||
|
||||
RankedTensorType attrTy =
|
||||
RankedTensorType::get({static_cast<int64_t>(broadcastDims.size())},
|
||||
rewriter.getIntegerType(64));
|
||||
auto broadcastAttr = DenseIntElementsAttr::get(attrTy, broadcastDims);
|
||||
|
||||
auto broadcast = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
|
||||
loc, outTy, tensor, mhloShape, broadcastAttr);
|
||||
return broadcast;
|
||||
}
|
||||
|
||||
Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input,
|
||||
ArrayRef<int64_t> inpTransDims) {
|
||||
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
auto rank = inputTy.getRank();
|
||||
auto transDims = mhlo::toPositiveDims(inpTransDims, rank);
|
||||
auto inpShape = inputTy.getShape();
|
||||
std::vector<int64_t> newShape;
|
||||
newShape.reserve(rank);
|
||||
|
||||
for (auto d : transDims) {
|
||||
newShape.push_back(inpShape[d]);
|
||||
}
|
||||
|
||||
auto attrTy = RankedTensorType::get({static_cast<int64_t>(transDims.size())},
|
||||
rewriter.getIntegerType(64));
|
||||
auto permuteAttr = DenseIntElementsAttr::get(attrTy, transDims);
|
||||
|
||||
auto outTy = RankedTensorType::get(newShape, inputTy.getElementType());
|
||||
auto result = rewriter.create<mhlo::TransposeOp>(op->getLoc(), outTy, input,
|
||||
permuteAttr);
|
||||
return result.getResult();
|
||||
}
|
||||
|
||||
void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
|
||||
Value &inpRhs, int64_t leadingRank) {
|
||||
Value lhs = inpLhs;
|
||||
Value rhs = inpRhs;
|
||||
auto lhsRankTy = inpLhs.getType().dyn_cast<RankedTensorType>();
|
||||
auto rhsRankTy = inpRhs.getType().dyn_cast<RankedTensorType>();
|
||||
|
||||
auto lhsRank = lhsRankTy.getRank();
|
||||
auto rhsRank = rhsRankTy.getRank();
|
||||
|
||||
// The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be
|
||||
// broadcastable).
|
||||
auto minRank = std::min(lhsRank, rhsRank);
|
||||
auto leadingDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, leadingRank));
|
||||
auto broadcastDims = llvm::to_vector<4>(
|
||||
llvm::seq<int64_t>(leadingRank, minRank + leadingRank));
|
||||
auto lhsShape = lhsRankTy.getShape();
|
||||
auto rhsShape = rhsRankTy.getShape();
|
||||
if (lhsRank < rhsRank) {
|
||||
std::vector<int64_t> newShape(rhsShape.begin(),
|
||||
rhsShape.begin() + leadingRank);
|
||||
newShape.insert(newShape.end(), lhsShape.begin(), lhsShape.end());
|
||||
auto newDimSizes =
|
||||
*mhlo::getDimSizesOfTensor(rewriter, op, rhs, leadingDims);
|
||||
auto lhsDimSizes = *mhlo::getDimSizesOfTensor(rewriter, op, lhs);
|
||||
newDimSizes.insert(newDimSizes.end(), lhsDimSizes.begin(),
|
||||
lhsDimSizes.end());
|
||||
lhs = getBroadcastTensor(rewriter, op, lhs, newShape, newDimSizes,
|
||||
broadcastDims);
|
||||
} else {
|
||||
std::vector<int64_t> newShape(lhsShape.begin(),
|
||||
lhsShape.begin() + leadingRank);
|
||||
newShape.insert(newShape.end(), rhsShape.begin(), rhsShape.end());
|
||||
auto newDimSizes =
|
||||
*mhlo::getDimSizesOfTensor(rewriter, op, lhs, leadingDims);
|
||||
auto rhsDimSizes = *mhlo::getDimSizesOfTensor(rewriter, op, rhs);
|
||||
newDimSizes.insert(newDimSizes.end(), rhsDimSizes.begin(),
|
||||
rhsDimSizes.end());
|
||||
rhs = getBroadcastTensor(rewriter, op, rhs, newShape, newDimSizes,
|
||||
broadcastDims);
|
||||
}
|
||||
|
||||
inpLhs = lhs;
|
||||
inpRhs = rhs;
|
||||
}
|
||||
|
||||
// Perform the basic n-dim matmul operation encompassing the handling of
|
||||
// broadcasting and dynamic shape propagation.
|
||||
// All PyTorch ops that leverage matrix multiplication will derive this and
|
||||
// implement their specialized input processing (e.g transpose), and output
|
||||
// processing, e.g. GEMM or fully connected bias handling.
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenMatmulBaseOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
// Each variant must implement corresponding parameter parsing options.
|
||||
// Maintain separate input read functions for each variant because it is not
|
||||
// necessarily true with all variants that the first two operands are the lhs
|
||||
// and rhs.
|
||||
virtual LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Value &lhs, Value &rhs) const {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"unimplemented matrix multiplication variant input parsing function");
|
||||
}
|
||||
LogicalResult performMatmul(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Value &lhs,
|
||||
Value &rhs, Value &output) const {
|
||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
||||
|
||||
auto lhsRank = lhsTy.getRank();
|
||||
auto rhsRank = rhsTy.getRank();
|
||||
auto lhsElemTy = lhsTy.getElementType();
|
||||
auto rhsElemTy = rhsTy.getElementType();
|
||||
|
||||
if (lhsElemTy != rhsElemTy)
|
||||
return op.emitError("matmul: input datatypes mismatched");
|
||||
if (lhsRank < 1 || rhsRank < 1) {
|
||||
return op.emitError("matmul: inputs can't be 0-rank");
|
||||
}
|
||||
|
||||
if (lhsRank <= 2 && rhsRank <= 2) {
|
||||
output = rewriter.create<mhlo::DotOp>(op->getLoc(), lhs, rhs, nullptr);
|
||||
return success();
|
||||
}
|
||||
|
||||
int64_t nBatchDims;
|
||||
if (rhsRank <= 2) {
|
||||
auto leadingRank = lhsRank - 2;
|
||||
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank);
|
||||
nBatchDims = leadingRank;
|
||||
} else if (lhsRank <= 2) {
|
||||
auto leadingRank = rhsRank - 2;
|
||||
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank);
|
||||
nBatchDims = leadingRank;
|
||||
} else {
|
||||
assert(rhsRank > 2 && lhsRank > 2);
|
||||
auto leadingRank = std::max(lhsRank - rhsRank, rhsRank - lhsRank);
|
||||
nBatchDims = std::max(lhsRank - 2, rhsRank - 2);
|
||||
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank);
|
||||
}
|
||||
auto batchDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, nBatchDims));
|
||||
auto lhsContractingDim = nBatchDims + 1;
|
||||
auto rhsContractingDim = nBatchDims;
|
||||
if (lhsRank == 1)
|
||||
lhsContractingDim = nBatchDims;
|
||||
|
||||
mhlo::DotDimensionNumbersAttr dotDimensionNumbers =
|
||||
mhlo::DotDimensionNumbersAttr::get(
|
||||
rewriter.getContext(),
|
||||
/*lhsBatchingDimensions=*/batchDims,
|
||||
/*rhsBatchingDimensions=*/batchDims,
|
||||
/*lhsContractingDimensions=*/{lhsContractingDim},
|
||||
/*rhsContractingDimensions=*/{rhsContractingDim});
|
||||
auto resultTy = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<RankedTensorType>();
|
||||
|
||||
output = rewriter
|
||||
.create<mhlo::DotGeneralOp>(op->getLoc(), resultTy, lhs, rhs,
|
||||
dotDimensionNumbers, nullptr)
|
||||
.getResult();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// The default version just reads two inputs, computes output and returns it.
|
||||
// Other versions may add a bias, apply GEMM-style alpha/beta scaling etc.
|
||||
virtual LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value lhs, rhs;
|
||||
if (failed(readMatMulInputs(op, adaptor, rewriter, lhs, rhs)))
|
||||
return op.emitError("failed to read matmul inputs");
|
||||
|
||||
Value output;
|
||||
|
||||
if (failed(performMatmul(op, adaptor, rewriter, lhs, rhs, output)))
|
||||
return op.emitError("failed to perform matmul operation");
|
||||
|
||||
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<RankedTensorType>(),
|
||||
output);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Legalizes the torch.matmul op for general n-dim matmul.
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenMatMulOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
|
||||
public:
|
||||
using ConvertAtenMatmulBaseOp<AtenOpT>::ConvertAtenMatmulBaseOp;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Value &lhs, Value &rhs) const override {
|
||||
lhs = adaptor.self();
|
||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
||||
|
||||
rhs = adaptor.other();
|
||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
||||
|
||||
if (!lhsTy || !rhsTy)
|
||||
return op.emitError(
|
||||
"only ranked tensor types are supported in MHLO matmul");
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Implements handling of aten.mm and aten.bmm ops.
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenMmOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
|
||||
public:
|
||||
using ConvertAtenMatmulBaseOp<AtenOpT>::ConvertAtenMatmulBaseOp;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Value &lhs, Value &rhs) const override {
|
||||
lhs = adaptor.self();
|
||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
||||
|
||||
rhs = adaptor.mat2();
|
||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
||||
|
||||
if (!lhsTy || !rhsTy)
|
||||
return op.emitError(
|
||||
"only ranked tensor types are supported in MHLO matmul");
|
||||
|
||||
auto lhsRank = lhsTy.getRank();
|
||||
auto rhsRank = rhsTy.getRank();
|
||||
|
||||
if (isa<AtenMmOp>(op)) {
|
||||
// Mm takes two 2D tensors.
|
||||
if (lhsRank != 2 || rhsRank != 2)
|
||||
return op.emitError("aten.mm called but matrix rank != 2");
|
||||
} else if (isa<AtenBmmOp>(op)) {
|
||||
// Bmm takes two 3D tensors.
|
||||
if (lhsRank != 3 || rhsRank != 3)
|
||||
return op.emitError("aten.bmm called but matrix rank != 3");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Implements handling of aten.linear op.
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
|
||||
public:
|
||||
using ConvertAtenMatmulBaseOp<AtenOpT>::ConvertAtenMatmulBaseOp;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Value &lhs, Value &rhs) const override {
|
||||
lhs = adaptor.input();
|
||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
||||
|
||||
rhs = adaptor.weight();
|
||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
||||
|
||||
if (!lhsTy || !rhsTy)
|
||||
return op.emitError(
|
||||
"only ranked tensor types are supported in MHLO matmul");
|
||||
|
||||
auto lhsRank = lhsTy.getRank();
|
||||
auto rhsRank = rhsTy.getRank();
|
||||
|
||||
if (lhsRank != 2 && lhsRank != 3)
|
||||
return op.emitError("aten.Linear called but input rank not 2 or 3");
|
||||
if (rhsRank != 2 && rhsRank != 3)
|
||||
return op.emitError("aten.Linear called but weight rank not 2 or 3");
|
||||
|
||||
return success();
|
||||
}
|
||||
// Override the default rewriter to perform RHS transpose and bias addition
|
||||
// as well.
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value lhs, rhs;
|
||||
|
||||
if (failed(readMatMulInputs(op, adaptor, rewriter, lhs, rhs)))
|
||||
return op.emitError("failed to read matmul inputs");
|
||||
|
||||
// The aten.Linear op has a bias tensor that is added to the matmul
|
||||
// output.
|
||||
auto bias = adaptor.bias();
|
||||
auto biasTy = bias.getType();
|
||||
|
||||
// MHLO does not mandate that elementwise op tensors need to be ranked.
|
||||
if (!biasTy.template isa<Torch::NoneType>() &&
|
||||
!biasTy.template isa<RankedTensorType>())
|
||||
return op.emitError("only ranked tensor types are supported in MHLO "
|
||||
"matmul for bias tensor");
|
||||
|
||||
// weight.T
|
||||
rhs = getPermutedTensor(rewriter, op, rhs, {1, 0});
|
||||
|
||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
||||
auto leadingRank = std::max(lhsTy.getRank() - rhsTy.getRank(),
|
||||
rhsTy.getRank() - lhsTy.getRank());
|
||||
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank);
|
||||
auto resultRank = std::max(lhsTy.getRank(), rhsTy.getRank());
|
||||
auto nBatchDims = resultRank - 2;
|
||||
auto batchDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, nBatchDims));
|
||||
auto lhsContractingDim = nBatchDims + 1;
|
||||
auto rhsContractingDim = nBatchDims;
|
||||
|
||||
mhlo::DotDimensionNumbersAttr dotDimensionNumbers =
|
||||
mhlo::DotDimensionNumbersAttr::get(
|
||||
rewriter.getContext(),
|
||||
/*lhsBatchingDimensions=*/batchDims,
|
||||
/*rhsBatchingDimensions=*/batchDims,
|
||||
/*lhsContractingDimensions=*/{lhsContractingDim},
|
||||
/*rhsContractingDimensions=*/{rhsContractingDim});
|
||||
|
||||
auto resultTy =
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType());
|
||||
|
||||
Value matmulOutput = rewriter.create<mhlo::DotGeneralOp>(
|
||||
op->getLoc(), resultTy, lhs, rhs, dotDimensionNumbers, nullptr);
|
||||
|
||||
Value matmulPlusBias = matmulOutput;
|
||||
if (!biasTy.template isa<Torch::NoneType>()) {
|
||||
// Bias addition broadcasts to the matmul output shape.
|
||||
matmulPlusBias =
|
||||
rewriter
|
||||
.create<chlo::BroadcastAddOp>(op->getLoc(), resultTy,
|
||||
matmulOutput, bias, nullptr)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(op, resultTy, matmulPlusBias);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::torch::torch_to_mhlo::populateLinearOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
|
||||
#define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenMatMulOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp);
|
||||
#undef INSERT_MATMUL_ATEMOP_PATTERN
|
||||
|
||||
#define INSERT_MM_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenMmOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_MM_ATENOP_PATTERN(AtenMmOp);
|
||||
INSERT_MM_ATENOP_PATTERN(AtenBmmOp);
|
||||
#undef INSERT_MM_ATEMOP_PATTERN
|
||||
|
||||
#define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenLinearOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp);
|
||||
#undef INSERT_LINEAR_ATEMOP_PATTERN
|
||||
}
|
|
@ -28,6 +28,10 @@ void populateGatherOpPatternsAndLegality(TypeConverter &typeConverter,
|
|||
void populateReductionOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target);
|
||||
void populateLinearOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target);
|
||||
|
||||
} // namespace torch_to_mhlo
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
|
||||
#include "../PassDetail.h"
|
||||
#include "./PopulatePatterns.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
|
@ -60,6 +61,8 @@ public:
|
|||
target);
|
||||
torch_to_mhlo::populateReductionOpPatternsAndLegality(typeConverter,
|
||||
patterns, target);
|
||||
torch_to_mhlo::populateLinearOpPatternsAndLegality(typeConverter, patterns,
|
||||
target);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns)))) {
|
||||
|
@ -73,4 +76,4 @@ public:
|
|||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::createConvertTorchToMhloPass() {
|
||||
return std::make_unique<ConvertTorchToMhlo>();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,268 @@
|
|||
// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.mm$basic$static(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[3,3],f32>) -> !torch.vtensor<[2,3],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,3],f32> -> tensor<3x3xf32>
|
||||
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<2x3xf32>, tensor<3x3xf32>) -> tensor<2x3xf32>
|
||||
// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor<2x3xf32>
|
||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
|
||||
// CHECK: return %[[T4]] : !torch.vtensor<[2,3],f32>
|
||||
func.func @torch.aten.mm$basic$static(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !torch.vtensor<[3,3],f32>) -> !torch.vtensor<[2,3],f32> {
|
||||
%0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,3],f32> -> !torch.vtensor<[2,3],f32>
|
||||
return %0 : !torch.vtensor<[2,3],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.mm$basic$dynamic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[3,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,3],f32> -> tensor<?x3xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,?],f32> -> tensor<3x?xf32>
|
||||
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<?x3xf32>, tensor<3x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T4]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.mm$basic$dynamic(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vtensor<[3,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,3],f32>, !torch.vtensor<[3,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.bmm$basic$static(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[10,3,4],f32>, %[[ARG1:.*]]: !torch.vtensor<[10,4,5],f32>) -> !torch.vtensor<[10,3,5],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[10,3,4],f32> -> tensor<10x3x4xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[10,4,5],f32> -> tensor<10x4x5xf32>
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor<10x4x5xf32>
|
||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<10x4x5xf32>
|
||||
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
|
||||
// CHECK: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor<10x4x5xf32>
|
||||
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
|
||||
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
||||
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<10x4x5xf32>, tensor<3xi64>) -> tensor<10x4x5xf32>
|
||||
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<10x3x4xf32>, tensor<10x4x5xf32>) -> tensor<10x3x5xf32>
|
||||
// CHECK: %[[T11:.*]] = mhlo.convert %[[T10]] : tensor<10x3x5xf32>
|
||||
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<10x3x5xf32> -> !torch.vtensor<[10,3,5],f32>
|
||||
// CHECK: return %[[T12]] : !torch.vtensor<[10,3,5],f32>
|
||||
func.func @torch.aten.bmm$basic$static(%arg0: !torch.vtensor<[10,3,4],f32>, %arg1: !torch.vtensor<[10,4,5],f32>) -> !torch.vtensor<[10,3,5],f32> {
|
||||
%0 = torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[10,3,4],f32>, !torch.vtensor<[10,4,5],f32> -> !torch.vtensor<[10,3,5],f32>
|
||||
return %0 : !torch.vtensor<[10,3,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.bmm$basic$dynamic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,4],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,4,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,4],f32> -> tensor<?x?x4xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,4,?],f32> -> tensor<?x4x?xf32>
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor<?x4x?xf32>
|
||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<?x4x?xf32>
|
||||
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
|
||||
// CHECK: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor<?x4x?xf32>
|
||||
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
|
||||
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
||||
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<?x4x?xf32>, tensor<3xi64>) -> tensor<?x4x?xf32>
|
||||
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<?x?x4xf32>, tensor<?x4x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[T11:.*]] = mhlo.convert %[[T10]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: return %[[T12]] : !torch.vtensor<[?,?,?],f32>
|
||||
func.func @torch.aten.bmm$basic$dynamic(%arg0: !torch.vtensor<[?,?,4],f32>, %arg1: !torch.vtensor<[?,4,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||
%0 = torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[?,?,4],f32>, !torch.vtensor<[?,4,?],f32> -> !torch.vtensor<[?,?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.matmul$basic$static(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256,120],f32>, %[[ARG1:.*]]: !torch.vtensor<[4,120,256],f32>) -> !torch.vtensor<[4,256,256],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256,120],f32> -> tensor<256x120xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[4,120,256],f32> -> tensor<4x120x256xf32>
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor<4x120x256xf32>
|
||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
||||
// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0_0]] : tensor<256x120xf32>
|
||||
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[T6:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<256x120xf32>
|
||||
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
|
||||
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
||||
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x120xf32>, tensor<3xi64>) -> tensor<4x256x120xf32>
|
||||
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T9]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<4x256x120xf32>, tensor<4x120x256xf32>) -> tensor<4x256x256xf32>
|
||||
// CHECK: %[[T11:.*]] = mhlo.convert %[[T10]] : tensor<4x256x256xf32>
|
||||
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x256x256xf32> -> !torch.vtensor<[4,256,256],f32>
|
||||
// CHECK: return %[[T12]] : !torch.vtensor<[4,256,256],f32>
|
||||
func.func @torch.aten.matmul$basic$static(%arg0: !torch.vtensor<[256,120],f32>, %arg1: !torch.vtensor<[4,120,256],f32>) -> !torch.vtensor<[4,256,256],f32> {
|
||||
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[256,120],f32>, !torch.vtensor<[4,120,256],f32> -> !torch.vtensor<[4,256,256],f32>
|
||||
return %0 : !torch.vtensor<[4,256,256],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.matmul$basic$dynamic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,?,256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[4,?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,?,256],f32> -> tensor<4x?x256xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256,?],f32> -> tensor<256x?xf32>
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<4x?x256xf32>
|
||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
||||
// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256x?xf32>
|
||||
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x?xf32>
|
||||
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
|
||||
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
||||
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x?xf32>, tensor<3xi64>) -> tensor<4x256x?xf32>
|
||||
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<4x?x256xf32>, tensor<4x256x?xf32>) -> tensor<4x?x?xf32>
|
||||
// CHECK: %[[T11:.*]] = mhlo.convert %[[T10]] : tensor<4x?x?xf32>
|
||||
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x?x?xf32> -> !torch.vtensor<[4,?,?],f32>
|
||||
// CHECK: return %[[T12]] : !torch.vtensor<[4,?,?],f32>
|
||||
func.func @torch.aten.matmul$basic$dynamic(%arg0: !torch.vtensor<[4,?,256],f32>, %arg1: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[4,?,?],f32> {
|
||||
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[4,?,256],f32>, !torch.vtensor<[256,?],f32> -> !torch.vtensor<[4,?,?],f32>
|
||||
return %0 : !torch.vtensor<[4,?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.matmul$3dx1d(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,?,256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[1,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,?,256],f32> -> tensor<1x?x256xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<1x?x256xf32>
|
||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
||||
// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256xf32>
|
||||
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
|
||||
// CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64>
|
||||
// CHECK: %[[T7:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32>
|
||||
// CHECK: %[[T8:.*]] = "mhlo.dot_general"(%[[T0]], %[[T7]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<1x?x256xf32>, tensor<1x256xf32>) -> tensor<1x?xf32>
|
||||
// CHECK: %[[T9:.*]] = mhlo.convert %[[T8]] : tensor<1x?xf32>
|
||||
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<1x?xf32> -> !torch.vtensor<[1,?],f32>
|
||||
// CHECK: return %[[T10]] : !torch.vtensor<[1,?],f32>
|
||||
func.func @torch.aten.matmul$3dx1d(%arg0: !torch.vtensor<[1,?,256],f32>, %arg1: !torch.vtensor<[256],f32>) -> !torch.vtensor<[1,?],f32> {
|
||||
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[1,?,256],f32>, !torch.vtensor<[256],f32> -> !torch.vtensor<[1,?],f32>
|
||||
return %0 : !torch.vtensor<[1,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.matmul$1dx3d(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,256,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,256,?],f32> -> tensor<?x256x?xf32>
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor<?x256x?xf32>
|
||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
||||
// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0_0]] : tensor<256xf32>
|
||||
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
|
||||
// CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64>
|
||||
// CHECK: %[[T7:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, tensor<2xi64>) -> tensor<?x256xf32>
|
||||
// CHECK: %[[T8:.*]] = "mhlo.dot_general"(%[[T7]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [1]>} : (tensor<?x256xf32>, tensor<?x256x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[T9:.*]] = mhlo.convert %[[T8]] : tensor<?x?xf32>
|
||||
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[T10]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.matmul$1dx3d(%arg0: !torch.vtensor<[256],f32>, %arg1: !torch.vtensor<[?,256,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[256],f32>, !torch.vtensor<[?,256,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.matmul$2dx1d(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor<?x256xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
|
||||
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<?x256xf32>, tensor<256xf32>) -> tensor<?xf32>
|
||||
// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor<?xf32>
|
||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
|
||||
// CHECK: return %[[T4]] : !torch.vtensor<[?],f32>
|
||||
func.func @torch.aten.matmul$2dx1d(%arg0: !torch.vtensor<[?,256],f32>, %arg1: !torch.vtensor<[256],f32>) -> !torch.vtensor<[?],f32> {
|
||||
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,256],f32>, !torch.vtensor<[256],f32> -> !torch.vtensor<[?],f32>
|
||||
return %0 : !torch.vtensor<[?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.matmul$1dx2d(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256,?],f32> -> tensor<256x?xf32>
|
||||
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<256xf32>, tensor<256x?xf32>) -> tensor<?xf32>
|
||||
// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor<?xf32>
|
||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
|
||||
// CHECK: return %[[T4]] : !torch.vtensor<[?],f32>
|
||||
func.func @torch.aten.matmul$1dx2d(%arg0: !torch.vtensor<[256],f32>, %arg1: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[?],f32> {
|
||||
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[256],f32>, !torch.vtensor<[256,?],f32> -> !torch.vtensor<[?],f32>
|
||||
return %0 : !torch.vtensor<[?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.matmul$1dx1d(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
|
||||
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<256xf32>, tensor<256xf32>) -> tensor<f32>
|
||||
// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor<f32>
|
||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: return %[[T4]] : !torch.vtensor<[],f32>
|
||||
func.func @torch.aten.matmul$1dx1d(%arg0: !torch.vtensor<[256],f32>, %arg1: !torch.vtensor<[256],f32>) -> !torch.vtensor<[],f32> {
|
||||
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[256],f32>, !torch.vtensor<[256],f32> -> !torch.vtensor<[],f32>
|
||||
return %0 : !torch.vtensor<[],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.matmul$proj(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,256],f32>) -> !torch.vtensor<[?,?,256],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,256],f32> -> tensor<?x?x256xf32>
|
||||
// CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor<256x256xf32>
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x256xf32>
|
||||
// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64
|
||||
// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256x256xf32>
|
||||
// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x256xf32>
|
||||
// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64
|
||||
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
||||
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x256xf32>, tensor<3xi64>) -> tensor<?x256x256xf32>
|
||||
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<?x?x256xf32>, tensor<?x256x256xf32>) -> tensor<?x?x256xf32>
|
||||
// CHECK: %[[T11:.*]] = mhlo.convert %[[T10]] : tensor<?x?x256xf32>
|
||||
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<?x?x256xf32> -> !torch.vtensor<[?,?,256],f32>
|
||||
// CHECK: return %[[T12]] : !torch.vtensor<[?,?,256],f32>
|
||||
func.func @torch.aten.matmul$proj(%arg0: !torch.vtensor<[?,?,256],f32>) -> !torch.vtensor<[?,?,256],f32> {
|
||||
%0 = torch.vtensor.literal(dense<1.000000e+00> : tensor<256x256xf32>) : !torch.vtensor<[256,256],f32>
|
||||
%1 = torch.aten.matmul %arg0, %0 : !torch.vtensor<[?,?,256],f32>, !torch.vtensor<[256,256],f32> -> !torch.vtensor<[?,?,256],f32>
|
||||
return %1 : !torch.vtensor<[?,?,256],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.mm$proj(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,256],f32>) -> !torch.vtensor<[?,256],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor<?x256xf32>
|
||||
// CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor<256x256xf32>
|
||||
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<?x256xf32>, tensor<256x256xf32>) -> tensor<?x256xf32>
|
||||
// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor<?x256xf32>
|
||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x256xf32> -> !torch.vtensor<[?,256],f32>
|
||||
// CHECK: return %[[T4]] : !torch.vtensor<[?,256],f32>
|
||||
func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vtensor<[?,256],f32> {
|
||||
%0 = torch.vtensor.literal(dense<1.000000e+00> : tensor<256x256xf32>) : !torch.vtensor<[256,256],f32>
|
||||
%1 = torch.aten.mm %arg0, %0 : !torch.vtensor<[?,256],f32>, !torch.vtensor<[256,256],f32> -> !torch.vtensor<[?,256],f32>
|
||||
return %1 : !torch.vtensor<[?,256],f32>
|
||||
}
|
||||
|
|
@ -450,6 +450,7 @@ cc_library(
|
|||
"lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp",
|
||||
"lib/Conversion/TorchToMhlo/BasicOp.cpp",
|
||||
"lib/Conversion/TorchToMhlo/GatherOp.cpp",
|
||||
"lib/Conversion/TorchToMhlo/Linear.cpp",
|
||||
"lib/Conversion/TorchToMhlo/ViewLikeOps.cpp",
|
||||
"lib/Conversion/TorchToMhlo/ReductionOp.cpp",
|
||||
"lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h",
|
||||
|
|
Loading…
Reference in New Issue