[MHLO] refactor pass configurations (#1315)

Related to https://github.com/llvm/torch-mlir/issues/1227

1. Reduce MHLO #ifdefs
2. Dismiss compilation warnings
pull/1321/head
Tanyo Kwok 2022-09-01 10:36:02 +08:00 committed by GitHub
parent 7769eb88f8
commit 29cafdbb61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 373 additions and 380 deletions

View File

@ -39,14 +39,6 @@ endmacro()
option(TORCH_MLIR_ENABLE_MHLO "Add mhlo dialect" ON)
if(TORCH_MLIR_ENABLE_MHLO)
add_definitions(-DTORCH_MLIR_ENABLE_MHLO)
# The i64 calculation is much slower than i32 on some devices, such as Nvidia GPU.
# One can truncate from i64 to i32 since dimension sizes are unlikely to exceed
# the range of i32(4GiB)
option(TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
"Enable truncate dimension size from i64 to i32(unsafely)" OFF)
if(TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32)
add_definitions(-DTORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32)
endif()
endif()
option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF)

View File

@ -132,6 +132,17 @@ def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> {
Convert Torch ops to mhlo ops.
}];
let constructor = "mlir::torch::createConvertTorchToMhloPass()";
// Specify any options.
let options = [
Option<"enableStaticShape", "enable-static-shape", "bool", /*default=*/"false",
"Enable static shape conversion">,
// The i64 calculation is much slower than i32 on some devices, such as
// Nvidia GPU. One can truncate from i64 to i32 since dimension sizes
// are unlikely to exceed the range of i32(4GiB)
Option<"enableI32Index", "enable-i32-index", "bool", /*default=*/"false",
"Enable truncate index from i64 to i32(unsafely)">,
];
}
#endif

View File

@ -17,6 +17,8 @@
namespace mlir {
namespace torch {
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToMhloPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchToMhloPass(bool enableStaticShape, bool enableI32Index);
} // namespace torch
} // namespace mlir

View File

