2022-07-25 23:47:46 +08:00
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
//
|
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
|
|
|
//
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
|
|
|
|
|
|
|
|
|
#include "../PassDetail.h"
|
2022-09-25 22:07:46 +08:00
|
|
|
|
#include "./MhloLegalizeUtils.h"
|
2022-07-25 23:47:46 +08:00
|
|
|
|
#include "./PopulatePatterns.h"
|
|
|
|
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
2022-10-05 21:28:06 +08:00
|
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
2022-07-25 23:47:46 +08:00
|
|
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
|
|
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
|
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
|
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
|
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
|
|
|
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
|
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
|
using namespace mlir::torch;
|
|
|
|
|
using namespace mlir::torch::Torch;
|
2022-09-01 10:36:02 +08:00
|
|
|
|
using namespace mlir::torch::torch_to_mhlo;
|
2022-07-25 23:47:46 +08:00
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
|
2022-09-01 10:36:02 +08:00
|
|
|
|
Value input, Value indices, int64_t axis,
|
|
|
|
|
size_t dimSizeIndexBits) {
|
2022-07-25 23:47:46 +08:00
|
|
|
|
auto loc = op->getLoc();
|
2022-09-01 10:36:02 +08:00
|
|
|
|
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
|
2022-07-25 23:47:46 +08:00
|
|
|
|
Value one = rewriter.create<arith::ConstantOp>(
|
|
|
|
|
loc, rewriter.getIntegerAttr(intType, 1));
|
|
|
|
|
|
|
|
|
|
// sliceSizes
|
|
|
|
|
auto inputRankTy = input.getType().dyn_cast<RankedTensorType>();
|
|
|
|
|
auto inputRank = inputRankTy.getRank();
|
|
|
|
|
SmallVector<Value, 4> sliceSizes;
|
|
|
|
|
sliceSizes.reserve(inputRank);
|
|
|
|
|
for (int64_t r = 0; r < inputRank; ++r) {
|
|
|
|
|
if (r == axis) {
|
|
|
|
|
sliceSizes.push_back(one);
|
|
|
|
|
} else {
|
|
|
|
|
sliceSizes.push_back(rewriter.create<arith::IndexCastOp>(
|
|
|
|
|
loc, intType, rewriter.create<tensor::DimOp>(loc, input, r)));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto sliceSizesTensor =
|
|
|
|
|
rewriter.create<tensor::FromElementsOp>(loc, sliceSizes);
|
|
|
|
|
|
|
|
|
|
// offsetDims
|
|
|
|
|
SmallVector<int64_t, 4> offsetDims;
|
|
|
|
|
offsetDims.reserve(inputRank);
|
|
|
|
|
for (int64_t r = 0; r < axis; ++r) {
|
|
|
|
|
offsetDims.push_back(r);
|
|
|
|
|
}
|
|
|
|
|
auto indicesRankTy = indices.getType().dyn_cast<RankedTensorType>();
|
|
|
|
|
auto indicesRank = indicesRankTy.getRank();
|
|
|
|
|
for (int64_t r = axis + 1; r < inputRank; ++r) {
|
|
|
|
|
offsetDims.push_back(r + indicesRank - 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// collapsedSliceDims
|
|
|
|
|
SmallVector<int64_t, 4> collapsedSliceDims(1, axis);
|
|
|
|
|
// startIndexMap
|
|
|
|
|
SmallVector<int64_t, 4> startIndexMap(1, axis);
|
|
|
|
|
// indexVecDim
|
|
|
|
|
int64_t indexVecDim = indicesRank;
|
|
|
|
|
auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get(
|
|
|
|
|
rewriter.getContext(),
|
|
|
|
|
/*offsetDims=*/offsetDims,
|
|
|
|
|
/*collapsedSliceDims=*/collapsedSliceDims,
|
|
|
|
|
/*startIndexMap=*/startIndexMap,
|
|
|
|
|
/*indexVecDim=*/indexVecDim);
|
|
|
|
|
|
|
|
|
|
// outputShape = input.shape[:axis] + indices.shape +
|
|
|
|
|
// input.shape[axis + 1:]
|
|
|
|
|
auto inputShape = inputRankTy.getShape();
|
|
|
|
|
auto indicesShape = indicesRankTy.getShape();
|
|
|
|
|
SmallVector<int64_t, 4> outputShape(inputShape.begin(),
|
|
|
|
|
inputShape.begin() + axis);
|
|
|
|
|
outputShape.insert(outputShape.end(), indicesShape.begin(),
|
|
|
|
|
indicesShape.end());
|
|
|
|
|
outputShape.insert(outputShape.end(), inputShape.begin() + axis + 1,
|
|
|
|
|
inputShape.end());
|
|
|
|
|
|
|
|
|
|
// create output tensor type
|
|
|
|
|
auto outputTy =
|
|
|
|
|
RankedTensorType::get(outputShape, inputRankTy.getElementType());
|
|
|
|
|
return rewriter
|
|
|
|
|
.create<mhlo::DynamicGatherOp>(loc, outputTy, input, indices,
|
|
|
|
|
sliceSizesTensor, dimsAttr)
|
|
|
|
|
.getResult();
|
|
|
|
|
}
|
2022-09-01 10:36:02 +08:00
|
|
|
|
} // namespace
|
2022-07-25 23:47:46 +08:00
|
|
|
|
|
|
|
|
|
// Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html
|
|
|
|
|
// padding_idx (int, optional)
|
|
|
|
|
// – If specified, the entries at padding_idx do not contribute to the gradient;
|
|
|
|
|
// therefore, the embedding vector at padding_idx is not updated during training,
|
|
|
|
|
// i.e. it remains as a fixed “pad”.
|
|
|
|
|
// scale_grad_by_freq (boolean, optional)
|
|
|
|
|
// – If given, this will scale gradients by the inverse of frequency of the
|
|
|
|
|
// words in the mini-batch. Default False.
|
|
|
|
|
// sparse (bool, optional)
|
|
|
|
|
// – If True, gradient w.r.t. weight matrix will be a sparse tensor.
|
|
|
|
|
template <>
|
|
|
|
|
LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
|
|
|
|
|
AtenEmbeddingOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
auto weight = adaptor.weight();
|
|
|
|
|
auto weightTy = weight.getType().template cast<RankedTensorType>();
|
|
|
|
|
if (!weightTy)
|
|
|
|
|
return op.emitError("only ranked tensor types are supported");
|
|
|
|
|
|
|
|
|
|
int64_t padding_idx;
|
|
|
|
|
if (!matchPattern(op.padding_idx(), m_TorchConstantInt(&padding_idx)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only constant padding_idx is currently supported");
|
|
|
|
|
|
|
|
|
|
bool scale_grad_by_freq;
|
|
|
|
|
if (!matchPattern(op.scale_grad_by_freq(),
|
|
|
|
|
m_TorchConstantBool(&scale_grad_by_freq)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only constant scale_grad_by_freq is currently supported");
|
|
|
|
|
if (scale_grad_by_freq)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "scale gradients is currently not supported");
|
|
|
|
|
bool sparse;
|
|
|
|
|
if (!matchPattern(op.sparse(), m_TorchConstantBool(&sparse)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only constant sparse is currently supported");
|
|
|
|
|
if (sparse)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "sparse gradients is currently not supported");
|
|
|
|
|
|
2022-09-01 10:36:02 +08:00
|
|
|
|
Value output = gatherTensorAlongSingleAxis(
|
|
|
|
|
rewriter, op, weight, adaptor.indices(), 0, options.dimSizeIndexBits);
|
2022-07-25 23:47:46 +08:00
|
|
|
|
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
|
|
|
|
|
op, getTypeConverter()->convertType(op.getType()), output);
|
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
|
|
|
|
|
AtenIndexSelectOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
auto self = adaptor.self();
|
|
|
|
|
auto selfTy = self.getType().template cast<RankedTensorType>();
|
|
|
|
|
if (!selfTy)
|
|
|
|
|
return op.emitError("only ranked tensor types are supported");
|
|
|
|
|
int64_t dim;
|
|
|
|
|
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only constant dim is currently supported");
|
|
|
|
|
|
2022-09-01 10:36:02 +08:00
|
|
|
|
Value output = gatherTensorAlongSingleAxis(
|
|
|
|
|
rewriter, op, self, adaptor.index(), dim, options.dimSizeIndexBits);
|
2022-07-25 23:47:46 +08:00
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
|
|
|
|
|
op, getTypeConverter()->convertType(op.getType()), output);
|
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
2022-09-25 22:07:46 +08:00
|
|
|
|
// AtenGatherOp
|
|
|
|
|
template <>
|
|
|
|
|
LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
|
|
|
|
|
AtenGatherOp op, OpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
Location loc = op->getLoc();
|
|
|
|
|
Value input = adaptor.self();
|
|
|
|
|
Value index = adaptor.index();
|
|
|
|
|
auto inputType = input.getType().cast<RankedTensorType>();
|
|
|
|
|
auto indexType = index.getType().cast<RankedTensorType>();
|
|
|
|
|
auto indexElemType = indexType.getElementType();
|
|
|
|
|
|
|
|
|
|
if (indexType.getRank() != inputType.getRank()) {
|
|
|
|
|
return op.emitError("`index` and `input` param should have the same rank");
|
|
|
|
|
}
|
|
|
|
|
int64_t dim;
|
|
|
|
|
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only constant int `dim` param supported");
|
|
|
|
|
}
|
|
|
|
|
dim = toPositiveDim(dim, inputType.getRank());
|
|
|
|
|
if (!isValidDim(dim, inputType.getRank())) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "invalid `dim` param detected");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool sparseGrad = false;
|
|
|
|
|
if (!matchPattern(op.sparse_grad(), m_TorchConstantBool(&sparseGrad))) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "only constant boolean `sparse_grad` param supported");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto options = getOptions();
|
|
|
|
|
auto indexShapeInfo =
|
|
|
|
|
mhlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits);
|
|
|
|
|
if (failed(indexShapeInfo)) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "failed to get dim sizes of `index` param");
|
|
|
|
|
}
|
|
|
|
|
auto intType = rewriter.getIntegerType(options.dimSizeIndexBits);
|
|
|
|
|
auto one = rewriter.create<arith::ConstantOp>(
|
|
|
|
|
loc, rewriter.getIntegerAttr(intType, 1));
|
|
|
|
|
auto toConcatIndexShapeValueVec = *indexShapeInfo;
|
|
|
|
|
toConcatIndexShapeValueVec.push_back(one);
|
|
|
|
|
auto toConcatIndexShape =
|
|
|
|
|
rewriter.create<tensor::FromElementsOp>(loc, toConcatIndexShapeValueVec);
|
|
|
|
|
|
|
|
|
|
auto indexShape = indexType.getShape();
|
|
|
|
|
SmallVector<int64_t> toConcatIndexShapeVec(indexShape.begin(),
|
|
|
|
|
indexShape.end());
|
|
|
|
|
toConcatIndexShapeVec.push_back(1);
|
|
|
|
|
RankedTensorType toConcatIndexType =
|
|
|
|
|
RankedTensorType::get(toConcatIndexShapeVec, indexElemType);
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> toConcat;
|
|
|
|
|
for (int64_t i = 0; i < inputType.getRank(); ++i) {
|
|
|
|
|
if (i == dim) {
|
|
|
|
|
toConcat.push_back(rewriter.create<mhlo::DynamicReshapeOp>(
|
|
|
|
|
loc, toConcatIndexType, index, toConcatIndexShape));
|
|
|
|
|
} else {
|
|
|
|
|
toConcat.push_back(rewriter.create<mhlo::DynamicIotaOp>(
|
|
|
|
|
loc, toConcatIndexType, toConcatIndexShape,
|
|
|
|
|
rewriter.getI64IntegerAttr(i)));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto gatherIndicies = rewriter.create<mhlo::ConcatenateOp>(
|
|
|
|
|
loc, toConcat, static_cast<uint64_t>(inputType.getRank()));
|
|
|
|
|
SmallVector<int64_t> sliceSizes(inputType.getRank(), 1);
|
|
|
|
|
|
|
|
|
|
int64_t indexVecDim = inputType.getRank();
|
|
|
|
|
SmallVector<int64_t> collapsedDims;
|
|
|
|
|
SmallVector<int64_t> startIndexMap;
|
|
|
|
|
for (int64_t i = 0; i < inputType.getRank(); ++i) {
|
|
|
|
|
collapsedDims.push_back(i);
|
|
|
|
|
startIndexMap.push_back(i);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get(
|
|
|
|
|
rewriter.getContext(),
|
|
|
|
|
/*offsetDims=*/{},
|
|
|
|
|
/*collapsedSliceDims=*/collapsedDims,
|
|
|
|
|
/*startIndexMap=*/startIndexMap,
|
|
|
|
|
/*indexVecDim=*/indexVecDim);
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<mhlo::GatherOp>(
|
|
|
|
|
op, input, gatherIndicies, dimsAttr,
|
|
|
|
|
rewriter.getI64TensorAttr(sliceSizes));
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
2022-07-25 23:47:46 +08:00
|
|
|
|
void mlir::torch::torch_to_mhlo::populateGatherOpPatternsAndLegality(
|
|
|
|
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
2022-09-01 10:36:02 +08:00
|
|
|
|
ConversionTarget &target, const TorchToMhloOptions &options) {
|
2022-07-25 23:47:46 +08:00
|
|
|
|
MLIRContext *context = patterns.getContext();
|
|
|
|
|
|
|
|
|
|
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
|
|
|
|
target.addIllegalOp<AtenOp>(); \
|
2022-09-01 10:36:02 +08:00
|
|
|
|
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
|
2022-07-25 23:47:46 +08:00
|
|
|
|
INSERT_ATENOP_PATTERN(AtenEmbeddingOp);
|
|
|
|
|
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
|
2022-09-25 22:07:46 +08:00
|
|
|
|
INSERT_ATENOP_PATTERN(AtenGatherOp);
|
2022-07-25 23:47:46 +08:00
|
|
|
|
#undef INSERT_ATENOP_PATTERN
|
|
|
|
|
}
|