[MHLO] refactor pass configurations (#1315)

Related to https://github.com/llvm/torch-mlir/issues/1227

1. Reduce MHLO #ifdefs
2. Dismiss compilation warnings
pull/1321/head
Tanyo Kwok 2022-09-01 10:36:02 +08:00 committed by GitHub
parent 7769eb88f8
commit 29cafdbb61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 373 additions and 380 deletions

View File

@ -39,14 +39,6 @@ endmacro()
option(TORCH_MLIR_ENABLE_MHLO "Add mhlo dialect" ON) option(TORCH_MLIR_ENABLE_MHLO "Add mhlo dialect" ON)
if(TORCH_MLIR_ENABLE_MHLO) if(TORCH_MLIR_ENABLE_MHLO)
add_definitions(-DTORCH_MLIR_ENABLE_MHLO) add_definitions(-DTORCH_MLIR_ENABLE_MHLO)
# The i64 calculation is much slower than i32 on some devices, such as Nvidia GPU.
# One can truncate from i64 to i32 since dimension sizes are unlikely to exceed
# the range of i32(4GiB)
option(TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
"Enable truncate dimension size from i64 to i32(unsafely)" OFF)
if(TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32)
add_definitions(-DTORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32)
endif()
endif() endif()
option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF) option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF)

View File

@ -132,6 +132,17 @@ def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> {
Convert Torch ops to mhlo ops. Convert Torch ops to mhlo ops.
}]; }];
let constructor = "mlir::torch::createConvertTorchToMhloPass()"; let constructor = "mlir::torch::createConvertTorchToMhloPass()";
// Specify any options.
let options = [
Option<"enableStaticShape", "enable-static-shape", "bool", /*default=*/"false",
"Enable static shape conversion">,
// The i64 calculation is much slower than i32 on some devices, such as
// Nvidia GPU. One can truncate from i64 to i32 since dimension sizes
// are unlikely to exceed the range of i32(4GiB)
Option<"enableI32Index", "enable-i32-index", "bool", /*default=*/"false",
"Enable truncate index from i64 to i32(unsafely)">,
];
} }
#endif #endif

View File