@ -29,6 +29,7 @@
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_mhlo;
bool skipMultiplyAlpha(Value alphaValue) {
double doubleValue;
@ -379,63 +380,58 @@ public:
} // namespace
// AtenBroadcastToOp
namespace {
class ConvertAtenBroadcastToOp : public OpConversionPattern<AtenBroadcastToOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenBroadcastToOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value self = adaptor.self();
auto selfTy = self.getType().cast<RankedTensorType>();
auto outType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
template <>
LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
AtenBroadcastToOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value self = adaptor.self();
auto selfTy = self.getType().cast<RankedTensorType>();
auto outType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
#ifdef TORCH_MLIR_ENABLE_MHLO_STATIC_SHAPE
if (selfTy.hasStaticShape()) {
Value bcastOp = mhlo::promoteAndBroadcast(rewriter, self, outType);
rewriter.replaceOp(op, bcastOp);
return success();
if (options.enableStaticShape && selfTy.hasStaticShape()) {
Value bcastOp = mhlo::promoteAndBroadcast(rewriter, self, outType);
rewriter.replaceOp(op, bcastOp);
return success();
}
SmallVector<Value> shape;
if (!(getListConstructElements(adaptor.size(), shape))) {
return op->emitError("desired shape must be a list of scalar");
}
SmallVector<Value> bcastShapeVec;
int64_t totalRank = shape.size();
int64_t selfRank = selfTy.getRank();
int64_t leadingRank = totalRank - selfRank;
for (int64_t i = 0; i < totalRank; ++i) {
Value dValue = shape[i];
Value newD;
int64_t dInt;
if (!(matchPattern(dValue, m_TorchConstantInt(&dInt)))) {
return op->emitError("element of desired shape must be a scalar");
}
#endif
SmallVector<Value> shape;
if (!(getListConstructElements(adaptor.size(), shape))) {
return op->emitError("desired shape must be a list of scalar");
if (i >= leadingRank && dInt == -1) {
newD = rewriter.create<mlir::tensor::DimOp>(op->getLoc(), self,
i - leadingRank);
} else {
dValue = rewriter.create<torch::TorchConversion::ToI64Op>(op->getLoc(),
dValue);
newD = rewriter.create<mlir::arith::IndexCastOp>(
op->getLoc(), rewriter.getIndexType(), dValue);
}
SmallVector<Value> bcastShapeVec;
int64_t totalRank = shape.size();
int64_t selfRank = selfTy.getRank();
int64_t leadingRank = totalRank - selfRank;
bcastShapeVec.push_back(newD);
}
for (int64_t i = 0; i < totalRank; ++i) {
Value dValue = shape[i];
Value newD;
int64_t dInt;
if (!(matchPattern(dValue, m_TorchConstantInt(&dInt)))) {
return op->emitError("element of desired shape must be a scalar");
}
if (i >= leadingRank && dInt == -1) {
newD = rewriter.create<mlir::tensor::DimOp>(op->getLoc(), self,
i - leadingRank);
} else {
dValue = rewriter.create<torch::TorchConversion::ToI64Op>(op->getLoc(),
dValue);
newD = rewriter.create<mlir::arith::IndexCastOp>(
op->getLoc(), rewriter.getIndexType(), dValue);
}
bcastShapeVec.push_back(newD);
}
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
if (options.dimSizeIndexBits == 32) {
for (auto &dsize : bcastShapeVec) {
auto dsizeI64 = rewriter.create<mlir::arith::IndexCastOp>(
op->getLoc(), rewriter.getI64Type(), dsize);
dsize = rewriter.create<arith::TruncIOp>(op->getLoc(),
rewriter.getI32Type(), dsizeI64);
}
#endif
}
Value bcastShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), ValueRange{bcastShapeVec});
@ -445,66 +441,45 @@ public:
op, outType, self, bcastShapeTensor,
rewriter.getI64TensorAttr(dimensionNumbers));
return success();
}
};
} // namespace
}
// AtenPermuteOp
namespace {
class ConvertAtenPermuteOp : public OpConversionPattern<AtenPermuteOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenPermuteOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value self = adaptor.self();
// Not a ranked tensor type
auto inType = self.getType().dyn_cast<RankedTensorType>();
auto outType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
if (!inType)
return op.emitError("only ranked tensor types with static shapes are "
"currently supported");
template <>
LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
AtenPermuteOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value self = adaptor.self();
// Not a ranked tensor type
auto inType = self.getType().dyn_cast<RankedTensorType>();
auto outType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
if (!inType)
return op.emitError("only ranked tensor types with static shapes are "
"currently supported");
SmallVector<int64_t> permValues;
if (!matchPattern(adaptor.dims(), m_TorchConstantIntList(permValues)))
return rewriter.notifyMatchFailure(
op, "only constant dimensions are currently supported");
SmallVector<int64_t> permValues;
if (!matchPattern(adaptor.dims(), m_TorchConstantIntList(permValues)))
return rewriter.notifyMatchFailure(
op, "only constant dimensions are currently supported");
int64_t inRank = inType.getRank();
for (auto &d : permValues) {
d = toPositiveDim(d, inRank);
if (!isValidDim(d, inRank))
return op.emitError("not all dims are valid");
}
DenseIntElementsAttr permutation = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<long int>(permValues.size())},
rewriter.getI64Type()),
permValues);
rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(op, outType, self,
permutation);
return success();
int64_t inRank = inType.getRank();
for (auto &d : permValues) {
d = toPositiveDim(d, inRank);
if (!isValidDim(d, inRank))
return op.emitError("not all dims are valid");
}
};
} // namespace
namespace {
template <typename AtenOpT>
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
DenseIntElementsAttr permutation = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<long int>(permValues.size())},
rewriter.getI64Type()),
permValues);
rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(op, outType, self,
permutation);
return success();
}
// AtenTanhOp
namespace {
template <>
LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
AtenTanhOp op, OpAdaptor adaptor,
@ -520,10 +495,8 @@ LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
"only floating-point datatype legalization currently supported");
}
}
} // namespace
// ValueTensorLiteralOp
namespace {
template <>
LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
ValueTensorLiteralOp op, OpAdaptor adaptor,
@ -553,11 +526,9 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
return success();
}
} // namespace
// AtenReciprocalOp
// Reciprocal(x) = Div(1, x)
namespace {
template <>
LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
AtenReciprocalOp op, OpAdaptor adaptor,
@ -575,10 +546,8 @@ LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
rewriter.replaceOpWithNewOp<mhlo::DivOp>(op, outTy, oneTensor, input);
return success();
}
} // namespace
// PrimNumToTensorScalarOp
namespace {
template <>
LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
PrimNumToTensorScalarOp op, OpAdaptor adaptor,
@ -592,11 +561,9 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
rewriter.replaceOp(op, mhloTensor);
return success();
}
} // namespace
// AtenContiguousOp
// Ref: TosaToTosa.cpp for implementation details
namespace {
template <>
LogicalResult ConvertAtenOp<AtenContiguousOp>::matchAndRewrite(
AtenContiguousOp op, OpAdaptor adaptor,
@ -614,11 +581,9 @@ LogicalResult ConvertAtenOp<AtenContiguousOp>::matchAndRewrite(
return success();
}
} // namespace
// AtenReluOp
// Relu(x) = Max(0, x)
namespace {
template <>
LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
AtenReluOp op, OpAdaptor adaptor,
@ -641,11 +606,9 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
return success();
}
} // namespace
// Convert a Aten::GELU to HLO
// Gelu(x) = x * 1/2 * [1 + erf(x/(sqrt(2)))]
namespace {
template <>
LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
AtenGeluOp op, OpAdaptor adaptor,
@ -668,10 +631,8 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, input, halfMul);
return success();
}
} // namespace
// AtenErfOp
namespace {
template <>
LogicalResult ConvertAtenOp<AtenErfOp>::matchAndRewrite(
AtenErfOp op, OpAdaptor adaptor,
@ -686,10 +647,8 @@ LogicalResult ConvertAtenOp<AtenErfOp>::matchAndRewrite(
return success();
}
} // namespace
// AtenBatchNormOp
namespace {
template <>
LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
AtenBatchNormOp op, OpAdaptor adaptor,
@ -716,12 +675,12 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
Value channelDim = rewriter.create<tensor::DimOp>(op->getLoc(), input, 1);
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
auto channelDimI64 = rewriter.create<mlir::arith::IndexCastOp>(
op->getLoc(), rewriter.getI64Type(), channelDim);
channelDim = rewriter.create<arith::TruncIOp>(
op->getLoc(), rewriter.getI32Type(), channelDimI64);
#endif
if (options.dimSizeIndexBits == 32) {
auto channelDimI64 = rewriter.create<mlir::arith::IndexCastOp>(
op->getLoc(), rewriter.getI64Type(), channelDim);
channelDim = rewriter.create<arith::TruncIOp>(
op->getLoc(), rewriter.getI32Type(), channelDimI64);
}
Value channelShape = rewriter.create<tensor::FromElementsOp>(
op->getLoc(), ValueRange{channelDim});
@ -806,10 +765,8 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
}
}
} // namespace
// AtenNativeLayerNormOp
namespace {
template <>
LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
AtenNativeLayerNormOp op, OpAdaptor adaptor,
@ -949,10 +906,8 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
return success();
}
} // namespace
// AtenCatOp
namespace {
template <>
LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
AtenCatOp op, OpAdaptor adaptor,
@ -983,10 +938,8 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
op, outType, ValueRange(builtinTensors), posDim);
return success();
}
} // namespace
// AtenNumelOp
namespace {
template <>
LogicalResult ConvertAtenOp<AtenNumelOp>::matchAndRewrite(
AtenNumelOp op,
@ -996,7 +949,7 @@ LogicalResult ConvertAtenOp<AtenNumelOp>::matchAndRewrite(
auto selfTy = self.getType().dyn_cast<RankedTensorType>();
size_t rank = selfTy.getRank();
Type intType = rewriter.getIntegerType(mhlo::kMhloDimSizeBits);
Type intType = rewriter.getIntegerType(options.dimSizeIndexBits);
auto loc = op->getLoc();
Value numel =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(intType, 1));
@ -1015,26 +968,18 @@ LogicalResult ConvertAtenOp<AtenNumelOp>::matchAndRewrite(
}
return success();
}
} // namespace
void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
ConversionTarget &target, const TorchToMhloOptions &options) {
MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenTransposeIntOp>();
patterns.add<ConvertAtenTransposeIntOp>(typeConverter, context);
target.addIllegalOp<AtenBroadcastToOp>();
patterns.add<ConvertAtenBroadcastToOp>(typeConverter, context);
target.addIllegalOp<AtenPermuteOp>();
patterns.add<ConvertAtenPermuteOp>(typeConverter, context);
#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, MhloOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, MhloOp>>(typeConverter, \
context);
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, MhloOp>>(typeConverter, context)
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, mhlo::LogOp);
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, mhlo::ExpOp);
INSERT_UNARY_FPONLY_PATTERN(AtenCloneOp, mhlo::CopyOp);
@ -1045,14 +990,14 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenConstPatternOp<AtenOp, fillVal>>(typeConverter, \
context);
context)
INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1);
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
#undef INSERT_CONSTANT_FILL_PATTERN
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, ChloOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenAddSubOp<AtenOp, ChloOp>>(typeConverter, context);
patterns.add<ConvertAtenAddSubOp<AtenOp, ChloOp>>(typeConverter, context)
INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, chlo::BroadcastAddOp);
INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, chlo::BroadcastAddOp);
INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, chlo::BroadcastSubOp);
@ -1062,7 +1007,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
#define INSERT_BINARY_MULDIV_PATTERN(AtenOp, ChloOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenMulDivOp<AtenOp, ChloOp>>(typeConverter, context);
patterns.add<ConvertAtenMulDivOp<AtenOp, ChloOp>>(typeConverter, context)
INSERT_BINARY_MULDIV_PATTERN(AtenMulTensorOp, chlo::BroadcastMulOp);
INSERT_BINARY_MULDIV_PATTERN(AtenMulScalarOp, chlo::BroadcastMulOp);
INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorOp, chlo::BroadcastDivOp);
@ -1072,7 +1017,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
#define INSERT_BINARY_COMPARE_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenCompareOp<AtenOp>>(typeConverter, context);
patterns.add<ConvertAtenCompareOp<AtenOp>>(typeConverter, context)
INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp);
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp);
@ -1086,7 +1031,11 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
#define INSERT_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
INSERT_ATENOP_PATTERN(AtenPermuteOp);
INSERT_ATENOP_PATTERN(AtenTanhOp);
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
INSERT_ATENOP_PATTERN(AtenReciprocalOp);

