mirror of https://github.com/llvm/torch-mlir
[MHLO] refactor pass configurations (#1315)
Related to https://github.com/llvm/torch-mlir/issues/1227 1. Reduce MHLO #ifdefs 2. Dismiss compilation warningspull/1321/head
parent
7769eb88f8
commit
29cafdbb61
|
@ -39,14 +39,6 @@ endmacro()
|
|||
option(TORCH_MLIR_ENABLE_MHLO "Add mhlo dialect" ON)
|
||||
if(TORCH_MLIR_ENABLE_MHLO)
|
||||
add_definitions(-DTORCH_MLIR_ENABLE_MHLO)
|
||||
# The i64 calculation is much slower than i32 on some devices, such as Nvidia GPU.
|
||||
# One can truncate from i64 to i32 since dimension sizes are unlikely to exceed
|
||||
# the range of i32(4GiB)
|
||||
option(TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
|
||||
"Enable truncate dimension size from i64 to i32(unsafely)" OFF)
|
||||
if(TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32)
|
||||
add_definitions(-DTORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF)
|
||||
|
|
|
@ -132,6 +132,17 @@ def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> {
|
|||
Convert Torch ops to mhlo ops.
|
||||
}];
|
||||
let constructor = "mlir::torch::createConvertTorchToMhloPass()";
|
||||
|
||||
// Specify any options.
|
||||
let options = [
|
||||
Option<"enableStaticShape", "enable-static-shape", "bool", /*default=*/"false",
|
||||
"Enable static shape conversion">,
|
||||
// The i64 calculation is much slower than i32 on some devices, such as
|
||||
// Nvidia GPU. One can truncate from i64 to i32 since dimension sizes
|
||||
// are unlikely to exceed the range of i32(4GiB)
|
||||
Option<"enableI32Index", "enable-i32-index", "bool", /*default=*/"false",
|
||||
"Enable truncate index from i64 to i32(unsafely)">,
|
||||
];
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
namespace mlir {
|
||||
namespace torch {
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToMhloPass();
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
createConvertTorchToMhloPass(bool enableStaticShape, bool enableI32Index);
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
using namespace mlir::torch::torch_to_mhlo;
|
||||
|
||||
bool skipMultiplyAlpha(Value alphaValue) {
|
||||
double doubleValue;
|
||||
|
@ -379,26 +380,21 @@ public:
|
|||
} // namespace
|
||||
|
||||
// AtenBroadcastToOp
|
||||
namespace {
|
||||
class ConvertAtenBroadcastToOp : public OpConversionPattern<AtenBroadcastToOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenBroadcastToOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
|
||||
AtenBroadcastToOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value self = adaptor.self();
|
||||
auto selfTy = self.getType().cast<RankedTensorType>();
|
||||
auto outType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO_STATIC_SHAPE
|
||||
if (selfTy.hasStaticShape()) {
|
||||
if (options.enableStaticShape && selfTy.hasStaticShape()) {
|
||||
Value bcastOp = mhlo::promoteAndBroadcast(rewriter, self, outType);
|
||||
rewriter.replaceOp(op, bcastOp);
|
||||
return success();
|
||||
}
|
||||
#endif
|
||||
|
||||
SmallVector<Value> shape;
|
||||
if (!(getListConstructElements(adaptor.size(), shape))) {
|
||||
|
@ -428,14 +424,14 @@ public:
|
|||
bcastShapeVec.push_back(newD);
|
||||
}
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
|
||||
if (options.dimSizeIndexBits == 32) {
|
||||
for (auto &dsize : bcastShapeVec) {
|
||||
auto dsizeI64 = rewriter.create<mlir::arith::IndexCastOp>(
|
||||
op->getLoc(), rewriter.getI64Type(), dsize);
|
||||
dsize = rewriter.create<arith::TruncIOp>(op->getLoc(),
|
||||
rewriter.getI32Type(), dsizeI64);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
Value bcastShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), ValueRange{bcastShapeVec});
|
||||
|
@ -445,18 +441,13 @@ public:
|
|||
op, outType, self, bcastShapeTensor,
|
||||
rewriter.getI64TensorAttr(dimensionNumbers));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
}
|
||||
|
||||
// AtenPermuteOp
|
||||
namespace {
|
||||
class ConvertAtenPermuteOp : public OpConversionPattern<AtenPermuteOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenPermuteOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
|
||||
AtenPermuteOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value self = adaptor.self();
|
||||
// Not a ranked tensor type
|
||||
auto inType = self.getType().dyn_cast<RankedTensorType>();
|
||||
|
@ -486,25 +477,9 @@ public:
|
|||
rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(op, outType, self,
|
||||
permutation);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
} // namespace
|
||||
}
|
||||
|
||||
// AtenTanhOp
|
||||
namespace {
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
|
||||
AtenTanhOp op, OpAdaptor adaptor,
|
||||
|
@ -520,10 +495,8 @@ LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
|
|||
"only floating-point datatype legalization currently supported");
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// ValueTensorLiteralOp
|
||||
namespace {
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
|
||||
ValueTensorLiteralOp op, OpAdaptor adaptor,
|
||||
|
@ -553,11 +526,9 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// AtenReciprocalOp
|
||||
// Reciprocal(x) = Div(1, x)
|
||||
namespace {
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
|
||||
AtenReciprocalOp op, OpAdaptor adaptor,
|
||||
|
@ -575,10 +546,8 @@ LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
|
|||
rewriter.replaceOpWithNewOp<mhlo::DivOp>(op, outTy, oneTensor, input);
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// PrimNumToTensorScalarOp
|
||||
namespace {
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
|
||||
PrimNumToTensorScalarOp op, OpAdaptor adaptor,
|
||||
|
@ -592,11 +561,9 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
|
|||
rewriter.replaceOp(op, mhloTensor);
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// AtenContiguousOp
|
||||
// Ref: TosaToTosa.cpp for implementation details
|
||||
namespace {
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenContiguousOp>::matchAndRewrite(
|
||||
AtenContiguousOp op, OpAdaptor adaptor,
|
||||
|
@ -614,11 +581,9 @@ LogicalResult ConvertAtenOp<AtenContiguousOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// AtenReluOp
|
||||
// Relu(x) = Max(0, x)
|
||||
namespace {
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
|
||||
AtenReluOp op, OpAdaptor adaptor,
|
||||
|
@ -641,11 +606,9 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Convert a Aten::GELU to HLO
|
||||
// Gelu(x) = x * 1/2 * [1 + erf(x/(sqrt(2)))]
|
||||
namespace {
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
|
||||
AtenGeluOp op, OpAdaptor adaptor,
|
||||
|
@ -668,10 +631,8 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
|
|||
rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, input, halfMul);
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// AtenErfOp
|
||||
namespace {
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenErfOp>::matchAndRewrite(
|
||||
AtenErfOp op, OpAdaptor adaptor,
|
||||
|
@ -686,10 +647,8 @@ LogicalResult ConvertAtenOp<AtenErfOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// AtenBatchNormOp
|
||||
namespace {
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
||||
AtenBatchNormOp op, OpAdaptor adaptor,
|
||||
|
@ -716,12 +675,12 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
|
||||
Value channelDim = rewriter.create<tensor::DimOp>(op->getLoc(), input, 1);
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
|
||||
if (options.dimSizeIndexBits == 32) {
|
||||
auto channelDimI64 = rewriter.create<mlir::arith::IndexCastOp>(
|
||||
op->getLoc(), rewriter.getI64Type(), channelDim);
|
||||
channelDim = rewriter.create<arith::TruncIOp>(
|
||||
op->getLoc(), rewriter.getI32Type(), channelDimI64);
|
||||
#endif
|
||||
}
|
||||
|
||||
Value channelShape = rewriter.create<tensor::FromElementsOp>(
|
||||
op->getLoc(), ValueRange{channelDim});
|
||||
|
@ -806,10 +765,8 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// AtenNativeLayerNormOp
|
||||
namespace {
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
||||
AtenNativeLayerNormOp op, OpAdaptor adaptor,
|
||||
|
@ -949,10 +906,8 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// AtenCatOp
|
||||
namespace {
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
||||
AtenCatOp op, OpAdaptor adaptor,
|
||||
|
@ -983,10 +938,8 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
|||
op, outType, ValueRange(builtinTensors), posDim);
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// AtenNumelOp
|
||||
namespace {
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenNumelOp>::matchAndRewrite(
|
||||
AtenNumelOp op,
|
||||
|
@ -996,7 +949,7 @@ LogicalResult ConvertAtenOp<AtenNumelOp>::matchAndRewrite(
|
|||
auto selfTy = self.getType().dyn_cast<RankedTensorType>();
|
||||
size_t rank = selfTy.getRank();
|
||||
|
||||
Type intType = rewriter.getIntegerType(mhlo::kMhloDimSizeBits);
|
||||
Type intType = rewriter.getIntegerType(options.dimSizeIndexBits);
|
||||
auto loc = op->getLoc();
|
||||
Value numel =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(intType, 1));
|
||||
|
@ -1015,26 +968,18 @@ LogicalResult ConvertAtenOp<AtenNumelOp>::matchAndRewrite(
|
|||
}
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target) {
|
||||
ConversionTarget &target, const TorchToMhloOptions &options) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
|
||||
target.addIllegalOp<AtenTransposeIntOp>();
|
||||
patterns.add<ConvertAtenTransposeIntOp>(typeConverter, context);
|
||||
|
||||
target.addIllegalOp<AtenBroadcastToOp>();
|
||||
patterns.add<ConvertAtenBroadcastToOp>(typeConverter, context);
|
||||
|
||||
target.addIllegalOp<AtenPermuteOp>();
|
||||
patterns.add<ConvertAtenPermuteOp>(typeConverter, context);
|
||||
|
||||
#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, MhloOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, MhloOp>>(typeConverter, \
|
||||
context);
|
||||
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, MhloOp>>(typeConverter, context)
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, mhlo::LogOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, mhlo::ExpOp);
|
||||
INSERT_UNARY_FPONLY_PATTERN(AtenCloneOp, mhlo::CopyOp);
|
||||
|
@ -1045,14 +990,14 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
|||
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenConstPatternOp<AtenOp, fillVal>>(typeConverter, \
|
||||
context);
|
||||
context)
|
||||
INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1);
|
||||
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
|
||||
#undef INSERT_CONSTANT_FILL_PATTERN
|
||||
|
||||
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, ChloOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenAddSubOp<AtenOp, ChloOp>>(typeConverter, context);
|
||||
patterns.add<ConvertAtenAddSubOp<AtenOp, ChloOp>>(typeConverter, context)
|
||||
INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, chlo::BroadcastAddOp);
|
||||
INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, chlo::BroadcastAddOp);
|
||||
INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, chlo::BroadcastSubOp);
|
||||
|
@ -1062,7 +1007,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
|||
|
||||
#define INSERT_BINARY_MULDIV_PATTERN(AtenOp, ChloOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenMulDivOp<AtenOp, ChloOp>>(typeConverter, context);
|
||||
patterns.add<ConvertAtenMulDivOp<AtenOp, ChloOp>>(typeConverter, context)
|
||||
INSERT_BINARY_MULDIV_PATTERN(AtenMulTensorOp, chlo::BroadcastMulOp);
|
||||
INSERT_BINARY_MULDIV_PATTERN(AtenMulScalarOp, chlo::BroadcastMulOp);
|
||||
INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorOp, chlo::BroadcastDivOp);
|
||||
|
@ -1072,7 +1017,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
|||
|
||||
#define INSERT_BINARY_COMPARE_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenCompareOp<AtenOp>>(typeConverter, context);
|
||||
patterns.add<ConvertAtenCompareOp<AtenOp>>(typeConverter, context)
|
||||
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp);
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp);
|
||||
|
@ -1086,7 +1031,11 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
|||
|
||||
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
|
||||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
|
||||
|
||||
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
|
||||
INSERT_ATENOP_PATTERN(AtenPermuteOp);
|
||||
|
||||
INSERT_ATENOP_PATTERN(AtenTanhOp);
|
||||
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
|
||||
INSERT_ATENOP_PATTERN(AtenReciprocalOp);
|
||||
|
|
|
@ -23,18 +23,14 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
|
||||
static constexpr size_t kMhloDimSizeBits = 32;
|
||||
#else
|
||||
static constexpr size_t kMhloDimSizeBits = 64;
|
||||
#endif
|
||||
using namespace mlir::torch::torch_to_mhlo;
|
||||
|
||||
namespace {
|
||||
Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
|
||||
Value input, Value indices, int64_t axis) {
|
||||
Value input, Value indices, int64_t axis,
|
||||
size_t dimSizeIndexBits) {
|
||||
auto loc = op->getLoc();
|
||||
Type intType = rewriter.getIntegerType(kMhloDimSizeBits);
|
||||
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
|
||||
Value one = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(intType, 1));
|
||||
|
||||
|
@ -98,16 +94,7 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
|
|||
sliceSizesTensor, dimsAttr)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html
|
||||
// padding_idx (int, optional)
|
||||
|
@ -149,8 +136,8 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "sparse gradients is currently not supported");
|
||||
|
||||
Value output =
|
||||
gatherTensorAlongSingleAxis(rewriter, op, weight, adaptor.indices(), 0);
|
||||
Value output = gatherTensorAlongSingleAxis(
|
||||
rewriter, op, weight, adaptor.indices(), 0, options.dimSizeIndexBits);
|
||||
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), output);
|
||||
|
||||
|
@ -170,24 +157,23 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "only constant dim is currently supported");
|
||||
|
||||
Value output =
|
||||
gatherTensorAlongSingleAxis(rewriter, op, self, adaptor.index(), dim);
|
||||
Value output = gatherTensorAlongSingleAxis(
|
||||
rewriter, op, self, adaptor.index(), dim, options.dimSizeIndexBits);
|
||||
|
||||
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), output);
|
||||
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void mlir::torch::torch_to_mhlo::populateGatherOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target) {
|
||||
ConversionTarget &target, const TorchToMhloOptions &options) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
|
||||
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
|
||||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
|
||||
INSERT_ATENOP_PATTERN(AtenEmbeddingOp);
|
||||
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
using namespace mlir::torch::torch_to_mhlo;
|
||||
|
||||
namespace {
|
||||
Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
|
||||
|
@ -71,7 +72,8 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input,
|
|||
}
|
||||
|
||||
void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
|
||||
Value &inpRhs, int64_t leadingRank) {
|
||||
Value &inpRhs, int64_t leadingRank,
|
||||
size_t dimSizeIndexBits) {
|
||||
Value lhs = inpLhs;
|
||||
Value rhs = inpRhs;
|
||||
auto lhsRankTy = inpLhs.getType().dyn_cast<RankedTensorType>();
|
||||
|
@ -92,9 +94,10 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
|
|||
std::vector<int64_t> newShape(rhsShape.begin(),
|
||||
rhsShape.begin() + leadingRank);
|
||||
newShape.insert(newShape.end(), lhsShape.begin(), lhsShape.end());
|
||||
auto newDimSizes =
|
||||
*mhlo::getDimSizesOfTensor(rewriter, op, rhs, leadingDims);
|
||||
auto lhsDimSizes = *mhlo::getDimSizesOfTensor(rewriter, op, lhs);
|
||||
auto newDimSizes = *mhlo::getDimSizesOfTensor(
|
||||
rewriter, op, rhs, leadingDims, dimSizeIndexBits);
|
||||
auto lhsDimSizes =
|
||||
*mhlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits);
|
||||
newDimSizes.insert(newDimSizes.end(), lhsDimSizes.begin(),
|
||||
lhsDimSizes.end());
|
||||
lhs = getBroadcastTensor(rewriter, op, lhs, newShape, newDimSizes,
|
||||
|
@ -103,9 +106,10 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
|
|||
std::vector<int64_t> newShape(lhsShape.begin(),
|
||||
lhsShape.begin() + leadingRank);
|
||||
newShape.insert(newShape.end(), rhsShape.begin(), rhsShape.end());
|
||||
auto newDimSizes =
|
||||
*mhlo::getDimSizesOfTensor(rewriter, op, lhs, leadingDims);
|
||||
auto rhsDimSizes = *mhlo::getDimSizesOfTensor(rewriter, op, rhs);
|
||||
auto newDimSizes = *mhlo::getDimSizesOfTensor(
|
||||
rewriter, op, lhs, leadingDims, dimSizeIndexBits);
|
||||
auto rhsDimSizes =
|
||||
*mhlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits);
|
||||
newDimSizes.insert(newDimSizes.end(), rhsDimSizes.begin(),
|
||||
rhsDimSizes.end());
|
||||
rhs = getBroadcastTensor(rewriter, op, rhs, newShape, newDimSizes,
|
||||
|
@ -122,9 +126,9 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
|
|||
// implement their specialized input processing (e.g transpose), and output
|
||||
// processing, e.g. GEMM or fully connected bias handling.
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenMatmulBaseOp : public OpConversionPattern<AtenOpT> {
|
||||
class ConvertAtenMatmulBaseOp : public ConvertAtenOp<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
// Each variant must implement corresponding parameter parsing options.
|
||||
// Maintain separate input read functions for each variant because it is not
|
||||
|
@ -159,20 +163,24 @@ public:
|
|||
return success();
|
||||
}
|
||||
|
||||
const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
|
||||
int64_t nBatchDims;
|
||||
if (rhsRank <= 2) {
|
||||
auto leadingRank = lhsRank - 2;
|
||||
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank);
|
||||
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank,
|
||||
options.dimSizeIndexBits);
|
||||
nBatchDims = leadingRank;
|
||||
} else if (lhsRank <= 2) {
|
||||
auto leadingRank = rhsRank - 2;
|
||||
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank);
|
||||
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank,
|
||||
options.dimSizeIndexBits);
|
||||
nBatchDims = leadingRank;
|
||||
} else {
|
||||
assert(rhsRank > 2 && lhsRank > 2);
|
||||
auto leadingRank = std::max(lhsRank - rhsRank, rhsRank - lhsRank);
|
||||
nBatchDims = std::max(lhsRank - 2, rhsRank - 2);
|
||||
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank);
|
||||
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank,
|
||||
options.dimSizeIndexBits);
|
||||
}
|
||||
auto batchDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, nBatchDims));
|
||||
auto lhsContractingDim = nBatchDims + 1;
|
||||
|
@ -187,7 +195,7 @@ public:
|
|||
/*rhsBatchingDimensions=*/batchDims,
|
||||
/*lhsContractingDimensions=*/{lhsContractingDim},
|
||||
/*rhsContractingDimensions=*/{rhsContractingDim});
|
||||
auto resultTy = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
auto resultTy = ConvertAtenOp<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<RankedTensorType>();
|
||||
|
||||
|
@ -215,7 +223,7 @@ public:
|
|||
|
||||
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
ConvertAtenOp<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<RankedTensorType>(),
|
||||
output);
|
||||
|
@ -340,7 +348,10 @@ public:
|
|||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
||||
auto leadingRank = std::max(lhsTy.getRank() - rhsTy.getRank(),
|
||||
rhsTy.getRank() - lhsTy.getRank());
|
||||
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank);
|
||||
|
||||
const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
|
||||
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank,
|
||||
options.dimSizeIndexBits);
|
||||
auto resultRank = std::max(lhsTy.getRank(), rhsTy.getRank());
|
||||
auto nBatchDims = resultRank - 2;
|
||||
auto batchDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, nBatchDims));
|
||||
|
@ -356,8 +367,7 @@ public:
|
|||
/*rhsContractingDimensions=*/{rhsContractingDim});
|
||||
|
||||
auto resultTy =
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType());
|
||||
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType());
|
||||
|
||||
Value matmulOutput = rewriter.create<mhlo::DotGeneralOp>(
|
||||
op->getLoc(), resultTy, lhs, rhs, dotDimensionNumbers, nullptr);
|
||||
|
@ -377,12 +387,9 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
|
||||
class ConvertAtenConvolutionOp : public ConvertAtenOp<AtenConvolutionOp> {
|
||||
public:
|
||||
using OpConversionPattern<AtenConvolutionOp>::OpConversionPattern;
|
||||
using ConvertAtenOp<AtenConvolutionOp>::ConvertAtenOp;
|
||||
using OpAdaptor = typename AtenConvolutionOp::Adaptor;
|
||||
|
||||
Value reshapeConvWeight(PatternRewriter &rewriter, Operation *op,
|
||||
|
@ -390,8 +397,9 @@ public:
|
|||
auto weightTy = weight.getType().cast<RankedTensorType>();
|
||||
auto weightElemTy = weightTy.getElementType();
|
||||
auto rank = weightTy.getRank();
|
||||
SmallVector<Value> weightShapeVec =
|
||||
*mhlo::getDimSizesOfTensor(rewriter, op, weight);
|
||||
const auto &options = getOptions();
|
||||
SmallVector<Value> weightShapeVec = *mhlo::getDimSizesOfTensor(
|
||||
rewriter, op, weight, options.dimSizeIndexBits);
|
||||
auto weightShape = weightTy.getShape();
|
||||
SmallVector<int64_t> weightShapeInt(rank);
|
||||
std::copy(weightShape.begin(), weightShape.end(), weightShapeInt.begin());
|
||||
|
@ -601,9 +609,8 @@ public:
|
|||
return mhloConvOp.getResult();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenConvolutionOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(AtenConvolutionOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.input();
|
||||
Value weight = adaptor.weight();
|
||||
|
||||
|
@ -714,7 +721,10 @@ public:
|
|||
// Reshape and promote bias
|
||||
auto inputUnsqzDims =
|
||||
llvm::to_vector<4>(llvm::seq<int64_t>(-nSpatialDims, 0));
|
||||
bias = *mhlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims);
|
||||
|
||||
const auto &options = getOptions();
|
||||
bias = *mhlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims,
|
||||
options.dimSizeIndexBits);
|
||||
bias = mhlo::promoteType(rewriter, bias, outTy);
|
||||
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
|
@ -727,31 +737,31 @@ public:
|
|||
|
||||
void mlir::torch::torch_to_mhlo::populateLinearOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target) {
|
||||
ConversionTarget &target, const TorchToMhloOptions &options) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
|
||||
#define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenMatMulOp<AtenOp>>(typeConverter, context);
|
||||
patterns.add<ConvertAtenMatMulOp<AtenOp>>(typeConverter, context, options)
|
||||
INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp);
|
||||
#undef INSERT_MATMUL_ATEMOP_PATTERN
|
||||
|
||||
#define INSERT_MM_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenMmOp<AtenOp>>(typeConverter, context);
|
||||
patterns.add<ConvertAtenMmOp<AtenOp>>(typeConverter, context, options)
|
||||
INSERT_MM_ATENOP_PATTERN(AtenMmOp);
|
||||
INSERT_MM_ATENOP_PATTERN(AtenBmmOp);
|
||||
#undef INSERT_MM_ATEMOP_PATTERN
|
||||
|
||||
#define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenLinearOp<AtenOp>>(typeConverter, context);
|
||||
patterns.add<ConvertAtenLinearOp<AtenOp>>(typeConverter, context, options)
|
||||
INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp);
|
||||
#undef INSERT_LINEAR_ATEMOP_PATTERN
|
||||
|
||||
#define INSERT_CONVOLUTION_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenConvolutionOp>(typeConverter, context);
|
||||
patterns.add<ConvertAtenConvolutionOp>(typeConverter, context, options)
|
||||
INSERT_CONVOLUTION_ATENOP_PATTERN(AtenConvolutionOp);
|
||||
#undef INSERT_CONVOLUTION_ATENOP_PATTERN
|
||||
}
|
||||
|
|
|
@ -259,9 +259,10 @@ SmallVector<size_t> toPositiveDims(ArrayRef<int64_t> dims, int64_t rank) {
|
|||
return posDims;
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<Value, 4>>
|
||||
getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value,
|
||||
ArrayRef<int64_t> inpDims) {
|
||||
FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
|
||||
Operation *op, Value value,
|
||||
ArrayRef<int64_t> inpDims,
|
||||
size_t dimSizeIndexBits) {
|
||||
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
|
||||
if (!valueTy) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -276,14 +277,15 @@ getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value,
|
|||
auto loc = op->getLoc();
|
||||
for (auto d : dims) {
|
||||
dimSizes.emplace_back(rewriter.create<arith::IndexCastOp>(
|
||||
loc, rewriter.getIntegerType(kMhloDimSizeBits),
|
||||
loc, rewriter.getIntegerType(dimSizeIndexBits),
|
||||
rewriter.create<tensor::DimOp>(loc, value, d)));
|
||||
}
|
||||
return dimSizes;
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<Value, 4>>
|
||||
getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value) {
|
||||
FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
|
||||
Operation *op, Value value,
|
||||
size_t dimSizeIndexBits) {
|
||||
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
|
||||
if (!valueTy) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -294,12 +296,12 @@ getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value) {
|
|||
// Get int vector [0, 1, ..., rank-1]
|
||||
std::vector<int64_t> dims(rank);
|
||||
std::iota(dims.begin(), dims.end(), 0);
|
||||
return getDimSizesOfTensor(rewriter, op, value, dims);
|
||||
return getDimSizesOfTensor(rewriter, op, value, dims, dimSizeIndexBits);
|
||||
}
|
||||
|
||||
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
||||
Value tensor,
|
||||
ArrayRef<int64_t> inputUnsqzDims) {
|
||||
Value tensor, ArrayRef<int64_t> inputUnsqzDims,
|
||||
size_t dimSizeIndexBits) {
|
||||
// Returns a new tensor with dims of size 1 inserted at the specified
|
||||
// position.
|
||||
//
|
||||
|
@ -307,7 +309,8 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
|||
// tensor) are specified with unsqzDims. Indices must be in-order, and in
|
||||
// range of tensor rank. Thus, unsqueeze a rank 1 tensor with {0, 2}, {0, 1,
|
||||
// 3}, {0, 1, 2} are all valid dimension sets, but {0, 3}, {2} are not.
|
||||
auto dimSizesInfo = getDimSizesOfTensor(rewriter, op, tensor);
|
||||
auto dimSizesInfo =
|
||||
getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits);
|
||||
if (failed(dimSizesInfo))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
|
@ -324,7 +327,7 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
|||
auto loc = op->getLoc();
|
||||
auto rankTy = tensor.getType().dyn_cast<RankedTensorType>();
|
||||
auto oldShape = rankTy.getShape();
|
||||
Type intType = rewriter.getIntegerType(kMhloDimSizeBits);
|
||||
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
|
||||
auto one = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(intType, 1));
|
||||
|
||||
|
|
|
@ -19,11 +19,6 @@
|
|||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
|
||||
static constexpr size_t kMhloDimSizeBits = 32;
|
||||
#else
|
||||
static constexpr size_t kMhloDimSizeBits = 64;
|
||||
#endif
|
||||
|
||||
using mlir::ConversionPatternRewriter;
|
||||
|
||||
|
@ -60,17 +55,18 @@ SmallVector<size_t> toPositiveDims(ArrayRef<int64_t> dims, int64_t rank);
|
|||
// Get the dimension sizes of the input tensor, given the dimension axes
|
||||
FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
|
||||
Operation *op, Value value,
|
||||
ArrayRef<int64_t> inpDims);
|
||||
ArrayRef<int64_t> inpDims,
|
||||
size_t dimSizeIndexBits);
|
||||
|
||||
// Get the dimension sizes of the input tensor
|
||||
FailureOr<SmallVector<Value, 4>>
|
||||
getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value);
|
||||
FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
|
||||
Operation *op, Value value,
|
||||
size_t dimSizeIndexBits);
|
||||
|
||||
// Get a tensor that unsqueezed the specified dimensions of the input tensor
|
||||
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
||||
Value tensor,
|
||||
ArrayRef<int64_t> inputUnsqzDims);
|
||||
|
||||
Value tensor, ArrayRef<int64_t> inputUnsqzDims,
|
||||
size_t dimSizeIndexBits);
|
||||
|
||||
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
|
||||
const APFloat &constant, Value shape,
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
using namespace mlir::torch::torch_to_mhlo;
|
||||
|
||||
static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
||||
PatternRewriter &rewriter) {
|
||||
|
@ -72,22 +73,9 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenPoolingOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// AtenMaxPool2dOp
|
||||
namespace {
|
||||
template <>
|
||||
LogicalResult ConvertAtenPoolingOp<AtenMaxPool2dOp>::matchAndRewrite(
|
||||
LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
|
||||
AtenMaxPool2dOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.self();
|
||||
|
@ -186,12 +174,10 @@ LogicalResult ConvertAtenPoolingOp<AtenMaxPool2dOp>::matchAndRewrite(
|
|||
rewriter.replaceOp(op, reduceWindowOp.getResults());
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// AtenMaxPool2dWithIndicesOp
|
||||
namespace {
|
||||
template <>
|
||||
LogicalResult ConvertAtenPoolingOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
||||
LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
||||
AtenMaxPool2dWithIndicesOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.self();
|
||||
|
@ -269,7 +255,9 @@ LogicalResult ConvertAtenPoolingOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
|||
rewriter.getI64Type()),
|
||||
mhloPadding);
|
||||
|
||||
auto inputShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input);
|
||||
const auto &options = getOptions();
|
||||
auto inputShapeInfo =
|
||||
mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
if (failed(inputShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
|
@ -379,12 +367,10 @@ LogicalResult ConvertAtenPoolingOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
|||
rewriter.replaceOp(op, reduceWindowOp.getResults());
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// AtenAvgPool2dOp
|
||||
namespace {
|
||||
template <>
|
||||
LogicalResult ConvertAtenPoolingOp<AtenAvgPool2dOp>::matchAndRewrite(
|
||||
LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
|
||||
AtenAvgPool2dOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.self();
|
||||
|
@ -502,7 +488,9 @@ LogicalResult ConvertAtenPoolingOp<AtenAvgPool2dOp>::matchAndRewrite(
|
|||
Value windowSizeConst =
|
||||
mhlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
|
||||
windowSizeConst = mhlo::promoteType(rewriter, windowSizeConst, outTy);
|
||||
auto inputShapeVec = *mhlo::getDimSizesOfTensor(rewriter, op, input);
|
||||
const auto &options = getOptions();
|
||||
auto inputShapeVec =
|
||||
*mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), inputShapeVec);
|
||||
|
||||
|
@ -540,17 +528,15 @@ LogicalResult ConvertAtenPoolingOp<AtenAvgPool2dOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::torch::torch_to_mhlo::populatePoolingOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target) {
|
||||
ConversionTarget &target, const TorchToMhloOptions &options) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
target.addIllegalOp<AtenMaxPool2dOp>();
|
||||
patterns.add<ConvertAtenPoolingOp<AtenMaxPool2dOp>>(typeConverter, context);
|
||||
patterns.add<ConvertAtenOp<AtenMaxPool2dOp>>(typeConverter, context, options);
|
||||
target.addIllegalOp<AtenAvgPool2dOp>();
|
||||
patterns.add<ConvertAtenPoolingOp<AtenAvgPool2dOp>>(typeConverter, context);
|
||||
patterns.add<ConvertAtenOp<AtenAvgPool2dOp>>(typeConverter, context, options);
|
||||
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
|
||||
patterns.add<ConvertAtenPoolingOp<AtenMaxPool2dWithIndicesOp>>(typeConverter,
|
||||
context);
|
||||
patterns.add<ConvertAtenOp<AtenMaxPool2dWithIndicesOp>>(typeConverter,
|
||||
context, options);
|
||||
}
|
||||
|
|
|
@ -16,25 +16,56 @@ namespace mlir {
|
|||
namespace torch {
|
||||
namespace torch_to_mhlo {
|
||||
|
||||
struct TorchToMhloOptions {
|
||||
bool enableStaticShape = false;
|
||||
size_t dimSizeIndexBits = 64;
|
||||
};
|
||||
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
ConvertAtenOp(TypeConverter &typeConverter, MLIRContext *context,
|
||||
const TorchToMhloOptions &options)
|
||||
: OpConversionPattern<AtenOpT>(typeConverter, context) {
|
||||
this->options = options;
|
||||
}
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
return rewriter.notifyMatchFailure(op, "haven't been implemented");
|
||||
}
|
||||
const TorchToMhloOptions &getOptions() const { return options; }
|
||||
|
||||
private:
|
||||
TorchToMhloOptions options;
|
||||
};
|
||||
|
||||
void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target);
|
||||
ConversionTarget &target,
|
||||
const TorchToMhloOptions &options);
|
||||
void populateViewLikeOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target);
|
||||
ConversionTarget &target,
|
||||
const TorchToMhloOptions &options);
|
||||
void populateGatherOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target);
|
||||
ConversionTarget &target,
|
||||
const TorchToMhloOptions &options);
|
||||
void populateReductionOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target);
|
||||
ConversionTarget &target,
|
||||
const TorchToMhloOptions &options);
|
||||
void populateLinearOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target);
|
||||
ConversionTarget &target,
|
||||
const TorchToMhloOptions &options);
|
||||
|
||||
void populatePoolingOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target);
|
||||
ConversionTarget &target,
|
||||
const TorchToMhloOptions &options);
|
||||
|
||||
} // namespace torch_to_mhlo
|
||||
} // namespace torch
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
using namespace mlir::torch::torch_to_mhlo;
|
||||
|
||||
static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
||||
PatternRewriter &rewriter) {
|
||||
|
@ -72,7 +73,8 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
|||
// Util for converting AtenArgmaxOp and AtenMaxDimOp
|
||||
static llvm::Optional<ValueRange>
|
||||
getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
||||
ArrayRef<Value> inputShapeVec, int64_t dim) {
|
||||
ArrayRef<Value> inputShapeVec, int64_t dim,
|
||||
size_t dimSizeIndexBits) {
|
||||
auto inputTy = input.getType().template cast<RankedTensorType>();
|
||||
if (!inputTy) {
|
||||
return llvm::None;
|
||||
|
@ -86,7 +88,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
|||
Value initValue = createInitialValueForReduceOp(op, inputElemTy, rewriter);
|
||||
if (!initValue) return llvm::None;
|
||||
Value initIndex;
|
||||
if (mlir::mhlo::kMhloDimSizeBits == 32) {
|
||||
if (dimSizeIndexBits == 32) {
|
||||
initIndex = mhlo::getConstTensor<int32_t>(rewriter, op, {0}, {}).value();
|
||||
} else {
|
||||
initIndex = mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
|
||||
|
@ -98,7 +100,9 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
|||
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), inputShapeVec);
|
||||
auto indexTensor = rewriter.create<mhlo::DynamicIotaOp>(
|
||||
op->getLoc(), RankedTensorType::get(inputShape, rewriter.getIntegerType(mlir::mhlo::kMhloDimSizeBits)),
|
||||
op->getLoc(),
|
||||
RankedTensorType::get(inputShape,
|
||||
rewriter.getIntegerType(dimSizeIndexBits)),
|
||||
inputShapeTensor, static_cast<uint64_t>(dim));
|
||||
|
||||
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
|
||||
|
@ -114,7 +118,8 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
|||
// Add block arguments
|
||||
auto blockValArgumentType =
|
||||
RankedTensorType::get({}, inputTy.getElementType());
|
||||
auto blockIdxArgumentType = RankedTensorType::get({}, rewriter.getIntegerType(mlir::mhlo::kMhloDimSizeBits));
|
||||
auto blockIdxArgumentType =
|
||||
RankedTensorType::get({}, rewriter.getIntegerType(dimSizeIndexBits));
|
||||
auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type());
|
||||
block.addArgument(blockValArgumentType, op->getLoc());
|
||||
block.addArgument(blockIdxArgumentType, op->getLoc());
|
||||
|
@ -171,9 +176,9 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
|||
|
||||
namespace {
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenReductionOp : public OpConversionPattern<AtenOpT> {
|
||||
class ConvertAtenReductionOp : public ConvertAtenOp<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
|
@ -220,21 +225,24 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
|
||||
}
|
||||
|
||||
auto inputShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input);
|
||||
const auto &options = getOptions();
|
||||
auto inputShapeInfo =
|
||||
mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
if (failed(inputShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
}
|
||||
auto inputShapeVec = *inputShapeInfo;
|
||||
auto mhloReduceResults =
|
||||
getMaxInDim(rewriter, op, input, inputShapeVec, dim).value();
|
||||
auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim,
|
||||
options.dimSizeIndexBits)
|
||||
.value();
|
||||
|
||||
if (keepDim) {
|
||||
auto outShapeVec = inputShapeVec;
|
||||
|
||||
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op->getLoc(), rewriter.getIntegerAttr(
|
||||
rewriter.getIntegerType(mlir::mhlo::kMhloDimSizeBits), 1));
|
||||
op->getLoc(),
|
||||
rewriter.getIntegerAttr(
|
||||
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
|
||||
|
||||
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), outShapeVec);
|
||||
|
@ -297,20 +305,24 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
|
||||
}
|
||||
|
||||
auto inputShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input);
|
||||
const auto &options = getOptions();
|
||||
auto inputShapeInfo =
|
||||
mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||
if (failed(inputShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
}
|
||||
auto inputShapeVec = *inputShapeInfo;
|
||||
auto mhloReduceResults =
|
||||
getMaxInDim(rewriter, op, input, inputShapeVec, dim).value();
|
||||
auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim,
|
||||
options.dimSizeIndexBits)
|
||||
.value();
|
||||
|
||||
if (keepDim) {
|
||||
auto outShapeVec = inputShapeVec;
|
||||
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op->getLoc(), rewriter.getIntegerAttr(
|
||||
rewriter.getIntegerType(mhlo::kMhloDimSizeBits), 1));
|
||||
op->getLoc(),
|
||||
rewriter.getIntegerAttr(
|
||||
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
|
||||
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), outShapeVec);
|
||||
|
||||
|
@ -532,15 +544,18 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
|||
}
|
||||
|
||||
if (keepDim) {
|
||||
auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input);
|
||||
const auto &options = getOptions();
|
||||
auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input,
|
||||
options.dimSizeIndexBits);
|
||||
if (failed(outShapeInfo)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
}
|
||||
auto outShapeVec = *outShapeInfo;
|
||||
auto one = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op->getLoc(), rewriter.getIntegerAttr(
|
||||
rewriter.getIntegerType(mhlo::kMhloDimSizeBits), 1));
|
||||
op->getLoc(),
|
||||
rewriter.getIntegerAttr(
|
||||
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
|
||||
for (int64_t i : dims) {
|
||||
outShapeVec[i] = one;
|
||||
}
|
||||
|
@ -558,11 +573,11 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
|||
|
||||
void mlir::torch::torch_to_mhlo::populateReductionOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target) {
|
||||
ConversionTarget &target, const TorchToMhloOptions &options) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
#define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenReductionOp<AtenOp>>(typeConverter, context);
|
||||
patterns.add<ConvertAtenReductionOp<AtenOp>>(typeConverter, context, options)
|
||||
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenArgmaxOp);
|
||||
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxDimOp);
|
||||
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp);
|
||||
|
|
|
@ -32,6 +32,12 @@ namespace {
|
|||
|
||||
class ConvertTorchToMhlo : public ConvertTorchToMhloBase<ConvertTorchToMhlo> {
|
||||
public:
|
||||
ConvertTorchToMhlo() = default;
|
||||
ConvertTorchToMhlo(bool enableStaticShape, bool enableI32Index) {
|
||||
this->enableStaticShape = enableStaticShape;
|
||||
this->enableI32Index = enableI32Index;
|
||||
}
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<chlo::ChloDialect>();
|
||||
registry.insert<mhlo::MhloDialect>();
|
||||
|
@ -51,18 +57,20 @@ public:
|
|||
|
||||
RewritePatternSet patterns(context);
|
||||
|
||||
torch_to_mhlo::TorchToMhloOptions options{enableStaticShape,
|
||||
enableI32Index ? 32u : 64u};
|
||||
torch_to_mhlo::populateBasicOpPatternsAndLegality(typeConverter, patterns,
|
||||
target);
|
||||
torch_to_mhlo::populateViewLikeOpPatternsAndLegality(typeConverter,
|
||||
patterns, target);
|
||||
target, options);
|
||||
torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
|
||||
typeConverter, patterns, target, options);
|
||||
torch_to_mhlo::populateGatherOpPatternsAndLegality(typeConverter, patterns,
|
||||
target);
|
||||
torch_to_mhlo::populateReductionOpPatternsAndLegality(typeConverter,
|
||||
patterns, target);
|
||||
target, options);
|
||||
torch_to_mhlo::populateReductionOpPatternsAndLegality(
|
||||
typeConverter, patterns, target, options);
|
||||
torch_to_mhlo::populateLinearOpPatternsAndLegality(typeConverter, patterns,
|
||||
target);
|
||||
target, options);
|
||||
torch_to_mhlo::populatePoolingOpPatternsAndLegality(typeConverter, patterns,
|
||||
target);
|
||||
target, options);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns)))) {
|
||||
|
@ -75,5 +83,12 @@ public:
|
|||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::createConvertTorchToMhloPass() {
|
||||
return std::make_unique<ConvertTorchToMhlo>();
|
||||
return std::make_unique<ConvertTorchToMhlo>(false, false);
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::createConvertTorchToMhloPass(bool enableStaticShape,
|
||||
bool enableI32Index) {
|
||||
return std::make_unique<ConvertTorchToMhlo>(enableStaticShape,
|
||||
enableI32Index);
|
||||
}
|
||||
|
|
|
@ -28,6 +28,7 @@ using namespace mlir;
|
|||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
using namespace mlir::torch::TorchConversion;
|
||||
using namespace mlir::torch::torch_to_mhlo;
|
||||
|
||||
namespace {
|
||||
// A dimension index from torch.dialect might outside the range [0, dimSize].
|
||||
|
@ -55,10 +56,11 @@ Value getNormalizedDimSizeInternal(PatternRewriter &rewriter, Operation *op,
|
|||
Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op,
|
||||
Type outTy, Value input, Value startIndex,
|
||||
Value endIndex, Value step, size_t dimIndex,
|
||||
ArrayRef<Value> dimSizes) {
|
||||
ArrayRef<Value> dimSizes,
|
||||
size_t dimSizeIndexBits) {
|
||||
auto loc = op->getLoc();
|
||||
// startIndex & endIndex has been normailized into range [0, dSize]
|
||||
Type intType = rewriter.getIntegerType(mhlo::kMhloDimSizeBits);
|
||||
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
|
||||
Value zero = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(intType, 0));
|
||||
Value one = rewriter.create<arith::ConstantOp>(
|
||||
|
@ -109,7 +111,8 @@ FailureOr<Value> getDynamicSlice(PatternRewriter &rewriter, Operation *op,
|
|||
Type outTy, Value input,
|
||||
llvm::Optional<Value> startIndexOpt,
|
||||
llvm::Optional<Value> endIndexOpt,
|
||||
llvm::Optional<Value> stepOpt, int64_t dim) {
|
||||
llvm::Optional<Value> stepOpt, int64_t dim,
|
||||
size_t dimSizeIndexBits) {
|
||||
auto loc = op->getLoc();
|
||||
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
auto rank = inputTy.getRank();
|
||||
|
@ -133,77 +136,31 @@ FailureOr<Value> getDynamicSlice(PatternRewriter &rewriter, Operation *op,
|
|||
: rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1));
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
|
||||
auto i32Type = rewriter.getIntegerType(mhlo::kMhloDimSizeBits);
|
||||
if (dimSizeIndexBits == 32) {
|
||||
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
|
||||
normStartIndex =
|
||||
rewriter.create<arith::TruncIOp>(loc, i32Type, normStartIndex);
|
||||
normEndIndex = rewriter.create<arith::TruncIOp>(loc, i32Type, normEndIndex);
|
||||
step = rewriter.create<arith::TruncIOp>(loc, i32Type, step);
|
||||
#endif
|
||||
rewriter.create<arith::TruncIOp>(loc, intType, normStartIndex);
|
||||
normEndIndex = rewriter.create<arith::TruncIOp>(loc, intType, normEndIndex);
|
||||
step = rewriter.create<arith::TruncIOp>(loc, intType, step);
|
||||
}
|
||||
FailureOr<SmallVector<Value, 4>> dimSizesInfo =
|
||||
mhlo::getDimSizesOfTensor(rewriter, op, input);
|
||||
mhlo::getDimSizesOfTensor(rewriter, op, input, dimSizeIndexBits);
|
||||
if (failed(dimSizesInfo))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
|
||||
auto dimSizes = *dimSizesInfo;
|
||||
return getDynamicSliceInternal(rewriter, op, outTy, input, normStartIndex,
|
||||
normEndIndex, step, dim, dimSizes);
|
||||
}
|
||||
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
|
||||
AtenSliceTensorOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto self = adaptor.self();
|
||||
auto selfTy = self.getType().template cast<RankedTensorType>();
|
||||
if (!selfTy)
|
||||
return op.emitError("only ranked tensor types are supported");
|
||||
auto outTy =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only constant dim is currently supported");
|
||||
|
||||
auto getOptionalVal = [&](Value val) -> llvm::Optional<Value> {
|
||||
if (val.getType().isa<Torch::NoneType>()) {
|
||||
return llvm::None;
|
||||
} else {
|
||||
return val;
|
||||
}
|
||||
};
|
||||
|
||||
llvm::Optional<Value> start = getOptionalVal(adaptor.start());
|
||||
llvm::Optional<Value> end = getOptionalVal(adaptor.end());
|
||||
llvm::Optional<Value> step = getOptionalVal(adaptor.step());
|
||||
|
||||
FailureOr<Value> sliceInfo =
|
||||
getDynamicSlice(rewriter, op, outTy, self, start, end, step, dim);
|
||||
if (failed(sliceInfo))
|
||||
return op.emitError("can not create a dynmaic slice");
|
||||
|
||||
auto slice = *sliceInfo;
|
||||
rewriter.replaceOp(op, slice);
|
||||
return success();
|
||||
normEndIndex, step, dim, dimSizes,
|
||||
dimSizeIndexBits);
|
||||
}
|
||||
|
||||
// This defines a template to construct ops whose legalizations are
|
||||
// specialized.
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenViewOp : public OpConversionPattern<AtenOpT> {
|
||||
class ConvertAtenViewOp : public ConvertAtenOp<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
|
||||
LogicalResult
|
||||
|
@ -235,19 +192,19 @@ public:
|
|||
return dSize;
|
||||
});
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
|
||||
const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
|
||||
Type intType = rewriter.getIntegerType(options.dimSizeIndexBits);
|
||||
if (options.dimSizeIndexBits == 32) {
|
||||
// The i64 calculation is much slower than i32 on some devices, such as
|
||||
// Nvidia GPU. One can truncate from i64 to i32 since dimension sizes are
|
||||
// unlikely to exceed the range of i32(4GiB)
|
||||
std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value &dSize) {
|
||||
// dimSize: cast i64 -> i32
|
||||
dSize =
|
||||
rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), dSize);
|
||||
dSize = rewriter.create<arith::TruncIOp>(loc, intType, dSize);
|
||||
return dSize;
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
Type intType = rewriter.getIntegerType(mhlo::kMhloDimSizeBits);
|
||||
Value numel = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(intType, 1));
|
||||
for (auto d : dimSizes) {
|
||||
|
@ -293,6 +250,45 @@ bool ConvertAtenViewOp<AtenReshapeOp>::getAtenViewOpSizes(
|
|||
SmallVector<Value, 4> &dimSizes) const {
|
||||
return getListConstructElements(adaptor.shape(), dimSizes);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
|
||||
AtenSliceTensorOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto self = adaptor.self();
|
||||
auto selfTy = self.getType().template cast<RankedTensorType>();
|
||||
if (!selfTy)
|
||||
return op.emitError("only ranked tensor types are supported");
|
||||
auto outTy =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only constant dim is currently supported");
|
||||
|
||||
auto getOptionalVal = [&](Value val) -> llvm::Optional<Value> {
|
||||
if (val.getType().isa<Torch::NoneType>()) {
|
||||
return llvm::None;
|
||||
} else {
|
||||
return val;
|
||||
}
|
||||
};
|
||||
|
||||
llvm::Optional<Value> start = getOptionalVal(adaptor.start());
|
||||
llvm::Optional<Value> end = getOptionalVal(adaptor.end());
|
||||
llvm::Optional<Value> step = getOptionalVal(adaptor.step());
|
||||
|
||||
FailureOr<Value> sliceInfo =
|
||||
getDynamicSlice(rewriter, op, outTy, self, start, end, step, dim,
|
||||
options.dimSizeIndexBits);
|
||||
if (failed(sliceInfo))
|
||||
return op.emitError("can not create a dynmaic slice");
|
||||
|
||||
auto slice = *sliceInfo;
|
||||
rewriter.replaceOp(op, slice);
|
||||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
|
||||
|
@ -324,7 +320,8 @@ LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims);
|
||||
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims,
|
||||
options.dimSizeIndexBits);
|
||||
if (failed(newDimSizesInfo))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
|
@ -372,7 +369,8 @@ LogicalResult ConvertAtenOp<AtenSqueezeDimOp>::matchAndRewrite(
|
|||
op, getTypeConverter()->convertType(op.getType()), self);
|
||||
return success();
|
||||
}
|
||||
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims);
|
||||
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims,
|
||||
options.dimSizeIndexBits);
|
||||
if (failed(newDimSizesInfo))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
|
@ -397,8 +395,8 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
|
|||
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
|
||||
return op->emitError("dim must be a Scalar constant");
|
||||
|
||||
auto unsqzTensorInfo =
|
||||
mhlo::unsqueezeTensor(rewriter, op, adaptor.self(), {dim});
|
||||
auto unsqzTensorInfo = mhlo::unsqueezeTensor(rewriter, op, adaptor.self(),
|
||||
{dim}, options.dimSizeIndexBits);
|
||||
if (failed(unsqzTensorInfo))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"failed to create unsqueezed tensor");
|
||||
|
@ -406,16 +404,15 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
|
|||
rewriter.replaceOp(op, *unsqzTensorInfo);
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target) {
|
||||
ConversionTarget &target, const TorchToMhloOptions &options) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
|
||||
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
|
||||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
|
||||
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
|
||||
INSERT_ATENOP_PATTERN(AtenSqueezeOp);
|
||||
INSERT_ATENOP_PATTERN(AtenSqueezeDimOp);
|
||||
|
@ -424,7 +421,7 @@ void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
|
|||
|
||||
#define INSERT_VIEW_OP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenViewOp<AtenOp>>(typeConverter, context);
|
||||
patterns.add<ConvertAtenViewOp<AtenOp>>(typeConverter, context, options)
|
||||
INSERT_VIEW_OP_PATTERN(AtenViewOp);
|
||||
INSERT_VIEW_OP_PATTERN(AtenReshapeOp);
|
||||
#undef INSERT_VIEW_OP_PATTERN
|
||||
|
|
Loading…
Reference in New Issue