@ -17,6 +17,8 @@
namespace mlir { namespace mlir {
namespace torch { namespace torch {
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToMhloPass(); std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToMhloPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchToMhloPass(bool enableStaticShape, bool enableI32Index);
} // namespace torch } // namespace torch
} // namespace mlir } // namespace mlir

View File

@ -29,6 +29,7 @@
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_mhlo;
bool skipMultiplyAlpha(Value alphaValue) { bool skipMultiplyAlpha(Value alphaValue) {
double doubleValue; double doubleValue;
@ -379,63 +380,58 @@ public:
} // namespace } // namespace
// AtenBroadcastToOp // AtenBroadcastToOp
namespace { template <>
class ConvertAtenBroadcastToOp : public OpConversionPattern<AtenBroadcastToOp> { LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
public: AtenBroadcastToOp op, OpAdaptor adaptor,
using OpConversionPattern::OpConversionPattern; ConversionPatternRewriter &rewriter) const {
LogicalResult Value self = adaptor.self();
matchAndRewrite(AtenBroadcastToOp op, OpAdaptor adaptor, auto selfTy = self.getType().cast<RankedTensorType>();
ConversionPatternRewriter &rewriter) const override { auto outType = getTypeConverter()
Value self = adaptor.self(); ->convertType(op->getResult(0).getType())
auto selfTy = self.getType().cast<RankedTensorType>(); .cast<RankedTensorType>();
auto outType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
#ifdef TORCH_MLIR_ENABLE_MHLO_STATIC_SHAPE if (options.enableStaticShape && selfTy.hasStaticShape()) {
if (selfTy.hasStaticShape()) { Value bcastOp = mhlo::promoteAndBroadcast(rewriter, self, outType);
Value bcastOp = mhlo::promoteAndBroadcast(rewriter, self, outType); rewriter.replaceOp(op, bcastOp);
rewriter.replaceOp(op, bcastOp); return success();
return success(); }
SmallVector<Value> shape;
if (!(getListConstructElements(adaptor.size(), shape))) {
return op->emitError("desired shape must be a list of scalar");
}
SmallVector<Value> bcastShapeVec;
int64_t totalRank = shape.size();
int64_t selfRank = selfTy.getRank();
int64_t leadingRank = totalRank - selfRank;
for (int64_t i = 0; i < totalRank; ++i) {
Value dValue = shape[i];
Value newD;
int64_t dInt;
if (!(matchPattern(dValue, m_TorchConstantInt(&dInt)))) {
return op->emitError("element of desired shape must be a scalar");
} }
#endif if (i >= leadingRank && dInt == -1) {
newD = rewriter.create<mlir::tensor::DimOp>(op->getLoc(), self,
SmallVector<Value> shape; i - leadingRank);
if (!(getListConstructElements(adaptor.size(), shape))) { } else {
return op->emitError("desired shape must be a list of scalar"); dValue = rewriter.create<torch::TorchConversion::ToI64Op>(op->getLoc(),
dValue);
newD = rewriter.create<mlir::arith::IndexCastOp>(
op->getLoc(), rewriter.getIndexType(), dValue);
} }
SmallVector<Value> bcastShapeVec; bcastShapeVec.push_back(newD);
int64_t totalRank = shape.size(); }
int64_t selfRank = selfTy.getRank();
int64_t leadingRank = totalRank - selfRank;
for (int64_t i = 0; i < totalRank; ++i) { if (options.dimSizeIndexBits == 32) {
Value dValue = shape[i];
Value newD;
int64_t dInt;
if (!(matchPattern(dValue, m_TorchConstantInt(&dInt)))) {
return op->emitError("element of desired shape must be a scalar");
}
if (i >= leadingRank && dInt == -1) {
newD = rewriter.create<mlir::tensor::DimOp>(op->getLoc(), self,
i - leadingRank);
} else {
dValue = rewriter.create<torch::TorchConversion::ToI64Op>(op->getLoc(),
dValue);
newD = rewriter.create<mlir::arith::IndexCastOp>(
op->getLoc(), rewriter.getIndexType(), dValue);
}
bcastShapeVec.push_back(newD);
}
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
for (auto &dsize : bcastShapeVec) { for (auto &dsize : bcastShapeVec) {
auto dsizeI64 = rewriter.create<mlir::arith::IndexCastOp>( auto dsizeI64 = rewriter.create<mlir::arith::IndexCastOp>(
op->getLoc(), rewriter.getI64Type(), dsize); op->getLoc(), rewriter.getI64Type(), dsize);
dsize = rewriter.create<arith::TruncIOp>(op->getLoc(), dsize = rewriter.create<arith::TruncIOp>(op->getLoc(),
rewriter.getI32Type(), dsizeI64); rewriter.getI32Type(), dsizeI64);
} }
#endif }
Value bcastShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>( Value bcastShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), ValueRange{bcastShapeVec}); op->getLoc(), ValueRange{bcastShapeVec});
@ -445,66 +441,45 @@ public:
op, outType, self, bcastShapeTensor, op, outType, self, bcastShapeTensor,
rewriter.getI64TensorAttr(dimensionNumbers)); rewriter.getI64TensorAttr(dimensionNumbers));
return success(); return success();
} }
};
} // namespace
// AtenPermuteOp // AtenPermuteOp
namespace { template <>
class ConvertAtenPermuteOp : public OpConversionPattern<AtenPermuteOp> { LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
public: AtenPermuteOp op, OpAdaptor adaptor,
using OpConversionPattern::OpConversionPattern; ConversionPatternRewriter &rewriter) const {
LogicalResult Value self = adaptor.self();
matchAndRewrite(AtenPermuteOp op, OpAdaptor adaptor, // Not a ranked tensor type
ConversionPatternRewriter &rewriter) const override { auto inType = self.getType().dyn_cast<RankedTensorType>();
Value self = adaptor.self(); auto outType = getTypeConverter()
// Not a ranked tensor type ->convertType(op->getResult(0).getType())
auto inType = self.getType().dyn_cast<RankedTensorType>(); .cast<RankedTensorType>();
auto outType = getTypeConverter() if (!inType)
->convertType(op->getResult(0).getType()) return op.emitError("only ranked tensor types with static shapes are "
.cast<RankedTensorType>(); "currently supported");
if (!inType)
return op.emitError("only ranked tensor types with static shapes are "
"currently supported");
SmallVector<int64_t> permValues; SmallVector<int64_t> permValues;
if (!matchPattern(adaptor.dims(), m_TorchConstantIntList(permValues))) if (!matchPattern(adaptor.dims(), m_TorchConstantIntList(permValues)))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "only constant dimensions are currently supported"); op, "only constant dimensions are currently supported");
int64_t inRank = inType.getRank(); int64_t inRank = inType.getRank();
for (auto &d : permValues) { for (auto &d : permValues) {
d = toPositiveDim(d, inRank); d = toPositiveDim(d, inRank);
if (!isValidDim(d, inRank)) if (!isValidDim(d, inRank))
return op.emitError("not all dims are valid"); return op.emitError("not all dims are valid");
}
DenseIntElementsAttr permutation = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<long int>(permValues.size())},
rewriter.getI64Type()),
permValues);
rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(op, outType, self,
permutation);
return success();
} }
};
} // namespace DenseIntElementsAttr permutation = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<long int>(permValues.size())},
namespace { rewriter.getI64Type()),
template <typename AtenOpT> permValues);
class ConvertAtenOp : public OpConversionPattern<AtenOpT> { rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(op, outType, self,
public: permutation);
using OpConversionPattern<AtenOpT>::OpConversionPattern; return success();
using OpAdaptor = typename AtenOpT::Adaptor; }
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
// AtenTanhOp // AtenTanhOp
namespace {
template <> template <>
LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
AtenTanhOp op, OpAdaptor adaptor, AtenTanhOp op, OpAdaptor adaptor,
@ -520,10 +495,8 @@ LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
"only floating-point datatype legalization currently supported"); "only floating-point datatype legalization currently supported");
} }
} }
} // namespace
// ValueTensorLiteralOp // ValueTensorLiteralOp
namespace {
template <> template <>
LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite( LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
ValueTensorLiteralOp op, OpAdaptor adaptor, ValueTensorLiteralOp op, OpAdaptor adaptor,
@ -553,11 +526,9 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
return success(); return success();
} }
} // namespace
// AtenReciprocalOp // AtenReciprocalOp
// Reciprocal(x) = Div(1, x) // Reciprocal(x) = Div(1, x)
namespace {
template <> template <>
LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
AtenReciprocalOp op, OpAdaptor adaptor, AtenReciprocalOp op, OpAdaptor adaptor,
@ -575,10 +546,8 @@ LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
rewriter.replaceOpWithNewOp<mhlo::DivOp>(op, outTy, oneTensor, input); rewriter.replaceOpWithNewOp<mhlo::DivOp>(op, outTy, oneTensor, input);
return success(); return success();
} }
} // namespace
// PrimNumToTensorScalarOp // PrimNumToTensorScalarOp
namespace {
template <> template <>
LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite( LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
PrimNumToTensorScalarOp op, OpAdaptor adaptor, PrimNumToTensorScalarOp op, OpAdaptor adaptor,
@ -592,11 +561,9 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
rewriter.replaceOp(op, mhloTensor); rewriter.replaceOp(op, mhloTensor);
return success(); return success();
} }
} // namespace
// AtenContiguousOp // AtenContiguousOp
// Ref: TosaToTosa.cpp for implementation details // Ref: TosaToTosa.cpp for implementation details
namespace {
template <> template <>
LogicalResult ConvertAtenOp<AtenContiguousOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenContiguousOp>::matchAndRewrite(
AtenContiguousOp op, OpAdaptor adaptor, AtenContiguousOp op, OpAdaptor adaptor,
@ -614,11 +581,9 @@ LogicalResult ConvertAtenOp<AtenContiguousOp>::matchAndRewrite(
return success(); return success();
} }
} // namespace
// AtenReluOp // AtenReluOp
// Relu(x) = Max(0, x) // Relu(x) = Max(0, x)
namespace {
template <> template <>
LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
AtenReluOp op, OpAdaptor adaptor, AtenReluOp op, OpAdaptor adaptor,
@ -641,11 +606,9 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
return success(); return success();
} }
} // namespace
// Convert a Aten::GELU to HLO // Convert a Aten::GELU to HLO
// Gelu(x) = x * 1/2 * [1 + erf(x/(sqrt(2)))] // Gelu(x) = x * 1/2 * [1 + erf(x/(sqrt(2)))]
namespace {
template <> template <>
LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
AtenGeluOp op, OpAdaptor adaptor, AtenGeluOp op, OpAdaptor adaptor,
@ -668,10 +631,8 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, input, halfMul); rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, input, halfMul);
return success(); return success();
} }
} // namespace
// AtenErfOp // AtenErfOp
namespace {
template <> template <>
LogicalResult ConvertAtenOp<AtenErfOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenErfOp>::matchAndRewrite(
AtenErfOp op, OpAdaptor adaptor, AtenErfOp op, OpAdaptor adaptor,
@ -686,10 +647,8 @@ LogicalResult ConvertAtenOp<AtenErfOp>::matchAndRewrite(
return success(); return success();
} }
} // namespace
// AtenBatchNormOp // AtenBatchNormOp
namespace {
template <> template <>
LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
AtenBatchNormOp op, OpAdaptor adaptor, AtenBatchNormOp op, OpAdaptor adaptor,
@ -716,12 +675,12 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
Value channelDim = rewriter.create<tensor::DimOp>(op->getLoc(), input, 1); Value channelDim = rewriter.create<tensor::DimOp>(op->getLoc(), input, 1);
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32 if (options.dimSizeIndexBits == 32) {
auto channelDimI64 = rewriter.create<mlir::arith::IndexCastOp>( auto channelDimI64 = rewriter.create<mlir::arith::IndexCastOp>(
op->getLoc(), rewriter.getI64Type(), channelDim); op->getLoc(), rewriter.getI64Type(), channelDim);
channelDim = rewriter.create<arith::TruncIOp>( channelDim = rewriter.create<arith::TruncIOp>(
op->getLoc(), rewriter.getI32Type(), channelDimI64); op->getLoc(), rewriter.getI32Type(), channelDimI64);
#endif }
Value channelShape = rewriter.create<tensor::FromElementsOp>( Value channelShape = rewriter.create<tensor::FromElementsOp>(
op->getLoc(), ValueRange{channelDim}); op->getLoc(), ValueRange{channelDim});
@ -806,10 +765,8 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
} }
} }
} // namespace
// AtenNativeLayerNormOp // AtenNativeLayerNormOp
namespace {
template <> template <>
LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
AtenNativeLayerNormOp op, OpAdaptor adaptor, AtenNativeLayerNormOp op, OpAdaptor adaptor,
@ -949,10 +906,8 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
return success(); return success();
} }
} // namespace
// AtenCatOp // AtenCatOp
namespace {
template <> template <>
LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
AtenCatOp op, OpAdaptor adaptor, AtenCatOp op, OpAdaptor adaptor,
@ -983,10 +938,8 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
op, outType, ValueRange(builtinTensors), posDim); op, outType, ValueRange(builtinTensors), posDim);
return success(); return success();
} }
} // namespace
// AtenNumelOp // AtenNumelOp
namespace {
template <> template <>
LogicalResult ConvertAtenOp<AtenNumelOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenNumelOp>::matchAndRewrite(
AtenNumelOp op, AtenNumelOp op,
@ -996,7 +949,7 @@ LogicalResult ConvertAtenOp<AtenNumelOp>::matchAndRewrite(
auto selfTy = self.getType().dyn_cast<RankedTensorType>(); auto selfTy = self.getType().dyn_cast<RankedTensorType>();
size_t rank = selfTy.getRank(); size_t rank = selfTy.getRank();
Type intType = rewriter.getIntegerType(mhlo::kMhloDimSizeBits); Type intType = rewriter.getIntegerType(options.dimSizeIndexBits);
auto loc = op->getLoc(); auto loc = op->getLoc();
Value numel = Value numel =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(intType, 1)); rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(intType, 1));
@ -1015,26 +968,18 @@ LogicalResult ConvertAtenOp<AtenNumelOp>::matchAndRewrite(
} }
return success(); return success();
} }
} // namespace
void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns, TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) { ConversionTarget &target, const TorchToMhloOptions &options) {
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenTransposeIntOp>(); target.addIllegalOp<AtenTransposeIntOp>();
patterns.add<ConvertAtenTransposeIntOp>(typeConverter, context); patterns.add<ConvertAtenTransposeIntOp>(typeConverter, context);
target.addIllegalOp<AtenBroadcastToOp>();
patterns.add<ConvertAtenBroadcastToOp>(typeConverter, context);
target.addIllegalOp<AtenPermuteOp>();
patterns.add<ConvertAtenPermuteOp>(typeConverter, context);
#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, MhloOp) \ #define INSERT_UNARY_FPONLY_PATTERN(AtenOp, MhloOp) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, MhloOp>>(typeConverter, \ patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, MhloOp>>(typeConverter, context)
context);
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, mhlo::LogOp); INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, mhlo::LogOp);
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, mhlo::ExpOp); INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, mhlo::ExpOp);
INSERT_UNARY_FPONLY_PATTERN(AtenCloneOp, mhlo::CopyOp); INSERT_UNARY_FPONLY_PATTERN(AtenCloneOp, mhlo::CopyOp);
@ -1045,14 +990,14 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ #define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenConstPatternOp<AtenOp, fillVal>>(typeConverter, \ patterns.add<ConvertAtenConstPatternOp<AtenOp, fillVal>>(typeConverter, \
context); context)
INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1); INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1);
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0); INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
#undef INSERT_CONSTANT_FILL_PATTERN #undef INSERT_CONSTANT_FILL_PATTERN
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, ChloOp) \ #define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, ChloOp) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenAddSubOp<AtenOp, ChloOp>>(typeConverter, context); patterns.add<ConvertAtenAddSubOp<AtenOp, ChloOp>>(typeConverter, context)
INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, chlo::BroadcastAddOp); INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, chlo::BroadcastAddOp);
INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, chlo::BroadcastAddOp); INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, chlo::BroadcastAddOp);
INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, chlo::BroadcastSubOp); INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, chlo::BroadcastSubOp);
@ -1062,7 +1007,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
#define INSERT_BINARY_MULDIV_PATTERN(AtenOp, ChloOp) \ #define INSERT_BINARY_MULDIV_PATTERN(AtenOp, ChloOp) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenMulDivOp<AtenOp, ChloOp>>(typeConverter, context); patterns.add<ConvertAtenMulDivOp<AtenOp, ChloOp>>(typeConverter, context)
INSERT_BINARY_MULDIV_PATTERN(AtenMulTensorOp, chlo::BroadcastMulOp); INSERT_BINARY_MULDIV_PATTERN(AtenMulTensorOp, chlo::BroadcastMulOp);
INSERT_BINARY_MULDIV_PATTERN(AtenMulScalarOp, chlo::BroadcastMulOp); INSERT_BINARY_MULDIV_PATTERN(AtenMulScalarOp, chlo::BroadcastMulOp);
INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorOp, chlo::BroadcastDivOp); INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorOp, chlo::BroadcastDivOp);
@ -1072,7 +1017,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
#define INSERT_BINARY_COMPARE_PATTERN(AtenOp) \ #define INSERT_BINARY_COMPARE_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenCompareOp<AtenOp>>(typeConverter, context); patterns.add<ConvertAtenCompareOp<AtenOp>>(typeConverter, context)
INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp);
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp); INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp);
@ -1086,7 +1031,11 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
#define INSERT_ATENOP_PATTERN(AtenOp) \ #define INSERT_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context); patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
INSERT_ATENOP_PATTERN(AtenPermuteOp);
INSERT_ATENOP_PATTERN(AtenTanhOp); INSERT_ATENOP_PATTERN(AtenTanhOp);
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
INSERT_ATENOP_PATTERN(AtenReciprocalOp); INSERT_ATENOP_PATTERN(AtenReciprocalOp);