View File

@ -23,18 +23,14 @@
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
static constexpr size_t kMhloDimSizeBits = 32;
#else
static constexpr size_t kMhloDimSizeBits = 64;
#endif
using namespace mlir::torch::torch_to_mhlo;
namespace {
Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
Value input, Value indices, int64_t axis) {
Value input, Value indices, int64_t axis,
size_t dimSizeIndexBits) {
auto loc = op->getLoc();
Type intType = rewriter.getIntegerType(kMhloDimSizeBits);
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
Value one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, 1));
@ -98,16 +94,7 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
sliceSizesTensor, dimsAttr)
.getResult();
}
template <typename AtenOpT>
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
// Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html
// padding_idx (int, optional)
@ -149,8 +136,8 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(
op, "sparse gradients is currently not supported");
Value output =
gatherTensorAlongSingleAxis(rewriter, op, weight, adaptor.indices(), 0);
Value output = gatherTensorAlongSingleAxis(
rewriter, op, weight, adaptor.indices(), 0, options.dimSizeIndexBits);
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
op, getTypeConverter()->convertType(op.getType()), output);
@ -170,24 +157,23 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(
op, "only constant dim is currently supported");
Value output =
gatherTensorAlongSingleAxis(rewriter, op, self, adaptor.index(), dim);
Value output = gatherTensorAlongSingleAxis(
rewriter, op, self, adaptor.index(), dim, options.dimSizeIndexBits);
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
op, getTypeConverter()->convertType(op.getType()), output);
return success();
}
} // namespace
void mlir::torch::torch_to_mhlo::populateGatherOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
ConversionTarget &target, const TorchToMhloOptions &options) {
MLIRContext *context = patterns.getContext();
#define INSERT_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
INSERT_ATENOP_PATTERN(AtenEmbeddingOp);
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
#undef INSERT_ATENOP_PATTERN

View File

@ -25,6 +25,7 @@
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_mhlo;
namespace {
Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
@ -71,7 +72,8 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input,
}
void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
Value &inpRhs, int64_t leadingRank) {
Value &inpRhs, int64_t leadingRank,
size_t dimSizeIndexBits) {
Value lhs = inpLhs;
Value rhs = inpRhs;
auto lhsRankTy = inpLhs.getType().dyn_cast<RankedTensorType>();
@ -92,9 +94,10 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
std::vector<int64_t> newShape(rhsShape.begin(),
rhsShape.begin() + leadingRank);
newShape.insert(newShape.end(), lhsShape.begin(), lhsShape.end());
auto newDimSizes =
*mhlo::getDimSizesOfTensor(rewriter, op, rhs, leadingDims);
auto lhsDimSizes = *mhlo::getDimSizesOfTensor(rewriter, op, lhs);
auto newDimSizes = *mhlo::getDimSizesOfTensor(
rewriter, op, rhs, leadingDims, dimSizeIndexBits);
auto lhsDimSizes =
*mhlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits);
newDimSizes.insert(newDimSizes.end(), lhsDimSizes.begin(),
lhsDimSizes.end());
lhs = getBroadcastTensor(rewriter, op, lhs, newShape, newDimSizes,
@ -103,9 +106,10 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
std::vector<int64_t> newShape(lhsShape.begin(),
lhsShape.begin() + leadingRank);
newShape.insert(newShape.end(), rhsShape.begin(), rhsShape.end());
auto newDimSizes =
*mhlo::getDimSizesOfTensor(rewriter, op, lhs, leadingDims);
auto rhsDimSizes = *mhlo::getDimSizesOfTensor(rewriter, op, rhs);
auto newDimSizes = *mhlo::getDimSizesOfTensor(
rewriter, op, lhs, leadingDims, dimSizeIndexBits);
auto rhsDimSizes =
*mhlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits);
newDimSizes.insert(newDimSizes.end(), rhsDimSizes.begin(),
rhsDimSizes.end());
rhs = getBroadcastTensor(rewriter, op, rhs, newShape, newDimSizes,
@ -122,9 +126,9 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
// implement their specialized input processing (e.g transpose), and output
// processing, e.g. GEMM or fully connected bias handling.
template <typename AtenOpT>
class ConvertAtenMatmulBaseOp : public OpConversionPattern<AtenOpT> {
class ConvertAtenMatmulBaseOp : public ConvertAtenOp<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
using OpAdaptor = typename AtenOpT::Adaptor;
// Each variant must implement corresponding parameter parsing options.
// Maintain separate input read functions for each variant because it is not
@ -159,20 +163,24 @@ public:
return success();
}
const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
int64_t nBatchDims;
if (rhsRank <= 2) {
auto leadingRank = lhsRank - 2;
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank);
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank,
options.dimSizeIndexBits);
nBatchDims = leadingRank;
} else if (lhsRank <= 2) {
auto leadingRank = rhsRank - 2;
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank);
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank,
options.dimSizeIndexBits);
nBatchDims = leadingRank;
} else {
assert(rhsRank > 2 && lhsRank > 2);
auto leadingRank = std::max(lhsRank - rhsRank, rhsRank - lhsRank);
nBatchDims = std::max(lhsRank - 2, rhsRank - 2);
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank);
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank,
options.dimSizeIndexBits);
}
auto batchDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, nBatchDims));
auto lhsContractingDim = nBatchDims + 1;
@ -187,7 +195,7 @@ public:
/*rhsBatchingDimensions=*/batchDims,
/*lhsContractingDimensions=*/{lhsContractingDim},
/*rhsContractingDimensions=*/{rhsContractingDim});
auto resultTy = OpConversionPattern<AtenOpT>::getTypeConverter()
auto resultTy = ConvertAtenOp<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<RankedTensorType>();
@ -215,7 +223,7 @@ public:
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()
ConvertAtenOp<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<RankedTensorType>(),
output);
@ -340,7 +348,10 @@ public:
auto rhsTy = rhs.getType().cast<RankedTensorType>();
auto leadingRank = std::max(lhsTy.getRank() - rhsTy.getRank(),
rhsTy.getRank() - lhsTy.getRank());
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank);
const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank,
options.dimSizeIndexBits);
auto resultRank = std::max(lhsTy.getRank(), rhsTy.getRank());
auto nBatchDims = resultRank - 2;
auto batchDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, nBatchDims));
@ -356,8 +367,7 @@ public:
/*rhsContractingDimensions=*/{rhsContractingDim});
auto resultTy =
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType());
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType());
Value matmulOutput = rewriter.create<mhlo::DotGeneralOp>(
op->getLoc(), resultTy, lhs, rhs, dotDimensionNumbers, nullptr);
@ -377,12 +387,9 @@ public:
}
};
} // namespace
namespace {
class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
class ConvertAtenConvolutionOp : public ConvertAtenOp<AtenConvolutionOp> {
public:
using OpConversionPattern<AtenConvolutionOp>::OpConversionPattern;
using ConvertAtenOp<AtenConvolutionOp>::ConvertAtenOp;
using OpAdaptor = typename AtenConvolutionOp::Adaptor;
Value reshapeConvWeight(PatternRewriter &rewriter, Operation *op,
@ -390,8 +397,9 @@ public:
auto weightTy = weight.getType().cast<RankedTensorType>();
auto weightElemTy = weightTy.getElementType();
auto rank = weightTy.getRank();
SmallVector<Value> weightShapeVec =
*mhlo::getDimSizesOfTensor(rewriter, op, weight);
const auto &options = getOptions();
SmallVector<Value> weightShapeVec = *mhlo::getDimSizesOfTensor(
rewriter, op, weight, options.dimSizeIndexBits);
auto weightShape = weightTy.getShape();
SmallVector<int64_t> weightShapeInt(rank);
std::copy(weightShape.begin(), weightShape.end(), weightShapeInt.begin());
@ -601,9 +609,8 @@ public:
return mhloConvOp.getResult();
}
LogicalResult
matchAndRewrite(AtenConvolutionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(AtenConvolutionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.input();
Value weight = adaptor.weight();
@ -714,7 +721,10 @@ public:
// Reshape and promote bias
auto inputUnsqzDims =
llvm::to_vector<4>(llvm::seq<int64_t>(-nSpatialDims, 0));
bias = *mhlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims);
const auto &options = getOptions();
bias = *mhlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims,
options.dimSizeIndexBits);
bias = mhlo::promoteType(rewriter, bias, outTy);
DenseIntElementsAttr bcastDimensions;
@ -727,31 +737,31 @@ public:
void mlir::torch::torch_to_mhlo::populateLinearOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
ConversionTarget &target, const TorchToMhloOptions &options) {
MLIRContext *context = patterns.getContext();
#define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenMatMulOp<AtenOp>>(typeConverter, context);
patterns.add<ConvertAtenMatMulOp<AtenOp>>(typeConverter, context, options)
INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp);
#undef INSERT_MATMUL_ATEMOP_PATTERN
#define INSERT_MM_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenMmOp<AtenOp>>(typeConverter, context);
patterns.add<ConvertAtenMmOp<AtenOp>>(typeConverter, context, options)
INSERT_MM_ATENOP_PATTERN(AtenMmOp);
INSERT_MM_ATENOP_PATTERN(AtenBmmOp);
#undef INSERT_MM_ATEMOP_PATTERN
#define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenLinearOp<AtenOp>>(typeConverter, context);
patterns.add<ConvertAtenLinearOp<AtenOp>>(typeConverter, context, options)
INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp);
#undef INSERT_LINEAR_ATEMOP_PATTERN
#define INSERT_CONVOLUTION_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenConvolutionOp>(typeConverter, context);
patterns.add<ConvertAtenConvolutionOp>(typeConverter, context, options)
INSERT_CONVOLUTION_ATENOP_PATTERN(AtenConvolutionOp);
#undef INSERT_CONVOLUTION_ATENOP_PATTERN
}