View File

@ -23,18 +23,14 @@
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_mhlo;
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
static constexpr size_t kMhloDimSizeBits = 32;
#else
static constexpr size_t kMhloDimSizeBits = 64;
#endif
namespace { namespace {
Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op, Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
Value input, Value indices, int64_t axis) { Value input, Value indices, int64_t axis,
size_t dimSizeIndexBits) {
auto loc = op->getLoc(); auto loc = op->getLoc();
Type intType = rewriter.getIntegerType(kMhloDimSizeBits); Type intType = rewriter.getIntegerType(dimSizeIndexBits);
Value one = rewriter.create<arith::ConstantOp>( Value one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, 1)); loc, rewriter.getIntegerAttr(intType, 1));
@ -98,16 +94,7 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
sliceSizesTensor, dimsAttr) sliceSizesTensor, dimsAttr)
.getResult(); .getResult();
} }
} // namespace
template <typename AtenOpT>
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
// Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html // Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html
// padding_idx (int, optional) // padding_idx (int, optional)
@ -149,8 +136,8 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "sparse gradients is currently not supported"); op, "sparse gradients is currently not supported");
Value output = Value output = gatherTensorAlongSingleAxis(
gatherTensorAlongSingleAxis(rewriter, op, weight, adaptor.indices(), 0); rewriter, op, weight, adaptor.indices(), 0, options.dimSizeIndexBits);
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>( rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
op, getTypeConverter()->convertType(op.getType()), output); op, getTypeConverter()->convertType(op.getType()), output);
@ -170,24 +157,23 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "only constant dim is currently supported"); op, "only constant dim is currently supported");
Value output = Value output = gatherTensorAlongSingleAxis(
gatherTensorAlongSingleAxis(rewriter, op, self, adaptor.index(), dim); rewriter, op, self, adaptor.index(), dim, options.dimSizeIndexBits);
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>( rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
op, getTypeConverter()->convertType(op.getType()), output); op, getTypeConverter()->convertType(op.getType()), output);
return success(); return success();
} }
} // namespace
void mlir::torch::torch_to_mhlo::populateGatherOpPatternsAndLegality( void mlir::torch::torch_to_mhlo::populateGatherOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns, TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) { ConversionTarget &target, const TorchToMhloOptions &options) {
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
#define INSERT_ATENOP_PATTERN(AtenOp) \ #define INSERT_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context); patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
INSERT_ATENOP_PATTERN(AtenEmbeddingOp); INSERT_ATENOP_PATTERN(AtenEmbeddingOp);
INSERT_ATENOP_PATTERN(AtenIndexSelectOp); INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
#undef INSERT_ATENOP_PATTERN #undef INSERT_ATENOP_PATTERN

View File

@ -25,6 +25,7 @@
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_mhlo;
namespace { namespace {
Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor, Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
@ -71,7 +72,8 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input,
} }
void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs, void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
Value &inpRhs, int64_t leadingRank) { Value &inpRhs, int64_t leadingRank,
size_t dimSizeIndexBits) {
Value lhs = inpLhs; Value lhs = inpLhs;
Value rhs = inpRhs; Value rhs = inpRhs;
auto lhsRankTy = inpLhs.getType().dyn_cast<RankedTensorType>(); auto lhsRankTy = inpLhs.getType().dyn_cast<RankedTensorType>();
@ -92,9 +94,10 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
std::vector<int64_t> newShape(rhsShape.begin(), std::vector<int64_t> newShape(rhsShape.begin(),
rhsShape.begin() + leadingRank); rhsShape.begin() + leadingRank);
newShape.insert(newShape.end(), lhsShape.begin(), lhsShape.end()); newShape.insert(newShape.end(), lhsShape.begin(), lhsShape.end());
auto newDimSizes = auto newDimSizes = *mhlo::getDimSizesOfTensor(
*mhlo::getDimSizesOfTensor(rewriter, op, rhs, leadingDims); rewriter, op, rhs, leadingDims, dimSizeIndexBits);
auto lhsDimSizes = *mhlo::getDimSizesOfTensor(rewriter, op, lhs); auto lhsDimSizes =
*mhlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits);
newDimSizes.insert(newDimSizes.end(), lhsDimSizes.begin(), newDimSizes.insert(newDimSizes.end(), lhsDimSizes.begin(),
lhsDimSizes.end()); lhsDimSizes.end());
lhs = getBroadcastTensor(rewriter, op, lhs, newShape, newDimSizes, lhs = getBroadcastTensor(rewriter, op, lhs, newShape, newDimSizes,
@ -103,9 +106,10 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
std::vector<int64_t> newShape(lhsShape.begin(), std::vector<int64_t> newShape(lhsShape.begin(),
lhsShape.begin() + leadingRank); lhsShape.begin() + leadingRank);
newShape.insert(newShape.end(), rhsShape.begin(), rhsShape.end()); newShape.insert(newShape.end(), rhsShape.begin(), rhsShape.end());
auto newDimSizes = auto newDimSizes = *mhlo::getDimSizesOfTensor(
*mhlo::getDimSizesOfTensor(rewriter, op, lhs, leadingDims); rewriter, op, lhs, leadingDims, dimSizeIndexBits);
auto rhsDimSizes = *mhlo::getDimSizesOfTensor(rewriter, op, rhs); auto rhsDimSizes =
*mhlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits);
newDimSizes.insert(newDimSizes.end(), rhsDimSizes.begin(), newDimSizes.insert(newDimSizes.end(), rhsDimSizes.begin(),
rhsDimSizes.end()); rhsDimSizes.end());
rhs = getBroadcastTensor(rewriter, op, rhs, newShape, newDimSizes, rhs = getBroadcastTensor(rewriter, op, rhs, newShape, newDimSizes,
@ -122,9 +126,9 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
// implement their specialized input processing (e.g transpose), and output // implement their specialized input processing (e.g transpose), and output
// processing, e.g. GEMM or fully connected bias handling. // processing, e.g. GEMM or fully connected bias handling.
template <typename AtenOpT> template <typename AtenOpT>
class ConvertAtenMatmulBaseOp : public OpConversionPattern<AtenOpT> { class ConvertAtenMatmulBaseOp : public ConvertAtenOp<AtenOpT> {
public: public:
using OpConversionPattern<AtenOpT>::OpConversionPattern; using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
using OpAdaptor = typename AtenOpT::Adaptor; using OpAdaptor = typename AtenOpT::Adaptor;
// Each variant must implement corresponding parameter parsing options. // Each variant must implement corresponding parameter parsing options.
// Maintain separate input read functions for each variant because it is not // Maintain separate input read functions for each variant because it is not
@ -159,20 +163,24 @@ public:
return success(); return success();
} }
const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
int64_t nBatchDims; int64_t nBatchDims;
if (rhsRank <= 2) { if (rhsRank <= 2) {
auto leadingRank = lhsRank - 2; auto leadingRank = lhsRank - 2;
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank); getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank,
options.dimSizeIndexBits);
nBatchDims = leadingRank; nBatchDims = leadingRank;
} else if (lhsRank <= 2) { } else if (lhsRank <= 2) {
auto leadingRank = rhsRank - 2; auto leadingRank = rhsRank - 2;
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank); getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank,
options.dimSizeIndexBits);
nBatchDims = leadingRank; nBatchDims = leadingRank;
} else { } else {
assert(rhsRank > 2 && lhsRank > 2); assert(rhsRank > 2 && lhsRank > 2);
auto leadingRank = std::max(lhsRank - rhsRank, rhsRank - lhsRank); auto leadingRank = std::max(lhsRank - rhsRank, rhsRank - lhsRank);
nBatchDims = std::max(lhsRank - 2, rhsRank - 2); nBatchDims = std::max(lhsRank - 2, rhsRank - 2);
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank); getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank,
options.dimSizeIndexBits);
} }
auto batchDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, nBatchDims)); auto batchDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, nBatchDims));
auto lhsContractingDim = nBatchDims + 1; auto lhsContractingDim = nBatchDims + 1;
@ -187,7 +195,7 @@ public:
/*rhsBatchingDimensions=*/batchDims, /*rhsBatchingDimensions=*/batchDims,
/*lhsContractingDimensions=*/{lhsContractingDim}, /*lhsContractingDimensions=*/{lhsContractingDim},
/*rhsContractingDimensions=*/{rhsContractingDim}); /*rhsContractingDimensions=*/{rhsContractingDim});
auto resultTy = OpConversionPattern<AtenOpT>::getTypeConverter() auto resultTy = ConvertAtenOp<AtenOpT>::getTypeConverter()
->convertType(op.getType()) ->convertType(op.getType())
.template cast<RankedTensorType>(); .template cast<RankedTensorType>();
@ -215,7 +223,7 @@ public:
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>( rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
op, op,
OpConversionPattern<AtenOpT>::getTypeConverter() ConvertAtenOp<AtenOpT>::getTypeConverter()
->convertType(op.getType()) ->convertType(op.getType())
.template cast<RankedTensorType>(), .template cast<RankedTensorType>(),
output); output);
@ -340,7 +348,10 @@ public:
auto rhsTy = rhs.getType().cast<RankedTensorType>(); auto rhsTy = rhs.getType().cast<RankedTensorType>();
auto leadingRank = std::max(lhsTy.getRank() - rhsTy.getRank(), auto leadingRank = std::max(lhsTy.getRank() - rhsTy.getRank(),
rhsTy.getRank() - lhsTy.getRank()); rhsTy.getRank() - lhsTy.getRank());
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank);
const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank,
options.dimSizeIndexBits);
auto resultRank = std::max(lhsTy.getRank(), rhsTy.getRank()); auto resultRank = std::max(lhsTy.getRank(), rhsTy.getRank());
auto nBatchDims = resultRank - 2; auto nBatchDims = resultRank - 2;
auto batchDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, nBatchDims)); auto batchDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, nBatchDims));
@ -356,8 +367,7 @@ public:
/*rhsContractingDimensions=*/{rhsContractingDim}); /*rhsContractingDimensions=*/{rhsContractingDim});
auto resultTy = auto resultTy =
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType( ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType());
op.getType());
Value matmulOutput = rewriter.create<mhlo::DotGeneralOp>( Value matmulOutput = rewriter.create<mhlo::DotGeneralOp>(
op->getLoc(), resultTy, lhs, rhs, dotDimensionNumbers, nullptr); op->getLoc(), resultTy, lhs, rhs, dotDimensionNumbers, nullptr);
@ -377,12 +387,9 @@ public:
} }
}; };
} // namespace class ConvertAtenConvolutionOp : public ConvertAtenOp<AtenConvolutionOp> {
namespace {
class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
public: public:
using OpConversionPattern<AtenConvolutionOp>::OpConversionPattern; using ConvertAtenOp<AtenConvolutionOp>::ConvertAtenOp;
using OpAdaptor = typename AtenConvolutionOp::Adaptor; using OpAdaptor = typename AtenConvolutionOp::Adaptor;
Value reshapeConvWeight(PatternRewriter &rewriter, Operation *op, Value reshapeConvWeight(PatternRewriter &rewriter, Operation *op,
@ -390,8 +397,9 @@ public:
auto weightTy = weight.getType().cast<RankedTensorType>(); auto weightTy = weight.getType().cast<RankedTensorType>();
auto weightElemTy = weightTy.getElementType(); auto weightElemTy = weightTy.getElementType();
auto rank = weightTy.getRank(); auto rank = weightTy.getRank();
SmallVector<Value> weightShapeVec = const auto &options = getOptions();
*mhlo::getDimSizesOfTensor(rewriter, op, weight); SmallVector<Value> weightShapeVec = *mhlo::getDimSizesOfTensor(
rewriter, op, weight, options.dimSizeIndexBits);
auto weightShape = weightTy.getShape(); auto weightShape = weightTy.getShape();
SmallVector<int64_t> weightShapeInt(rank); SmallVector<int64_t> weightShapeInt(rank);
std::copy(weightShape.begin(), weightShape.end(), weightShapeInt.begin()); std::copy(weightShape.begin(), weightShape.end(), weightShapeInt.begin());
@ -601,9 +609,8 @@ public:
return mhloConvOp.getResult(); return mhloConvOp.getResult();
} }
LogicalResult LogicalResult matchAndRewrite(AtenConvolutionOp op, OpAdaptor adaptor,
matchAndRewrite(AtenConvolutionOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const {
ConversionPatternRewriter &rewriter) const override {
Value input = adaptor.input(); Value input = adaptor.input();
Value weight = adaptor.weight(); Value weight = adaptor.weight();
@ -714,7 +721,10 @@ public:
// Reshape and promote bias // Reshape and promote bias
auto inputUnsqzDims = auto inputUnsqzDims =
llvm::to_vector<4>(llvm::seq<int64_t>(-nSpatialDims, 0)); llvm::to_vector<4>(llvm::seq<int64_t>(-nSpatialDims, 0));
bias = *mhlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims);
const auto &options = getOptions();
bias = *mhlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims,
options.dimSizeIndexBits);
bias = mhlo::promoteType(rewriter, bias, outTy); bias = mhlo::promoteType(rewriter, bias, outTy);
DenseIntElementsAttr bcastDimensions; DenseIntElementsAttr bcastDimensions;
@ -727,31 +737,31 @@ public:
void mlir::torch::torch_to_mhlo::populateLinearOpPatternsAndLegality( void mlir::torch::torch_to_mhlo::populateLinearOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns, TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) { ConversionTarget &target, const TorchToMhloOptions &options) {
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
#define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \ #define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenMatMulOp<AtenOp>>(typeConverter, context); patterns.add<ConvertAtenMatMulOp<AtenOp>>(typeConverter, context, options)
INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp); INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp);
#undef INSERT_MATMUL_ATEMOP_PATTERN #undef INSERT_MATMUL_ATEMOP_PATTERN
#define INSERT_MM_ATENOP_PATTERN(AtenOp) \ #define INSERT_MM_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenMmOp<AtenOp>>(typeConverter, context); patterns.add<ConvertAtenMmOp<AtenOp>>(typeConverter, context, options)
INSERT_MM_ATENOP_PATTERN(AtenMmOp); INSERT_MM_ATENOP_PATTERN(AtenMmOp);
INSERT_MM_ATENOP_PATTERN(AtenBmmOp); INSERT_MM_ATENOP_PATTERN(AtenBmmOp);
#undef INSERT_MM_ATEMOP_PATTERN #undef INSERT_MM_ATEMOP_PATTERN
#define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \ #define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenLinearOp<AtenOp>>(typeConverter, context); patterns.add<ConvertAtenLinearOp<AtenOp>>(typeConverter, context, options)
INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp); INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp);
#undef INSERT_LINEAR_ATEMOP_PATTERN #undef INSERT_LINEAR_ATEMOP_PATTERN
#define INSERT_CONVOLUTION_ATENOP_PATTERN(AtenOp) \ #define INSERT_CONVOLUTION_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenConvolutionOp>(typeConverter, context); patterns.add<ConvertAtenConvolutionOp>(typeConverter, context, options)
INSERT_CONVOLUTION_ATENOP_PATTERN(AtenConvolutionOp); INSERT_CONVOLUTION_ATENOP_PATTERN(AtenConvolutionOp);
#undef INSERT_CONVOLUTION_ATENOP_PATTERN #undef INSERT_CONVOLUTION_ATENOP_PATTERN
} }

View File

@ -259,9 +259,10 @@ SmallVector<size_t> toPositiveDims(ArrayRef<int64_t> dims, int64_t rank) {
return posDims; return posDims;
} }
FailureOr<SmallVector<Value, 4>> FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value, Operation *op, Value value,
ArrayRef<int64_t> inpDims) { ArrayRef<int64_t> inpDims,
size_t dimSizeIndexBits) {
auto valueTy = value.getType().dyn_cast<RankedTensorType>(); auto valueTy = value.getType().dyn_cast<RankedTensorType>();
if (!valueTy) { if (!valueTy) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -276,14 +277,15 @@ getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value,
auto loc = op->getLoc(); auto loc = op->getLoc();
for (auto d : dims) { for (auto d : dims) {
dimSizes.emplace_back(rewriter.create<arith::IndexCastOp>( dimSizes.emplace_back(rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIntegerType(kMhloDimSizeBits), loc, rewriter.getIntegerType(dimSizeIndexBits),
rewriter.create<tensor::DimOp>(loc, value, d))); rewriter.create<tensor::DimOp>(loc, value, d)));
} }
return dimSizes; return dimSizes;
} }
FailureOr<SmallVector<Value, 4>> FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value) { Operation *op, Value value,
size_t dimSizeIndexBits) {
auto valueTy = value.getType().dyn_cast<RankedTensorType>(); auto valueTy = value.getType().dyn_cast<RankedTensorType>();
if (!valueTy) { if (!valueTy) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -294,12 +296,12 @@ getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value) {
// Get int vector [0, 1, ..., rank-1] // Get int vector [0, 1, ..., rank-1]
std::vector<int64_t> dims(rank); std::vector<int64_t> dims(rank);
std::iota(dims.begin(), dims.end(), 0); std::iota(dims.begin(), dims.end(), 0);
return getDimSizesOfTensor(rewriter, op, value, dims); return getDimSizesOfTensor(rewriter, op, value, dims, dimSizeIndexBits);
} }
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op, FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
Value tensor, Value tensor, ArrayRef<int64_t> inputUnsqzDims,
ArrayRef<int64_t> inputUnsqzDims) { size_t dimSizeIndexBits) {
// Returns a new tensor with dims of size 1 inserted at the specified // Returns a new tensor with dims of size 1 inserted at the specified
// position. // position.
// //
@ -307,7 +309,8 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
// tensor) are specified with unsqzDims. Indices must be in-order, and in // tensor) are specified with unsqzDims. Indices must be in-order, and in
// range of tensor rank. Thus, unsqueeze a rank 1 tensor with {0, 2}, {0, 1, // range of tensor rank. Thus, unsqueeze a rank 1 tensor with {0, 2}, {0, 1,
// 3}, {0, 1, 2} are all valid dimension sets, but {0, 3}, {2} are not. // 3}, {0, 1, 2} are all valid dimension sets, but {0, 3}, {2} are not.
auto dimSizesInfo = getDimSizesOfTensor(rewriter, op, tensor); auto dimSizesInfo =
getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits);
if (failed(dimSizesInfo)) if (failed(dimSizesInfo))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input"); op, "failed to get dimension sizes of the input");
@ -324,7 +327,7 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
auto loc = op->getLoc(); auto loc = op->getLoc();
auto rankTy = tensor.getType().dyn_cast<RankedTensorType>(); auto rankTy = tensor.getType().dyn_cast<RankedTensorType>();
auto oldShape = rankTy.getShape(); auto oldShape = rankTy.getShape();
Type intType = rewriter.getIntegerType(kMhloDimSizeBits); Type intType = rewriter.getIntegerType(dimSizeIndexBits);
auto one = rewriter.create<arith::ConstantOp>( auto one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, 1)); loc, rewriter.getIntegerAttr(intType, 1));