View File

@ -259,9 +259,10 @@ SmallVector<size_t> toPositiveDims(ArrayRef<int64_t> dims, int64_t rank) {
return posDims;
}
FailureOr<SmallVector<Value, 4>>
getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value,
ArrayRef<int64_t> inpDims) {
FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
Operation *op, Value value,
ArrayRef<int64_t> inpDims,
size_t dimSizeIndexBits) {
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
if (!valueTy) {
return rewriter.notifyMatchFailure(
@ -276,14 +277,15 @@ getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value,
auto loc = op->getLoc();
for (auto d : dims) {
dimSizes.emplace_back(rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIntegerType(kMhloDimSizeBits),
loc, rewriter.getIntegerType(dimSizeIndexBits),
rewriter.create<tensor::DimOp>(loc, value, d)));
}
return dimSizes;
}
FailureOr<SmallVector<Value, 4>>
getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value) {
FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
Operation *op, Value value,
size_t dimSizeIndexBits) {
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
if (!valueTy) {
return rewriter.notifyMatchFailure(
@ -294,12 +296,12 @@ getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value) {
// Get int vector [0, 1, ..., rank-1]
std::vector<int64_t> dims(rank);
std::iota(dims.begin(), dims.end(), 0);
return getDimSizesOfTensor(rewriter, op, value, dims);
return getDimSizesOfTensor(rewriter, op, value, dims, dimSizeIndexBits);
}
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
Value tensor,
ArrayRef<int64_t> inputUnsqzDims) {
Value tensor, ArrayRef<int64_t> inputUnsqzDims,
size_t dimSizeIndexBits) {
// Returns a new tensor with dims of size 1 inserted at the specified
// position.
//
@ -307,7 +309,8 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
// tensor) are specified with unsqzDims. Indices must be in-order, and in
// range of tensor rank. Thus, unsqueeze a rank 1 tensor with {0, 2}, {0, 1,
// 3}, {0, 1, 2} are all valid dimension sets, but {0, 3}, {2} are not.
auto dimSizesInfo = getDimSizesOfTensor(rewriter, op, tensor);
auto dimSizesInfo =
getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits);
if (failed(dimSizesInfo))
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
@ -324,7 +327,7 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
auto loc = op->getLoc();
auto rankTy = tensor.getType().dyn_cast<RankedTensorType>();
auto oldShape = rankTy.getShape();
Type intType = rewriter.getIntegerType(kMhloDimSizeBits);
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
auto one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, 1));