View File

@ -19,11 +19,6 @@
namespace mlir { namespace mlir {
namespace mhlo { namespace mhlo {
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
static constexpr size_t kMhloDimSizeBits = 32;
#else
static constexpr size_t kMhloDimSizeBits = 64;
#endif
using mlir::ConversionPatternRewriter; using mlir::ConversionPatternRewriter;
@ -60,22 +55,23 @@ SmallVector<size_t> toPositiveDims(ArrayRef<int64_t> dims, int64_t rank);
// Get the dimension sizes of the input tensor, given the dimension axes // Get the dimension sizes of the input tensor, given the dimension axes
FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter, FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
Operation *op, Value value, Operation *op, Value value,
ArrayRef<int64_t> inpDims); ArrayRef<int64_t> inpDims,
size_t dimSizeIndexBits);
// Get the dimension sizes of the input tensor // Get the dimension sizes of the input tensor
FailureOr<SmallVector<Value, 4>> FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value); Operation *op, Value value,
size_t dimSizeIndexBits);
// Get a tensor that unsqueezed the specified dimensions of the input tensor // Get a tensor that unsqueezed the specified dimensions of the input tensor
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op, FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
Value tensor, Value tensor, ArrayRef<int64_t> inputUnsqzDims,
ArrayRef<int64_t> inputUnsqzDims); size_t dimSizeIndexBits);
Value getConstantOfShape(PatternRewriter &rewriter, Location loc, Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
const APFloat &constant, Value shape, const APFloat &constant, Value shape,
TensorType outType); TensorType outType);
} // namespace mhlo } // namespace mhlo
} // namespace mlir } // namespace mlir
#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H #endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H

View File