View File

@ -19,11 +19,6 @@
namespace mlir {
namespace mhlo {
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
static constexpr size_t kMhloDimSizeBits = 32;
#else
static constexpr size_t kMhloDimSizeBits = 64;
#endif
using mlir::ConversionPatternRewriter;
@ -60,22 +55,23 @@ SmallVector<size_t> toPositiveDims(ArrayRef<int64_t> dims, int64_t rank);
// Get the dimension sizes of the input tensor, given the dimension axes
FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
Operation *op, Value value,
ArrayRef<int64_t> inpDims);
ArrayRef<int64_t> inpDims,
size_t dimSizeIndexBits);
// Get the dimension sizes of the input tensor
FailureOr<SmallVector<Value, 4>>
getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value);
FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
Operation *op, Value value,
size_t dimSizeIndexBits);
// Get a tensor that unsqueezed the specified dimensions of the input tensor
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
Value tensor,
ArrayRef<int64_t> inputUnsqzDims);
Value tensor, ArrayRef<int64_t> inputUnsqzDims,
size_t dimSizeIndexBits);
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
const APFloat &constant, Value shape,
TensorType outType);
} // namespace mhlo
} // namespace mlir
#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H
#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H

View File

@ -28,6 +28,7 @@
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_mhlo;
static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
PatternRewriter &rewriter) {
@ -72,22 +73,9 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
return nullptr;
}
namespace {
template <typename AtenOpT>
class ConvertAtenPoolingOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
// AtenMaxPool2dOp
namespace {
template <>
LogicalResult ConvertAtenPoolingOp<AtenMaxPool2dOp>::matchAndRewrite(
LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
AtenMaxPool2dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.self();
@ -186,12 +174,10 @@ LogicalResult ConvertAtenPoolingOp<AtenMaxPool2dOp>::matchAndRewrite(
rewriter.replaceOp(op, reduceWindowOp.getResults());
return success();
}
} // namespace
// AtenMaxPool2dWithIndicesOp
namespace {
template <>
LogicalResult ConvertAtenPoolingOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
AtenMaxPool2dWithIndicesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.self();
@ -269,7 +255,9 @@ LogicalResult ConvertAtenPoolingOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
rewriter.getI64Type()),
mhloPadding);
auto inputShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input);
const auto &options = getOptions();
auto inputShapeInfo =
mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(inputShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
@ -379,12 +367,10 @@ LogicalResult ConvertAtenPoolingOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
rewriter.replaceOp(op, reduceWindowOp.getResults());
return success();
}
} // namespace
// AtenAvgPool2dOp
namespace {
template <>
LogicalResult ConvertAtenPoolingOp<AtenAvgPool2dOp>::matchAndRewrite(
LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
AtenAvgPool2dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.self();
@ -502,7 +488,9 @@ LogicalResult ConvertAtenPoolingOp<AtenAvgPool2dOp>::matchAndRewrite(
Value windowSizeConst =
mhlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
windowSizeConst = mhlo::promoteType(rewriter, windowSizeConst, outTy);
auto inputShapeVec = *mhlo::getDimSizesOfTensor(rewriter, op, input);
const auto &options = getOptions();
auto inputShapeVec =
*mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), inputShapeVec);
@ -540,17 +528,15 @@ LogicalResult ConvertAtenPoolingOp<AtenAvgPool2dOp>::matchAndRewrite(
return success();
}
} // namespace
void mlir::torch::torch_to_mhlo::populatePoolingOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
ConversionTarget &target, const TorchToMhloOptions &options) {
MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenMaxPool2dOp>();
patterns.add<ConvertAtenPoolingOp<AtenMaxPool2dOp>>(typeConverter, context);
patterns.add<ConvertAtenOp<AtenMaxPool2dOp>>(typeConverter, context, options);
target.addIllegalOp<AtenAvgPool2dOp>();
patterns.add<ConvertAtenPoolingOp<AtenAvgPool2dOp>>(typeConverter, context);
patterns.add<ConvertAtenOp<AtenAvgPool2dOp>>(typeConverter, context, options);
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
patterns.add<ConvertAtenPoolingOp<AtenMaxPool2dWithIndicesOp>>(typeConverter,
context);
patterns.add<ConvertAtenOp<AtenMaxPool2dWithIndicesOp>>(typeConverter,
context, options);
}

View File

@ -16,25 +16,56 @@ namespace mlir {
namespace torch {
namespace torch_to_mhlo {
struct TorchToMhloOptions {
bool enableStaticShape = false;
size_t dimSizeIndexBits = 64;
};
template <typename AtenOpT>
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
public:
using OpAdaptor = typename AtenOpT::Adaptor;
ConvertAtenOp(TypeConverter &typeConverter, MLIRContext *context,
const TorchToMhloOptions &options)
: OpConversionPattern<AtenOpT>(typeConverter, context) {
this->options = options;
}
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return rewriter.notifyMatchFailure(op, "haven't been implemented");
}
const TorchToMhloOptions &getOptions() const { return options; }
private:
TorchToMhloOptions options;
};
void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);
ConversionTarget &target,
const TorchToMhloOptions &options);
void populateViewLikeOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);
ConversionTarget &target,
const TorchToMhloOptions &options);
void populateGatherOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);
ConversionTarget &target,
const TorchToMhloOptions &options);
void populateReductionOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);
RewritePatternSet &patterns,
ConversionTarget &target,
const TorchToMhloOptions &options);
void populateLinearOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);
ConversionTarget &target,
const TorchToMhloOptions &options);
void populatePoolingOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target);
ConversionTarget &target,
const TorchToMhloOptions &options);
} // namespace torch_to_mhlo
} // namespace torch

View File

@ -25,6 +25,7 @@
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_mhlo;
static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
PatternRewriter &rewriter) {
@ -72,7 +73,8 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
// Util for converting AtenArgmaxOp and AtenMaxDimOp
static llvm::Optional<ValueRange>
getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
ArrayRef<Value> inputShapeVec, int64_t dim) {
ArrayRef<Value> inputShapeVec, int64_t dim,
size_t dimSizeIndexBits) {
auto inputTy = input.getType().template cast<RankedTensorType>();
if (!inputTy) {
return llvm::None;
@ -86,7 +88,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
Value initValue = createInitialValueForReduceOp(op, inputElemTy, rewriter);
if (!initValue) return llvm::None;
Value initIndex;
if (mlir::mhlo::kMhloDimSizeBits == 32) {
if (dimSizeIndexBits == 32) {
initIndex = mhlo::getConstTensor<int32_t>(rewriter, op, {0}, {}).value();
} else {
initIndex = mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
@ -98,7 +100,9 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), inputShapeVec);
auto indexTensor = rewriter.create<mhlo::DynamicIotaOp>(
op->getLoc(), RankedTensorType::get(inputShape, rewriter.getIntegerType(mlir::mhlo::kMhloDimSizeBits)),
op->getLoc(),
RankedTensorType::get(inputShape,
rewriter.getIntegerType(dimSizeIndexBits)),
inputShapeTensor, static_cast<uint64_t>(dim));
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
@ -114,7 +118,8 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
// Add block arguments
auto blockValArgumentType =
RankedTensorType::get({}, inputTy.getElementType());
auto blockIdxArgumentType = RankedTensorType::get({}, rewriter.getIntegerType(mlir::mhlo::kMhloDimSizeBits));
auto blockIdxArgumentType =
RankedTensorType::get({}, rewriter.getIntegerType(dimSizeIndexBits));
auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type());
block.addArgument(blockValArgumentType, op->getLoc());
block.addArgument(blockIdxArgumentType, op->getLoc());
@ -171,9 +176,9 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
namespace {
template <typename AtenOpT>
class ConvertAtenReductionOp : public OpConversionPattern<AtenOpT> {
class ConvertAtenReductionOp : public ConvertAtenOp<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
@ -220,21 +225,24 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
}
auto inputShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input);
const auto &options = getOptions();
auto inputShapeInfo =
mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(inputShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
auto inputShapeVec = *inputShapeInfo;
auto mhloReduceResults =
getMaxInDim(rewriter, op, input, inputShapeVec, dim).value();
auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim,
options.dimSizeIndexBits)
.value();
if (keepDim) {
auto outShapeVec = inputShapeVec;
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(), rewriter.getIntegerAttr(
rewriter.getIntegerType(mlir::mhlo::kMhloDimSizeBits), 1));
op->getLoc(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);
@ -297,20 +305,24 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
}
auto inputShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input);
const auto &options = getOptions();
auto inputShapeInfo =
mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(inputShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
auto inputShapeVec = *inputShapeInfo;
auto mhloReduceResults =
getMaxInDim(rewriter, op, input, inputShapeVec, dim).value();
auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim,
options.dimSizeIndexBits)
.value();
if (keepDim) {
auto outShapeVec = inputShapeVec;
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(), rewriter.getIntegerAttr(
rewriter.getIntegerType(mhlo::kMhloDimSizeBits), 1));
op->getLoc(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);
@ -532,15 +544,18 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
}
if (keepDim) {
auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input);
const auto &options = getOptions();
auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input,
options.dimSizeIndexBits);
if (failed(outShapeInfo)) {
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
}
auto outShapeVec = *outShapeInfo;
auto one = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(), rewriter.getIntegerAttr(
rewriter.getIntegerType(mhlo::kMhloDimSizeBits), 1));
op->getLoc(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
for (int64_t i : dims) {
outShapeVec[i] = one;
}
@ -558,11 +573,11 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
void mlir::torch::torch_to_mhlo::populateReductionOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
ConversionTarget &target, const TorchToMhloOptions &options) {
MLIRContext *context = patterns.getContext();
#define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenReductionOp<AtenOp>>(typeConverter, context);
patterns.add<ConvertAtenReductionOp<AtenOp>>(typeConverter, context, options)
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenArgmaxOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxDimOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp);