@ -28,6 +28,7 @@
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_mhlo;
static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
PatternRewriter &rewriter) { PatternRewriter &rewriter) {
@ -72,22 +73,9 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
return nullptr; return nullptr;
} }
namespace {
template <typename AtenOpT>
class ConvertAtenPoolingOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
// AtenMaxPool2dOp // AtenMaxPool2dOp
namespace {
template <> template <>
LogicalResult ConvertAtenPoolingOp<AtenMaxPool2dOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
AtenMaxPool2dOp op, OpAdaptor adaptor, AtenMaxPool2dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value input = adaptor.self(); Value input = adaptor.self();
@ -186,12 +174,10 @@ LogicalResult ConvertAtenPoolingOp<AtenMaxPool2dOp>::matchAndRewrite(
rewriter.replaceOp(op, reduceWindowOp.getResults()); rewriter.replaceOp(op, reduceWindowOp.getResults());
return success(); return success();
} }
} // namespace
// AtenMaxPool2dWithIndicesOp // AtenMaxPool2dWithIndicesOp
namespace {
template <> template <>
LogicalResult ConvertAtenPoolingOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
AtenMaxPool2dWithIndicesOp op, OpAdaptor adaptor, AtenMaxPool2dWithIndicesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value input = adaptor.self(); Value input = adaptor.self();
@ -269,7 +255,9 @@ LogicalResult ConvertAtenPoolingOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
rewriter.getI64Type()), rewriter.getI64Type()),
mhloPadding); mhloPadding);
auto inputShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input); const auto &options = getOptions();
auto inputShapeInfo =
mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(inputShapeInfo)) { if (failed(inputShapeInfo)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input"); op, "failed to get dimension sizes of the input");
@ -379,12 +367,10 @@ LogicalResult ConvertAtenPoolingOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
rewriter.replaceOp(op, reduceWindowOp.getResults()); rewriter.replaceOp(op, reduceWindowOp.getResults());
return success(); return success();
} }
} // namespace
// AtenAvgPool2dOp // AtenAvgPool2dOp
namespace {
template <> template <>
LogicalResult ConvertAtenPoolingOp<AtenAvgPool2dOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
AtenAvgPool2dOp op, OpAdaptor adaptor, AtenAvgPool2dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
Value input = adaptor.self(); Value input = adaptor.self();
@ -502,7 +488,9 @@ LogicalResult ConvertAtenPoolingOp<AtenAvgPool2dOp>::matchAndRewrite(
Value windowSizeConst = Value windowSizeConst =
mhlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value(); mhlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
windowSizeConst = mhlo::promoteType(rewriter, windowSizeConst, outTy); windowSizeConst = mhlo::promoteType(rewriter, windowSizeConst, outTy);
auto inputShapeVec = *mhlo::getDimSizesOfTensor(rewriter, op, input); const auto &options = getOptions();
auto inputShapeVec =
*mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>( auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), inputShapeVec); op->getLoc(), inputShapeVec);
@ -540,17 +528,15 @@ LogicalResult ConvertAtenPoolingOp<AtenAvgPool2dOp>::matchAndRewrite(
return success(); return success();
} }
} // namespace
void mlir::torch::torch_to_mhlo::populatePoolingOpPatternsAndLegality( void mlir::torch::torch_to_mhlo::populatePoolingOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns, TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) { ConversionTarget &target, const TorchToMhloOptions &options) {
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenMaxPool2dOp>(); target.addIllegalOp<AtenMaxPool2dOp>();
patterns.add<ConvertAtenPoolingOp<AtenMaxPool2dOp>>(typeConverter, context); patterns.add<ConvertAtenOp<AtenMaxPool2dOp>>(typeConverter, context, options);
target.addIllegalOp<AtenAvgPool2dOp>(); target.addIllegalOp<AtenAvgPool2dOp>();
patterns.add<ConvertAtenPoolingOp<AtenAvgPool2dOp>>(typeConverter, context); patterns.add<ConvertAtenOp<AtenAvgPool2dOp>>(typeConverter, context, options);
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>(); target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
patterns.add<ConvertAtenPoolingOp<AtenMaxPool2dWithIndicesOp>>(typeConverter, patterns.add<ConvertAtenOp<AtenMaxPool2dWithIndicesOp>>(typeConverter,
context); context, options);
} }

View File

@ -16,25 +16,56 @@ namespace mlir {
namespace torch { namespace torch {
namespace torch_to_mhlo { namespace torch_to_mhlo {
struct TorchToMhloOptions {
bool enableStaticShape = false;
size_t dimSizeIndexBits = 64;
};
template <typename AtenOpT>
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
public:
using OpAdaptor = typename AtenOpT::Adaptor;
ConvertAtenOp(TypeConverter &typeConverter, MLIRContext *context,
const TorchToMhloOptions &options)
: OpConversionPattern<AtenOpT>(typeConverter, context) {
this->options = options;
}
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return rewriter.notifyMatchFailure(op, "haven't been implemented");
}
const TorchToMhloOptions &getOptions() const { return options; }
private:
TorchToMhloOptions options;
};
void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter, void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns, RewritePatternSet &patterns,
ConversionTarget &target); ConversionTarget &target,
const TorchToMhloOptions &options);
void populateViewLikeOpPatternsAndLegality(TypeConverter &typeConverter, void populateViewLikeOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns, RewritePatternSet &patterns,
ConversionTarget &target); ConversionTarget &target,
const TorchToMhloOptions &options);
void populateGatherOpPatternsAndLegality(TypeConverter &typeConverter, void populateGatherOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns, RewritePatternSet &patterns,
ConversionTarget &target); ConversionTarget &target,
const TorchToMhloOptions &options);
void populateReductionOpPatternsAndLegality(TypeConverter &typeConverter, void populateReductionOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns, RewritePatternSet &patterns,
ConversionTarget &target); ConversionTarget &target,
const TorchToMhloOptions &options);
void populateLinearOpPatternsAndLegality(TypeConverter &typeConverter, void populateLinearOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns, RewritePatternSet &patterns,
ConversionTarget &target); ConversionTarget &target,
const TorchToMhloOptions &options);
void populatePoolingOpPatternsAndLegality(TypeConverter &typeConverter, void populatePoolingOpPatternsAndLegality(TypeConverter &typeConverter,
RewritePatternSet &patterns, RewritePatternSet &patterns,
ConversionTarget &target); ConversionTarget &target,
const TorchToMhloOptions &options);
} // namespace torch_to_mhlo } // namespace torch_to_mhlo
} // namespace torch } // namespace torch

View File