View File

@ -32,6 +32,12 @@ namespace {
class ConvertTorchToMhlo : public ConvertTorchToMhloBase<ConvertTorchToMhlo> {
public:
ConvertTorchToMhlo() = default;
ConvertTorchToMhlo(bool enableStaticShape, bool enableI32Index) {
this->enableStaticShape = enableStaticShape;
this->enableI32Index = enableI32Index;
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<chlo::ChloDialect>();
registry.insert<mhlo::MhloDialect>();
@ -51,18 +57,20 @@ public:
RewritePatternSet patterns(context);
torch_to_mhlo::TorchToMhloOptions options{enableStaticShape,
enableI32Index ? 32u : 64u};
torch_to_mhlo::populateBasicOpPatternsAndLegality(typeConverter, patterns,
target);
torch_to_mhlo::populateViewLikeOpPatternsAndLegality(typeConverter,
patterns, target);
target, options);
torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
typeConverter, patterns, target, options);
torch_to_mhlo::populateGatherOpPatternsAndLegality(typeConverter, patterns,
target);
torch_to_mhlo::populateReductionOpPatternsAndLegality(typeConverter,
patterns, target);
target, options);
torch_to_mhlo::populateReductionOpPatternsAndLegality(
typeConverter, patterns, target, options);
torch_to_mhlo::populateLinearOpPatternsAndLegality(typeConverter, patterns,
target);
target, options);
torch_to_mhlo::populatePoolingOpPatternsAndLegality(typeConverter, patterns,
target);
target, options);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
@ -75,5 +83,12 @@ public:
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToMhloPass() {
return std::make_unique<ConvertTorchToMhlo>();
return std::make_unique<ConvertTorchToMhlo>(false, false);
}
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToMhloPass(bool enableStaticShape,
bool enableI32Index) {
return std::make_unique<ConvertTorchToMhlo>(enableStaticShape,
enableI32Index);
}

View File

@ -28,6 +28,7 @@ using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::TorchConversion;
using namespace mlir::torch::torch_to_mhlo;
namespace {
// A dimension index from torch.dialect might outside the range [0, dimSize].
@ -55,10 +56,11 @@ Value getNormalizedDimSizeInternal(PatternRewriter &rewriter, Operation *op,
Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op,
Type outTy, Value input, Value startIndex,
Value endIndex, Value step, size_t dimIndex,
ArrayRef<Value> dimSizes) {
ArrayRef<Value> dimSizes,
size_t dimSizeIndexBits) {
auto loc = op->getLoc();
// startIndex & endIndex has been normailized into range [0, dSize]
Type intType = rewriter.getIntegerType(mhlo::kMhloDimSizeBits);
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
Value zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, 0));
Value one = rewriter.create<arith::ConstantOp>(
@ -109,7 +111,8 @@ FailureOr<Value> getDynamicSlice(PatternRewriter &rewriter, Operation *op,
Type outTy, Value input,
llvm::Optional<Value> startIndexOpt,
llvm::Optional<Value> endIndexOpt,
llvm::Optional<Value> stepOpt, int64_t dim) {
llvm::Optional<Value> stepOpt, int64_t dim,
size_t dimSizeIndexBits) {
auto loc = op->getLoc();
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
auto rank = inputTy.getRank();
@ -133,77 +136,31 @@ FailureOr<Value> getDynamicSlice(PatternRewriter &rewriter, Operation *op,
: rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1));
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
auto i32Type = rewriter.getIntegerType(mhlo::kMhloDimSizeBits);
normStartIndex =
rewriter.create<arith::TruncIOp>(loc, i32Type, normStartIndex);
normEndIndex = rewriter.create<arith::TruncIOp>(loc, i32Type, normEndIndex);
step = rewriter.create<arith::TruncIOp>(loc, i32Type, step);
#endif
if (dimSizeIndexBits == 32) {
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
normStartIndex =
rewriter.create<arith::TruncIOp>(loc, intType, normStartIndex);
normEndIndex = rewriter.create<arith::TruncIOp>(loc, intType, normEndIndex);
step = rewriter.create<arith::TruncIOp>(loc, intType, step);
}
FailureOr<SmallVector<Value, 4>> dimSizesInfo =
mhlo::getDimSizesOfTensor(rewriter, op, input);
mhlo::getDimSizesOfTensor(rewriter, op, input, dimSizeIndexBits);
if (failed(dimSizesInfo))
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
auto dimSizes = *dimSizesInfo;
return getDynamicSliceInternal(rewriter, op, outTy, input, normStartIndex,
normEndIndex, step, dim, dimSizes);
}
template <typename AtenOpT>
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
template <>
LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
AtenSliceTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto self = adaptor.self();
auto selfTy = self.getType().template cast<RankedTensorType>();
if (!selfTy)
return op.emitError("only ranked tensor types are supported");
auto outTy =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
int64_t dim;
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
op, "only constant dim is currently supported");
auto getOptionalVal = [&](Value val) -> llvm::Optional<Value> {
if (val.getType().isa<Torch::NoneType>()) {
return llvm::None;
} else {
return val;
}
};
llvm::Optional<Value> start = getOptionalVal(adaptor.start());
llvm::Optional<Value> end = getOptionalVal(adaptor.end());
llvm::Optional<Value> step = getOptionalVal(adaptor.step());
FailureOr<Value> sliceInfo =
getDynamicSlice(rewriter, op, outTy, self, start, end, step, dim);
if (failed(sliceInfo))
return op.emitError("can not create a dynmaic slice");
auto slice = *sliceInfo;
rewriter.replaceOp(op, slice);
return success();
normEndIndex, step, dim, dimSizes,
dimSizeIndexBits);
}
// This defines a template to construct ops whose legalizations are
// specialized.
template <typename AtenOpT>
class ConvertAtenViewOp : public OpConversionPattern<AtenOpT> {
class ConvertAtenViewOp : public ConvertAtenOp<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
@ -235,19 +192,19 @@ public:
return dSize;
});
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
// The i64 calculation is much slower than i32 on some devices, such as
// Nvidia GPU. One can truncate from i64 to i32 since dimension sizes are
// unlikely to exceed the range of i32(4GiB)
std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value &dSize) {
// dimSize: cast i64 -> i32
dSize =
rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), dSize);
return dSize;
});
#endif
const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
Type intType = rewriter.getIntegerType(options.dimSizeIndexBits);
if (options.dimSizeIndexBits == 32) {
// The i64 calculation is much slower than i32 on some devices, such as
// Nvidia GPU. One can truncate from i64 to i32 since dimension sizes are
// unlikely to exceed the range of i32(4GiB)
std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value &dSize) {
// dimSize: cast i64 -> i32
dSize = rewriter.create<arith::TruncIOp>(loc, intType, dSize);
return dSize;
});
}
Type intType = rewriter.getIntegerType(mhlo::kMhloDimSizeBits);
Value numel = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, 1));
for (auto d : dimSizes) {
@ -293,6 +250,45 @@ bool ConvertAtenViewOp<AtenReshapeOp>::getAtenViewOpSizes(
SmallVector<Value, 4> &dimSizes) const {
return getListConstructElements(adaptor.shape(), dimSizes);
}
} // namespace
template <>
LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
AtenSliceTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto self = adaptor.self();
auto selfTy = self.getType().template cast<RankedTensorType>();
if (!selfTy)
return op.emitError("only ranked tensor types are supported");
auto outTy =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
int64_t dim;
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
op, "only constant dim is currently supported");
auto getOptionalVal = [&](Value val) -> llvm::Optional<Value> {
if (val.getType().isa<Torch::NoneType>()) {
return llvm::None;
} else {
return val;
}
};
llvm::Optional<Value> start = getOptionalVal(adaptor.start());
llvm::Optional<Value> end = getOptionalVal(adaptor.end());
llvm::Optional<Value> step = getOptionalVal(adaptor.step());
FailureOr<Value> sliceInfo =
getDynamicSlice(rewriter, op, outTy, self, start, end, step, dim,
options.dimSizeIndexBits);
if (failed(sliceInfo))
return op.emitError("can not create a dynmaic slice");
auto slice = *sliceInfo;
rewriter.replaceOp(op, slice);
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
@ -324,7 +320,8 @@ LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
return success();
}
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims);
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims,
options.dimSizeIndexBits);
if (failed(newDimSizesInfo))
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
@ -372,7 +369,8 @@ LogicalResult ConvertAtenOp<AtenSqueezeDimOp>::matchAndRewrite(
op, getTypeConverter()->convertType(op.getType()), self);
return success();
}
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims);
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims,
options.dimSizeIndexBits);
if (failed(newDimSizesInfo))
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
@ -397,8 +395,8 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
return op->emitError("dim must be a Scalar constant");
auto unsqzTensorInfo =
mhlo::unsqueezeTensor(rewriter, op, adaptor.self(), {dim});
auto unsqzTensorInfo = mhlo::unsqueezeTensor(rewriter, op, adaptor.self(),
{dim}, options.dimSizeIndexBits);
if (failed(unsqzTensorInfo))
return rewriter.notifyMatchFailure(op,
"failed to create unsqueezed tensor");
@ -406,16 +404,15 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
rewriter.replaceOp(op, *unsqzTensorInfo);
return success();
}
} // namespace
void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
ConversionTarget &target, const TorchToMhloOptions &options) {
MLIRContext *context = patterns.getContext();
#define INSERT_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
INSERT_ATENOP_PATTERN(AtenSqueezeOp);
INSERT_ATENOP_PATTERN(AtenSqueezeDimOp);
@ -424,7 +421,7 @@ void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
#define INSERT_VIEW_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenViewOp<AtenOp>>(typeConverter, context);
patterns.add<ConvertAtenViewOp<AtenOp>>(typeConverter, context, options)
INSERT_VIEW_OP_PATTERN(AtenViewOp);
INSERT_VIEW_OP_PATTERN(AtenReshapeOp);
#undef INSERT_VIEW_OP_PATTERN