@ -25,6 +25,7 @@
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_mhlo;
static Value createInitialValueForReduceOp(Operation *op, Type elementTy, static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
PatternRewriter &rewriter) { PatternRewriter &rewriter) {
@ -72,7 +73,8 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
// Util for converting AtenArgmaxOp and AtenMaxDimOp // Util for converting AtenArgmaxOp and AtenMaxDimOp
static llvm::Optional<ValueRange> static llvm::Optional<ValueRange>
getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
ArrayRef<Value> inputShapeVec, int64_t dim) { ArrayRef<Value> inputShapeVec, int64_t dim,
size_t dimSizeIndexBits) {
auto inputTy = input.getType().template cast<RankedTensorType>(); auto inputTy = input.getType().template cast<RankedTensorType>();
if (!inputTy) { if (!inputTy) {
return llvm::None; return llvm::None;
@ -86,7 +88,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
Value initValue = createInitialValueForReduceOp(op, inputElemTy, rewriter); Value initValue = createInitialValueForReduceOp(op, inputElemTy, rewriter);
if (!initValue) return llvm::None; if (!initValue) return llvm::None;
Value initIndex; Value initIndex;
if (mlir::mhlo::kMhloDimSizeBits == 32) { if (dimSizeIndexBits == 32) {
initIndex = mhlo::getConstTensor<int32_t>(rewriter, op, {0}, {}).value(); initIndex = mhlo::getConstTensor<int32_t>(rewriter, op, {0}, {}).value();
} else { } else {
initIndex = mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value(); initIndex = mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
@ -98,7 +100,9 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>( auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), inputShapeVec); op->getLoc(), inputShapeVec);
auto indexTensor = rewriter.create<mhlo::DynamicIotaOp>( auto indexTensor = rewriter.create<mhlo::DynamicIotaOp>(
op->getLoc(), RankedTensorType::get(inputShape, rewriter.getIntegerType(mlir::mhlo::kMhloDimSizeBits)), op->getLoc(),
RankedTensorType::get(inputShape,
rewriter.getIntegerType(dimSizeIndexBits)),
inputShapeTensor, static_cast<uint64_t>(dim)); inputShapeTensor, static_cast<uint64_t>(dim));
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>( auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
@ -114,7 +118,8 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
// Add block arguments // Add block arguments
auto blockValArgumentType = auto blockValArgumentType =
RankedTensorType::get({}, inputTy.getElementType()); RankedTensorType::get({}, inputTy.getElementType());
auto blockIdxArgumentType = RankedTensorType::get({}, rewriter.getIntegerType(mlir::mhlo::kMhloDimSizeBits)); auto blockIdxArgumentType =
RankedTensorType::get({}, rewriter.getIntegerType(dimSizeIndexBits));
auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type()); auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type());
block.addArgument(blockValArgumentType, op->getLoc()); block.addArgument(blockValArgumentType, op->getLoc());
block.addArgument(blockIdxArgumentType, op->getLoc()); block.addArgument(blockIdxArgumentType, op->getLoc());
@ -171,9 +176,9 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
namespace { namespace {
template <typename AtenOpT> template <typename AtenOpT>
class ConvertAtenReductionOp : public OpConversionPattern<AtenOpT> { class ConvertAtenReductionOp : public ConvertAtenOp<AtenOpT> {
public: public:
using OpConversionPattern<AtenOpT>::OpConversionPattern; using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
using OpAdaptor = typename AtenOpT::Adaptor; using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor, matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
@ -220,21 +225,24 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
} }
auto inputShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input); const auto &options = getOptions();
auto inputShapeInfo =
mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(inputShapeInfo)) { if (failed(inputShapeInfo)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input"); op, "failed to get dimension sizes of the input");
} }
auto inputShapeVec = *inputShapeInfo; auto inputShapeVec = *inputShapeInfo;
auto mhloReduceResults = auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim,
getMaxInDim(rewriter, op, input, inputShapeVec, dim).value(); options.dimSizeIndexBits)
.value();
if (keepDim) { if (keepDim) {
auto outShapeVec = inputShapeVec; auto outShapeVec = inputShapeVec;
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>( outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(), rewriter.getIntegerAttr( op->getLoc(),
rewriter.getIntegerType(mlir::mhlo::kMhloDimSizeBits), 1)); rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>( auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec); op->getLoc(), outShapeVec);
@ -297,20 +305,24 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
} }
auto inputShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input); const auto &options = getOptions();
auto inputShapeInfo =
mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
if (failed(inputShapeInfo)) { if (failed(inputShapeInfo)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input"); op, "failed to get dimension sizes of the input");
} }
auto inputShapeVec = *inputShapeInfo; auto inputShapeVec = *inputShapeInfo;
auto mhloReduceResults = auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim,
getMaxInDim(rewriter, op, input, inputShapeVec, dim).value(); options.dimSizeIndexBits)
.value();
if (keepDim) { if (keepDim) {
auto outShapeVec = inputShapeVec; auto outShapeVec = inputShapeVec;
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>( outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(), rewriter.getIntegerAttr( op->getLoc(),
rewriter.getIntegerType(mhlo::kMhloDimSizeBits), 1)); rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>( auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec); op->getLoc(), outShapeVec);
@ -532,15 +544,18 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
} }
if (keepDim) { if (keepDim) {
auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input); const auto &options = getOptions();
auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input,
options.dimSizeIndexBits);
if (failed(outShapeInfo)) { if (failed(outShapeInfo)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input"); op, "failed to get dimension sizes of the input");
} }
auto outShapeVec = *outShapeInfo; auto outShapeVec = *outShapeInfo;
auto one = rewriter.create<mlir::arith::ConstantOp>( auto one = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(), rewriter.getIntegerAttr( op->getLoc(),
rewriter.getIntegerType(mhlo::kMhloDimSizeBits), 1)); rewriter.getIntegerAttr(
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
for (int64_t i : dims) { for (int64_t i : dims) {
outShapeVec[i] = one; outShapeVec[i] = one;
} }
@ -558,11 +573,11 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
void mlir::torch::torch_to_mhlo::populateReductionOpPatternsAndLegality( void mlir::torch::torch_to_mhlo::populateReductionOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns, TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) { ConversionTarget &target, const TorchToMhloOptions &options) {
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
#define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \ #define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenReductionOp<AtenOp>>(typeConverter, context); patterns.add<ConvertAtenReductionOp<AtenOp>>(typeConverter, context, options)
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenArgmaxOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenArgmaxOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxDimOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxDimOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp);

View File

@ -32,6 +32,12 @@ namespace {
class ConvertTorchToMhlo : public ConvertTorchToMhloBase<ConvertTorchToMhlo> { class ConvertTorchToMhlo : public ConvertTorchToMhloBase<ConvertTorchToMhlo> {
public: public:
ConvertTorchToMhlo() = default;
ConvertTorchToMhlo(bool enableStaticShape, bool enableI32Index) {
this->enableStaticShape = enableStaticShape;
this->enableI32Index = enableI32Index;
}
void getDependentDialects(DialectRegistry &registry) const override { void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<chlo::ChloDialect>(); registry.insert<chlo::ChloDialect>();
registry.insert<mhlo::MhloDialect>(); registry.insert<mhlo::MhloDialect>();
@ -51,18 +57,20 @@ public:
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
torch_to_mhlo::TorchToMhloOptions options{enableStaticShape,
enableI32Index ? 32u : 64u};
torch_to_mhlo::populateBasicOpPatternsAndLegality(typeConverter, patterns, torch_to_mhlo::populateBasicOpPatternsAndLegality(typeConverter, patterns,
target); target, options);
torch_to_mhlo::populateViewLikeOpPatternsAndLegality(typeConverter, torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
patterns, target); typeConverter, patterns, target, options);
torch_to_mhlo::populateGatherOpPatternsAndLegality(typeConverter, patterns, torch_to_mhlo::populateGatherOpPatternsAndLegality(typeConverter, patterns,
target); target, options);
torch_to_mhlo::populateReductionOpPatternsAndLegality(typeConverter, torch_to_mhlo::populateReductionOpPatternsAndLegality(
patterns, target); typeConverter, patterns, target, options);
torch_to_mhlo::populateLinearOpPatternsAndLegality(typeConverter, patterns, torch_to_mhlo::populateLinearOpPatternsAndLegality(typeConverter, patterns,
target); target, options);
torch_to_mhlo::populatePoolingOpPatternsAndLegality(typeConverter, patterns, torch_to_mhlo::populatePoolingOpPatternsAndLegality(typeConverter, patterns,
target); target, options);
if (failed(applyPartialConversion(getOperation(), target, if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) { std::move(patterns)))) {
@ -75,5 +83,12 @@ public:
std::unique_ptr<OperationPass<func::FuncOp>> std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToMhloPass() { mlir::torch::createConvertTorchToMhloPass() {
return std::make_unique<ConvertTorchToMhlo>(); return std::make_unique<ConvertTorchToMhlo>(false, false);
}
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToMhloPass(bool enableStaticShape,
bool enableI32Index) {
return std::make_unique<ConvertTorchToMhlo>(enableStaticShape,
enableI32Index);
} }

View File

@ -28,6 +28,7 @@ using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
using namespace mlir::torch::TorchConversion; using namespace mlir::torch::TorchConversion;
using namespace mlir::torch::torch_to_mhlo;
namespace { namespace {
// A dimension index from torch.dialect might outside the range [0, dimSize]. // A dimension index from torch.dialect might outside the range [0, dimSize].
@ -55,10 +56,11 @@ Value getNormalizedDimSizeInternal(PatternRewriter &rewriter, Operation *op,
Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op, Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op,
Type outTy, Value input, Value startIndex, Type outTy, Value input, Value startIndex,
Value endIndex, Value step, size_t dimIndex, Value endIndex, Value step, size_t dimIndex,
ArrayRef<Value> dimSizes) { ArrayRef<Value> dimSizes,
size_t dimSizeIndexBits) {
auto loc = op->getLoc(); auto loc = op->getLoc();
// startIndex & endIndex has been normailized into range [0, dSize] // startIndex & endIndex has been normailized into range [0, dSize]
Type intType = rewriter.getIntegerType(mhlo::kMhloDimSizeBits); Type intType = rewriter.getIntegerType(dimSizeIndexBits);
Value zero = rewriter.create<arith::ConstantOp>( Value zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, 0)); loc, rewriter.getIntegerAttr(intType, 0));
Value one = rewriter.create<arith::ConstantOp>( Value one = rewriter.create<arith::ConstantOp>(
@ -109,7 +111,8 @@ FailureOr<Value> getDynamicSlice(PatternRewriter &rewriter, Operation *op,
Type outTy, Value input, Type outTy, Value input,
llvm::Optional<Value> startIndexOpt, llvm::Optional<Value> startIndexOpt,
llvm::Optional<Value> endIndexOpt, llvm::Optional<Value> endIndexOpt,
llvm::Optional<Value> stepOpt, int64_t dim) { llvm::Optional<Value> stepOpt, int64_t dim,
size_t dimSizeIndexBits) {
auto loc = op->getLoc(); auto loc = op->getLoc();
auto inputTy = input.getType().dyn_cast<RankedTensorType>(); auto inputTy = input.getType().dyn_cast<RankedTensorType>();
auto rank = inputTy.getRank(); auto rank = inputTy.getRank();
@ -133,77 +136,31 @@ FailureOr<Value> getDynamicSlice(PatternRewriter &rewriter, Operation *op,
: rewriter.create<arith::ConstantOp>( : rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1)); loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1));
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32 if (dimSizeIndexBits == 32) {
auto i32Type = rewriter.getIntegerType(mhlo::kMhloDimSizeBits); Type intType = rewriter.getIntegerType(dimSizeIndexBits);
normStartIndex = normStartIndex =
rewriter.create<arith::TruncIOp>(loc, i32Type, normStartIndex); rewriter.create<arith::TruncIOp>(loc, intType, normStartIndex);
normEndIndex = rewriter.create<arith::TruncIOp>(loc, i32Type, normEndIndex); normEndIndex = rewriter.create<arith::TruncIOp>(loc, intType, normEndIndex);
step = rewriter.create<arith::TruncIOp>(loc, i32Type, step); step = rewriter.create<arith::TruncIOp>(loc, intType, step);
#endif }
FailureOr<SmallVector<Value, 4>> dimSizesInfo = FailureOr<SmallVector<Value, 4>> dimSizesInfo =
mhlo::getDimSizesOfTensor(rewriter, op, input); mhlo::getDimSizesOfTensor(rewriter, op, input, dimSizeIndexBits);
if (failed(dimSizesInfo)) if (failed(dimSizesInfo))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input"); op, "failed to get dimension sizes of the input");
auto dimSizes = *dimSizesInfo; auto dimSizes = *dimSizesInfo;
return getDynamicSliceInternal(rewriter, op, outTy, input, normStartIndex, return getDynamicSliceInternal(rewriter, op, outTy, input, normStartIndex,
normEndIndex, step, dim, dimSizes); normEndIndex, step, dim, dimSizes,
} dimSizeIndexBits);
template <typename AtenOpT>
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
template <>
LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
AtenSliceTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto self = adaptor.self();
auto selfTy = self.getType().template cast<RankedTensorType>();
if (!selfTy)
return op.emitError("only ranked tensor types are supported");
auto outTy =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
int64_t dim;
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
op, "only constant dim is currently supported");
auto getOptionalVal = [&](Value val) -> llvm::Optional<Value> {
if (val.getType().isa<Torch::NoneType>()) {
return llvm::None;
} else {
return val;
}
};
llvm::Optional<Value> start = getOptionalVal(adaptor.start());
llvm::Optional<Value> end = getOptionalVal(adaptor.end());
llvm::Optional<Value> step = getOptionalVal(adaptor.step());
FailureOr<Value> sliceInfo =
getDynamicSlice(rewriter, op, outTy, self, start, end, step, dim);
if (failed(sliceInfo))
return op.emitError("can not create a dynmaic slice");
auto slice = *sliceInfo;
rewriter.replaceOp(op, slice);
return success();
} }
// This defines a template to construct ops whose legalizations are // This defines a template to construct ops whose legalizations are
// specialized. // specialized.
template <typename AtenOpT> template <typename AtenOpT>
class ConvertAtenViewOp : public OpConversionPattern<AtenOpT> { class ConvertAtenViewOp : public ConvertAtenOp<AtenOpT> {
public: public:
using OpConversionPattern<AtenOpT>::OpConversionPattern; using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
using OpAdaptor = typename AtenOpT::Adaptor; using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult LogicalResult
@ -235,19 +192,19 @@ public:
return dSize; return dSize;
}); });
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32 const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
// The i64 calculation is much slower than i32 on some devices, such as Type intType = rewriter.getIntegerType(options.dimSizeIndexBits);
// Nvidia GPU. One can truncate from i64 to i32 since dimension sizes are if (options.dimSizeIndexBits == 32) {
// unlikely to exceed the range of i32(4GiB) // The i64 calculation is much slower than i32 on some devices, such as
std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value &dSize) { // Nvidia GPU. One can truncate from i64 to i32 since dimension sizes are
// dimSize: cast i64 -> i32 // unlikely to exceed the range of i32(4GiB)
dSize = std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value &dSize) {
rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), dSize); // dimSize: cast i64 -> i32
return dSize; dSize = rewriter.create<arith::TruncIOp>(loc, intType, dSize);
}); return dSize;
#endif });
}
Type intType = rewriter.getIntegerType(mhlo::kMhloDimSizeBits);
Value numel = rewriter.create<arith::ConstantOp>( Value numel = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, 1)); loc, rewriter.getIntegerAttr(intType, 1));
for (auto d : dimSizes) { for (auto d : dimSizes) {
@ -293,6 +250,45 @@ bool ConvertAtenViewOp<AtenReshapeOp>::getAtenViewOpSizes(
SmallVector<Value, 4> &dimSizes) const { SmallVector<Value, 4> &dimSizes) const {
return getListConstructElements(adaptor.shape(), dimSizes); return getListConstructElements(adaptor.shape(), dimSizes);
} }
} // namespace
template <>
LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
AtenSliceTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto self = adaptor.self();
auto selfTy = self.getType().template cast<RankedTensorType>();
if (!selfTy)
return op.emitError("only ranked tensor types are supported");
auto outTy =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
int64_t dim;
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
op, "only constant dim is currently supported");
auto getOptionalVal = [&](Value val) -> llvm::Optional<Value> {
if (val.getType().isa<Torch::NoneType>()) {
return llvm::None;
} else {
return val;
}
};
llvm::Optional<Value> start = getOptionalVal(adaptor.start());
llvm::Optional<Value> end = getOptionalVal(adaptor.end());
llvm::Optional<Value> step = getOptionalVal(adaptor.step());
FailureOr<Value> sliceInfo =
getDynamicSlice(rewriter, op, outTy, self, start, end, step, dim,
options.dimSizeIndexBits);
if (failed(sliceInfo))
return op.emitError("can not create a dynmaic slice");
auto slice = *sliceInfo;
rewriter.replaceOp(op, slice);
return success();
}
template <> template <>
LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
@ -324,7 +320,8 @@ LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
return success(); return success();
} }
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims); auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims,
options.dimSizeIndexBits);
if (failed(newDimSizesInfo)) if (failed(newDimSizesInfo))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input"); op, "failed to get dimension sizes of the input");
@ -372,7 +369,8 @@ LogicalResult ConvertAtenOp<AtenSqueezeDimOp>::matchAndRewrite(
op, getTypeConverter()->convertType(op.getType()), self); op, getTypeConverter()->convertType(op.getType()), self);
return success(); return success();
} }
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims); auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims,
options.dimSizeIndexBits);
if (failed(newDimSizesInfo)) if (failed(newDimSizesInfo))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input"); op, "failed to get dimension sizes of the input");
@ -397,8 +395,8 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
return op->emitError("dim must be a Scalar constant"); return op->emitError("dim must be a Scalar constant");
auto unsqzTensorInfo = auto unsqzTensorInfo = mhlo::unsqueezeTensor(rewriter, op, adaptor.self(),
mhlo::unsqueezeTensor(rewriter, op, adaptor.self(), {dim}); {dim}, options.dimSizeIndexBits);
if (failed(unsqzTensorInfo)) if (failed(unsqzTensorInfo))
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"failed to create unsqueezed tensor"); "failed to create unsqueezed tensor");
@ -406,16 +404,15 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
rewriter.replaceOp(op, *unsqzTensorInfo); rewriter.replaceOp(op, *unsqzTensorInfo);
return success(); return success();
} }
} // namespace
void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality( void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns, TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) { ConversionTarget &target, const TorchToMhloOptions &options) {
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
#define INSERT_ATENOP_PATTERN(AtenOp) \ #define INSERT_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context); patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
INSERT_ATENOP_PATTERN(AtenSliceTensorOp); INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
INSERT_ATENOP_PATTERN(AtenSqueezeOp); INSERT_ATENOP_PATTERN(AtenSqueezeOp);
INSERT_ATENOP_PATTERN(AtenSqueezeDimOp); INSERT_ATENOP_PATTERN(AtenSqueezeDimOp);
@ -424,7 +421,7 @@ void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
#define INSERT_VIEW_OP_PATTERN(AtenOp) \ #define INSERT_VIEW_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \ target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenViewOp<AtenOp>>(typeConverter, context); patterns.add<ConvertAtenViewOp<AtenOp>>(typeConverter, context, options)
INSERT_VIEW_OP_PATTERN(AtenViewOp); INSERT_VIEW_OP_PATTERN(AtenViewOp);
INSERT_VIEW_OP_PATTERN(AtenReshapeOp); INSERT_VIEW_OP_PATTERN(AtenReshapeOp);
#undef INSERT_VIEW_OP_PATTERN #undef INSERT_VIEW_OP_PATTERN