mirror of https://github.com/llvm/torch-mlir
[MHLO] refactor pass configurations (#1315)
Related to https://github.com/llvm/torch-mlir/issues/1227 1. Reduce MHLO #ifdefs 2. Dismiss compilation warningspull/1321/head
parent
7769eb88f8
commit
29cafdbb61
|
@ -39,14 +39,6 @@ endmacro()
|
||||||
option(TORCH_MLIR_ENABLE_MHLO "Add mhlo dialect" ON)
|
option(TORCH_MLIR_ENABLE_MHLO "Add mhlo dialect" ON)
|
||||||
if(TORCH_MLIR_ENABLE_MHLO)
|
if(TORCH_MLIR_ENABLE_MHLO)
|
||||||
add_definitions(-DTORCH_MLIR_ENABLE_MHLO)
|
add_definitions(-DTORCH_MLIR_ENABLE_MHLO)
|
||||||
# The i64 calculation is much slower than i32 on some devices, such as Nvidia GPU.
|
|
||||||
# One can truncate from i64 to i32 since dimension sizes are unlikely to exceed
|
|
||||||
# the range of i32(4GiB)
|
|
||||||
option(TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
|
|
||||||
"Enable truncate dimension size from i64 to i32(unsafely)" OFF)
|
|
||||||
if(TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32)
|
|
||||||
add_definitions(-DTORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32)
|
|
||||||
endif()
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF)
|
option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF)
|
||||||
|
|
|
@ -132,6 +132,17 @@ def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> {
|
||||||
Convert Torch ops to mhlo ops.
|
Convert Torch ops to mhlo ops.
|
||||||
}];
|
}];
|
||||||
let constructor = "mlir::torch::createConvertTorchToMhloPass()";
|
let constructor = "mlir::torch::createConvertTorchToMhloPass()";
|
||||||
|
|
||||||
|
// Specify any options.
|
||||||
|
let options = [
|
||||||
|
Option<"enableStaticShape", "enable-static-shape", "bool", /*default=*/"false",
|
||||||
|
"Enable static shape conversion">,
|
||||||
|
// The i64 calculation is much slower than i32 on some devices, such as
|
||||||
|
// Nvidia GPU. One can truncate from i64 to i32 since dimension sizes
|
||||||
|
// are unlikely to exceed the range of i32(4GiB)
|
||||||
|
Option<"enableI32Index", "enable-i32-index", "bool", /*default=*/"false",
|
||||||
|
"Enable truncate index from i64 to i32(unsafely)">,
|
||||||
|
];
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,8 @@
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace torch {
|
namespace torch {
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToMhloPass();
|
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToMhloPass();
|
||||||
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
|
createConvertTorchToMhloPass(bool enableStaticShape, bool enableI32Index);
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -29,6 +29,7 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
using namespace mlir::torch::torch_to_mhlo;
|
||||||
|
|
||||||
bool skipMultiplyAlpha(Value alphaValue) {
|
bool skipMultiplyAlpha(Value alphaValue) {
|
||||||
double doubleValue;
|
double doubleValue;
|
||||||
|
@ -379,63 +380,58 @@ public:
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// AtenBroadcastToOp
|
// AtenBroadcastToOp
|
||||||
namespace {
|
template <>
|
||||||
class ConvertAtenBroadcastToOp : public OpConversionPattern<AtenBroadcastToOp> {
|
LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
|
||||||
public:
|
AtenBroadcastToOp op, OpAdaptor adaptor,
|
||||||
using OpConversionPattern::OpConversionPattern;
|
ConversionPatternRewriter &rewriter) const {
|
||||||
LogicalResult
|
Value self = adaptor.self();
|
||||||
matchAndRewrite(AtenBroadcastToOp op, OpAdaptor adaptor,
|
auto selfTy = self.getType().cast<RankedTensorType>();
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
auto outType = getTypeConverter()
|
||||||
Value self = adaptor.self();
|
->convertType(op->getResult(0).getType())
|
||||||
auto selfTy = self.getType().cast<RankedTensorType>();
|
.cast<RankedTensorType>();
|
||||||
auto outType = getTypeConverter()
|
|
||||||
->convertType(op->getResult(0).getType())
|
|
||||||
.cast<RankedTensorType>();
|
|
||||||
|
|
||||||
#ifdef TORCH_MLIR_ENABLE_MHLO_STATIC_SHAPE
|
if (options.enableStaticShape && selfTy.hasStaticShape()) {
|
||||||
if (selfTy.hasStaticShape()) {
|
Value bcastOp = mhlo::promoteAndBroadcast(rewriter, self, outType);
|
||||||
Value bcastOp = mhlo::promoteAndBroadcast(rewriter, self, outType);
|
rewriter.replaceOp(op, bcastOp);
|
||||||
rewriter.replaceOp(op, bcastOp);
|
return success();
|
||||||
return success();
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> shape;
|
||||||
|
if (!(getListConstructElements(adaptor.size(), shape))) {
|
||||||
|
return op->emitError("desired shape must be a list of scalar");
|
||||||
|
}
|
||||||
|
SmallVector<Value> bcastShapeVec;
|
||||||
|
int64_t totalRank = shape.size();
|
||||||
|
int64_t selfRank = selfTy.getRank();
|
||||||
|
int64_t leadingRank = totalRank - selfRank;
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < totalRank; ++i) {
|
||||||
|
Value dValue = shape[i];
|
||||||
|
Value newD;
|
||||||
|
int64_t dInt;
|
||||||
|
if (!(matchPattern(dValue, m_TorchConstantInt(&dInt)))) {
|
||||||
|
return op->emitError("element of desired shape must be a scalar");
|
||||||
}
|
}
|
||||||
#endif
|
if (i >= leadingRank && dInt == -1) {
|
||||||
|
newD = rewriter.create<mlir::tensor::DimOp>(op->getLoc(), self,
|
||||||
SmallVector<Value> shape;
|
i - leadingRank);
|
||||||
if (!(getListConstructElements(adaptor.size(), shape))) {
|
} else {
|
||||||
return op->emitError("desired shape must be a list of scalar");
|
dValue = rewriter.create<torch::TorchConversion::ToI64Op>(op->getLoc(),
|
||||||
|
dValue);
|
||||||
|
newD = rewriter.create<mlir::arith::IndexCastOp>(
|
||||||
|
op->getLoc(), rewriter.getIndexType(), dValue);
|
||||||
}
|
}
|
||||||
SmallVector<Value> bcastShapeVec;
|
bcastShapeVec.push_back(newD);
|
||||||
int64_t totalRank = shape.size();
|
}
|
||||||
int64_t selfRank = selfTy.getRank();
|
|
||||||
int64_t leadingRank = totalRank - selfRank;
|
|
||||||
|
|
||||||
for (int64_t i = 0; i < totalRank; ++i) {
|
if (options.dimSizeIndexBits == 32) {
|
||||||
Value dValue = shape[i];
|
|
||||||
Value newD;
|
|
||||||
int64_t dInt;
|
|
||||||
if (!(matchPattern(dValue, m_TorchConstantInt(&dInt)))) {
|
|
||||||
return op->emitError("element of desired shape must be a scalar");
|
|
||||||
}
|
|
||||||
if (i >= leadingRank && dInt == -1) {
|
|
||||||
newD = rewriter.create<mlir::tensor::DimOp>(op->getLoc(), self,
|
|
||||||
i - leadingRank);
|
|
||||||
} else {
|
|
||||||
dValue = rewriter.create<torch::TorchConversion::ToI64Op>(op->getLoc(),
|
|
||||||
dValue);
|
|
||||||
newD = rewriter.create<mlir::arith::IndexCastOp>(
|
|
||||||
op->getLoc(), rewriter.getIndexType(), dValue);
|
|
||||||
}
|
|
||||||
bcastShapeVec.push_back(newD);
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
|
|
||||||
for (auto &dsize : bcastShapeVec) {
|
for (auto &dsize : bcastShapeVec) {
|
||||||
auto dsizeI64 = rewriter.create<mlir::arith::IndexCastOp>(
|
auto dsizeI64 = rewriter.create<mlir::arith::IndexCastOp>(
|
||||||
op->getLoc(), rewriter.getI64Type(), dsize);
|
op->getLoc(), rewriter.getI64Type(), dsize);
|
||||||
dsize = rewriter.create<arith::TruncIOp>(op->getLoc(),
|
dsize = rewriter.create<arith::TruncIOp>(op->getLoc(),
|
||||||
rewriter.getI32Type(), dsizeI64);
|
rewriter.getI32Type(), dsizeI64);
|
||||||
}
|
}
|
||||||
#endif
|
}
|
||||||
|
|
||||||
Value bcastShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
Value bcastShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||||
op->getLoc(), ValueRange{bcastShapeVec});
|
op->getLoc(), ValueRange{bcastShapeVec});
|
||||||
|
@ -445,66 +441,45 @@ public:
|
||||||
op, outType, self, bcastShapeTensor,
|
op, outType, self, bcastShapeTensor,
|
||||||
rewriter.getI64TensorAttr(dimensionNumbers));
|
rewriter.getI64TensorAttr(dimensionNumbers));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// AtenPermuteOp
|
// AtenPermuteOp
|
||||||
namespace {
|
template <>
|
||||||
class ConvertAtenPermuteOp : public OpConversionPattern<AtenPermuteOp> {
|
LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
|
||||||
public:
|
AtenPermuteOp op, OpAdaptor adaptor,
|
||||||
using OpConversionPattern::OpConversionPattern;
|
ConversionPatternRewriter &rewriter) const {
|
||||||
LogicalResult
|
Value self = adaptor.self();
|
||||||
matchAndRewrite(AtenPermuteOp op, OpAdaptor adaptor,
|
// Not a ranked tensor type
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
auto inType = self.getType().dyn_cast<RankedTensorType>();
|
||||||
Value self = adaptor.self();
|
auto outType = getTypeConverter()
|
||||||
// Not a ranked tensor type
|
->convertType(op->getResult(0).getType())
|
||||||
auto inType = self.getType().dyn_cast<RankedTensorType>();
|
.cast<RankedTensorType>();
|
||||||
auto outType = getTypeConverter()
|
if (!inType)
|
||||||
->convertType(op->getResult(0).getType())
|
return op.emitError("only ranked tensor types with static shapes are "
|
||||||
.cast<RankedTensorType>();
|
"currently supported");
|
||||||
if (!inType)
|
|
||||||
return op.emitError("only ranked tensor types with static shapes are "
|
|
||||||
"currently supported");
|
|
||||||
|
|
||||||
SmallVector<int64_t> permValues;
|
SmallVector<int64_t> permValues;
|
||||||
if (!matchPattern(adaptor.dims(), m_TorchConstantIntList(permValues)))
|
if (!matchPattern(adaptor.dims(), m_TorchConstantIntList(permValues)))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only constant dimensions are currently supported");
|
op, "only constant dimensions are currently supported");
|
||||||
|
|
||||||
int64_t inRank = inType.getRank();
|
int64_t inRank = inType.getRank();
|
||||||
for (auto &d : permValues) {
|
for (auto &d : permValues) {
|
||||||
d = toPositiveDim(d, inRank);
|
d = toPositiveDim(d, inRank);
|
||||||
if (!isValidDim(d, inRank))
|
if (!isValidDim(d, inRank))
|
||||||
return op.emitError("not all dims are valid");
|
return op.emitError("not all dims are valid");
|
||||||
}
|
|
||||||
|
|
||||||
DenseIntElementsAttr permutation = DenseIntElementsAttr::get(
|
|
||||||
RankedTensorType::get({static_cast<long int>(permValues.size())},
|
|
||||||
rewriter.getI64Type()),
|
|
||||||
permValues);
|
|
||||||
rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(op, outType, self,
|
|
||||||
permutation);
|
|
||||||
return success();
|
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
DenseIntElementsAttr permutation = DenseIntElementsAttr::get(
|
||||||
|
RankedTensorType::get({static_cast<long int>(permValues.size())},
|
||||||
namespace {
|
rewriter.getI64Type()),
|
||||||
template <typename AtenOpT>
|
permValues);
|
||||||
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
|
rewriter.replaceOpWithNewOp<mhlo::TransposeOp>(op, outType, self,
|
||||||
public:
|
permutation);
|
||||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
return success();
|
||||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
}
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const override;
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// AtenTanhOp
|
// AtenTanhOp
|
||||||
namespace {
|
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
|
||||||
AtenTanhOp op, OpAdaptor adaptor,
|
AtenTanhOp op, OpAdaptor adaptor,
|
||||||
|
@ -520,10 +495,8 @@ LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
|
||||||
"only floating-point datatype legalization currently supported");
|
"only floating-point datatype legalization currently supported");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// ValueTensorLiteralOp
|
// ValueTensorLiteralOp
|
||||||
namespace {
|
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
|
||||||
ValueTensorLiteralOp op, OpAdaptor adaptor,
|
ValueTensorLiteralOp op, OpAdaptor adaptor,
|
||||||
|
@ -553,11 +526,9 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// AtenReciprocalOp
|
// AtenReciprocalOp
|
||||||
// Reciprocal(x) = Div(1, x)
|
// Reciprocal(x) = Div(1, x)
|
||||||
namespace {
|
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
|
||||||
AtenReciprocalOp op, OpAdaptor adaptor,
|
AtenReciprocalOp op, OpAdaptor adaptor,
|
||||||
|
@ -575,10 +546,8 @@ LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
|
||||||
rewriter.replaceOpWithNewOp<mhlo::DivOp>(op, outTy, oneTensor, input);
|
rewriter.replaceOpWithNewOp<mhlo::DivOp>(op, outTy, oneTensor, input);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// PrimNumToTensorScalarOp
|
// PrimNumToTensorScalarOp
|
||||||
namespace {
|
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
|
||||||
PrimNumToTensorScalarOp op, OpAdaptor adaptor,
|
PrimNumToTensorScalarOp op, OpAdaptor adaptor,
|
||||||
|
@ -592,11 +561,9 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
|
||||||
rewriter.replaceOp(op, mhloTensor);
|
rewriter.replaceOp(op, mhloTensor);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// AtenContiguousOp
|
// AtenContiguousOp
|
||||||
// Ref: TosaToTosa.cpp for implementation details
|
// Ref: TosaToTosa.cpp for implementation details
|
||||||
namespace {
|
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<AtenContiguousOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenContiguousOp>::matchAndRewrite(
|
||||||
AtenContiguousOp op, OpAdaptor adaptor,
|
AtenContiguousOp op, OpAdaptor adaptor,
|
||||||
|
@ -614,11 +581,9 @@ LogicalResult ConvertAtenOp<AtenContiguousOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// AtenReluOp
|
// AtenReluOp
|
||||||
// Relu(x) = Max(0, x)
|
// Relu(x) = Max(0, x)
|
||||||
namespace {
|
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
|
||||||
AtenReluOp op, OpAdaptor adaptor,
|
AtenReluOp op, OpAdaptor adaptor,
|
||||||
|
@ -641,11 +606,9 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// Convert a Aten::GELU to HLO
|
// Convert a Aten::GELU to HLO
|
||||||
// Gelu(x) = x * 1/2 * [1 + erf(x/(sqrt(2)))]
|
// Gelu(x) = x * 1/2 * [1 + erf(x/(sqrt(2)))]
|
||||||
namespace {
|
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
|
||||||
AtenGeluOp op, OpAdaptor adaptor,
|
AtenGeluOp op, OpAdaptor adaptor,
|
||||||
|
@ -668,10 +631,8 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
|
||||||
rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, input, halfMul);
|
rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, input, halfMul);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// AtenErfOp
|
// AtenErfOp
|
||||||
namespace {
|
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<AtenErfOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenErfOp>::matchAndRewrite(
|
||||||
AtenErfOp op, OpAdaptor adaptor,
|
AtenErfOp op, OpAdaptor adaptor,
|
||||||
|
@ -686,10 +647,8 @@ LogicalResult ConvertAtenOp<AtenErfOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// AtenBatchNormOp
|
// AtenBatchNormOp
|
||||||
namespace {
|
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
||||||
AtenBatchNormOp op, OpAdaptor adaptor,
|
AtenBatchNormOp op, OpAdaptor adaptor,
|
||||||
|
@ -716,12 +675,12 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
||||||
|
|
||||||
Value channelDim = rewriter.create<tensor::DimOp>(op->getLoc(), input, 1);
|
Value channelDim = rewriter.create<tensor::DimOp>(op->getLoc(), input, 1);
|
||||||
|
|
||||||
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
|
if (options.dimSizeIndexBits == 32) {
|
||||||
auto channelDimI64 = rewriter.create<mlir::arith::IndexCastOp>(
|
auto channelDimI64 = rewriter.create<mlir::arith::IndexCastOp>(
|
||||||
op->getLoc(), rewriter.getI64Type(), channelDim);
|
op->getLoc(), rewriter.getI64Type(), channelDim);
|
||||||
channelDim = rewriter.create<arith::TruncIOp>(
|
channelDim = rewriter.create<arith::TruncIOp>(
|
||||||
op->getLoc(), rewriter.getI32Type(), channelDimI64);
|
op->getLoc(), rewriter.getI32Type(), channelDimI64);
|
||||||
#endif
|
}
|
||||||
|
|
||||||
Value channelShape = rewriter.create<tensor::FromElementsOp>(
|
Value channelShape = rewriter.create<tensor::FromElementsOp>(
|
||||||
op->getLoc(), ValueRange{channelDim});
|
op->getLoc(), ValueRange{channelDim});
|
||||||
|
@ -806,10 +765,8 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// AtenNativeLayerNormOp
|
// AtenNativeLayerNormOp
|
||||||
namespace {
|
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
||||||
AtenNativeLayerNormOp op, OpAdaptor adaptor,
|
AtenNativeLayerNormOp op, OpAdaptor adaptor,
|
||||||
|
@ -949,10 +906,8 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// AtenCatOp
|
// AtenCatOp
|
||||||
namespace {
|
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
||||||
AtenCatOp op, OpAdaptor adaptor,
|
AtenCatOp op, OpAdaptor adaptor,
|
||||||
|
@ -983,10 +938,8 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
||||||
op, outType, ValueRange(builtinTensors), posDim);
|
op, outType, ValueRange(builtinTensors), posDim);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// AtenNumelOp
|
// AtenNumelOp
|
||||||
namespace {
|
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<AtenNumelOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenNumelOp>::matchAndRewrite(
|
||||||
AtenNumelOp op,
|
AtenNumelOp op,
|
||||||
|
@ -996,7 +949,7 @@ LogicalResult ConvertAtenOp<AtenNumelOp>::matchAndRewrite(
|
||||||
auto selfTy = self.getType().dyn_cast<RankedTensorType>();
|
auto selfTy = self.getType().dyn_cast<RankedTensorType>();
|
||||||
size_t rank = selfTy.getRank();
|
size_t rank = selfTy.getRank();
|
||||||
|
|
||||||
Type intType = rewriter.getIntegerType(mhlo::kMhloDimSizeBits);
|
Type intType = rewriter.getIntegerType(options.dimSizeIndexBits);
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value numel =
|
Value numel =
|
||||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(intType, 1));
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(intType, 1));
|
||||||
|
@ -1015,26 +968,18 @@ LogicalResult ConvertAtenOp<AtenNumelOp>::matchAndRewrite(
|
||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
ConversionTarget &target) {
|
ConversionTarget &target, const TorchToMhloOptions &options) {
|
||||||
MLIRContext *context = patterns.getContext();
|
MLIRContext *context = patterns.getContext();
|
||||||
|
|
||||||
target.addIllegalOp<AtenTransposeIntOp>();
|
target.addIllegalOp<AtenTransposeIntOp>();
|
||||||
patterns.add<ConvertAtenTransposeIntOp>(typeConverter, context);
|
patterns.add<ConvertAtenTransposeIntOp>(typeConverter, context);
|
||||||
|
|
||||||
target.addIllegalOp<AtenBroadcastToOp>();
|
|
||||||
patterns.add<ConvertAtenBroadcastToOp>(typeConverter, context);
|
|
||||||
|
|
||||||
target.addIllegalOp<AtenPermuteOp>();
|
|
||||||
patterns.add<ConvertAtenPermuteOp>(typeConverter, context);
|
|
||||||
|
|
||||||
#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, MhloOp) \
|
#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, MhloOp) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, MhloOp>>(typeConverter, \
|
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, MhloOp>>(typeConverter, context)
|
||||||
context);
|
|
||||||
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, mhlo::LogOp);
|
INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, mhlo::LogOp);
|
||||||
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, mhlo::ExpOp);
|
INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, mhlo::ExpOp);
|
||||||
INSERT_UNARY_FPONLY_PATTERN(AtenCloneOp, mhlo::CopyOp);
|
INSERT_UNARY_FPONLY_PATTERN(AtenCloneOp, mhlo::CopyOp);
|
||||||
|
@ -1045,14 +990,14 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
||||||
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
|
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenConstPatternOp<AtenOp, fillVal>>(typeConverter, \
|
patterns.add<ConvertAtenConstPatternOp<AtenOp, fillVal>>(typeConverter, \
|
||||||
context);
|
context)
|
||||||
INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1);
|
INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1);
|
||||||
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
|
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
|
||||||
#undef INSERT_CONSTANT_FILL_PATTERN
|
#undef INSERT_CONSTANT_FILL_PATTERN
|
||||||
|
|
||||||
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, ChloOp) \
|
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, ChloOp) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenAddSubOp<AtenOp, ChloOp>>(typeConverter, context);
|
patterns.add<ConvertAtenAddSubOp<AtenOp, ChloOp>>(typeConverter, context)
|
||||||
INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, chlo::BroadcastAddOp);
|
INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, chlo::BroadcastAddOp);
|
||||||
INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, chlo::BroadcastAddOp);
|
INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, chlo::BroadcastAddOp);
|
||||||
INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, chlo::BroadcastSubOp);
|
INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, chlo::BroadcastSubOp);
|
||||||
|
@ -1062,7 +1007,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
||||||
|
|
||||||
#define INSERT_BINARY_MULDIV_PATTERN(AtenOp, ChloOp) \
|
#define INSERT_BINARY_MULDIV_PATTERN(AtenOp, ChloOp) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenMulDivOp<AtenOp, ChloOp>>(typeConverter, context);
|
patterns.add<ConvertAtenMulDivOp<AtenOp, ChloOp>>(typeConverter, context)
|
||||||
INSERT_BINARY_MULDIV_PATTERN(AtenMulTensorOp, chlo::BroadcastMulOp);
|
INSERT_BINARY_MULDIV_PATTERN(AtenMulTensorOp, chlo::BroadcastMulOp);
|
||||||
INSERT_BINARY_MULDIV_PATTERN(AtenMulScalarOp, chlo::BroadcastMulOp);
|
INSERT_BINARY_MULDIV_PATTERN(AtenMulScalarOp, chlo::BroadcastMulOp);
|
||||||
INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorOp, chlo::BroadcastDivOp);
|
INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorOp, chlo::BroadcastDivOp);
|
||||||
|
@ -1072,7 +1017,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
||||||
|
|
||||||
#define INSERT_BINARY_COMPARE_PATTERN(AtenOp) \
|
#define INSERT_BINARY_COMPARE_PATTERN(AtenOp) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenCompareOp<AtenOp>>(typeConverter, context);
|
patterns.add<ConvertAtenCompareOp<AtenOp>>(typeConverter, context)
|
||||||
|
|
||||||
INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp);
|
INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp);
|
||||||
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp);
|
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp);
|
||||||
|
@ -1086,7 +1031,11 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
||||||
|
|
||||||
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
|
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
|
||||||
|
|
||||||
|
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenPermuteOp);
|
||||||
|
|
||||||
INSERT_ATENOP_PATTERN(AtenTanhOp);
|
INSERT_ATENOP_PATTERN(AtenTanhOp);
|
||||||
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
|
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenReciprocalOp);
|
INSERT_ATENOP_PATTERN(AtenReciprocalOp);
|
||||||
|
|
|
@ -23,18 +23,14 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
using namespace mlir::torch::torch_to_mhlo;
|
||||||
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
|
|
||||||
static constexpr size_t kMhloDimSizeBits = 32;
|
|
||||||
#else
|
|
||||||
static constexpr size_t kMhloDimSizeBits = 64;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
|
Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
|
||||||
Value input, Value indices, int64_t axis) {
|
Value input, Value indices, int64_t axis,
|
||||||
|
size_t dimSizeIndexBits) {
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Type intType = rewriter.getIntegerType(kMhloDimSizeBits);
|
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
|
||||||
Value one = rewriter.create<arith::ConstantOp>(
|
Value one = rewriter.create<arith::ConstantOp>(
|
||||||
loc, rewriter.getIntegerAttr(intType, 1));
|
loc, rewriter.getIntegerAttr(intType, 1));
|
||||||
|
|
||||||
|
@ -98,16 +94,7 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
|
||||||
sliceSizesTensor, dimsAttr)
|
sliceSizesTensor, dimsAttr)
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
|
} // namespace
|
||||||
template <typename AtenOpT>
|
|
||||||
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
|
|
||||||
public:
|
|
||||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
|
||||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const override;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html
|
// Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html
|
||||||
// padding_idx (int, optional)
|
// padding_idx (int, optional)
|
||||||
|
@ -149,8 +136,8 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "sparse gradients is currently not supported");
|
op, "sparse gradients is currently not supported");
|
||||||
|
|
||||||
Value output =
|
Value output = gatherTensorAlongSingleAxis(
|
||||||
gatherTensorAlongSingleAxis(rewriter, op, weight, adaptor.indices(), 0);
|
rewriter, op, weight, adaptor.indices(), 0, options.dimSizeIndexBits);
|
||||||
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
|
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
|
||||||
op, getTypeConverter()->convertType(op.getType()), output);
|
op, getTypeConverter()->convertType(op.getType()), output);
|
||||||
|
|
||||||
|
@ -170,24 +157,23 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only constant dim is currently supported");
|
op, "only constant dim is currently supported");
|
||||||
|
|
||||||
Value output =
|
Value output = gatherTensorAlongSingleAxis(
|
||||||
gatherTensorAlongSingleAxis(rewriter, op, self, adaptor.index(), dim);
|
rewriter, op, self, adaptor.index(), dim, options.dimSizeIndexBits);
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
|
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
|
||||||
op, getTypeConverter()->convertType(op.getType()), output);
|
op, getTypeConverter()->convertType(op.getType()), output);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void mlir::torch::torch_to_mhlo::populateGatherOpPatternsAndLegality(
|
void mlir::torch::torch_to_mhlo::populateGatherOpPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
ConversionTarget &target) {
|
ConversionTarget &target, const TorchToMhloOptions &options) {
|
||||||
MLIRContext *context = patterns.getContext();
|
MLIRContext *context = patterns.getContext();
|
||||||
|
|
||||||
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
|
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
|
||||||
INSERT_ATENOP_PATTERN(AtenEmbeddingOp);
|
INSERT_ATENOP_PATTERN(AtenEmbeddingOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
|
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
|
||||||
#undef INSERT_ATENOP_PATTERN
|
#undef INSERT_ATENOP_PATTERN
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
using namespace mlir::torch::torch_to_mhlo;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
|
Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
|
||||||
|
@ -71,7 +72,8 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input,
|
||||||
}
|
}
|
||||||
|
|
||||||
void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
|
void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
|
||||||
Value &inpRhs, int64_t leadingRank) {
|
Value &inpRhs, int64_t leadingRank,
|
||||||
|
size_t dimSizeIndexBits) {
|
||||||
Value lhs = inpLhs;
|
Value lhs = inpLhs;
|
||||||
Value rhs = inpRhs;
|
Value rhs = inpRhs;
|
||||||
auto lhsRankTy = inpLhs.getType().dyn_cast<RankedTensorType>();
|
auto lhsRankTy = inpLhs.getType().dyn_cast<RankedTensorType>();
|
||||||
|
@ -92,9 +94,10 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
|
||||||
std::vector<int64_t> newShape(rhsShape.begin(),
|
std::vector<int64_t> newShape(rhsShape.begin(),
|
||||||
rhsShape.begin() + leadingRank);
|
rhsShape.begin() + leadingRank);
|
||||||
newShape.insert(newShape.end(), lhsShape.begin(), lhsShape.end());
|
newShape.insert(newShape.end(), lhsShape.begin(), lhsShape.end());
|
||||||
auto newDimSizes =
|
auto newDimSizes = *mhlo::getDimSizesOfTensor(
|
||||||
*mhlo::getDimSizesOfTensor(rewriter, op, rhs, leadingDims);
|
rewriter, op, rhs, leadingDims, dimSizeIndexBits);
|
||||||
auto lhsDimSizes = *mhlo::getDimSizesOfTensor(rewriter, op, lhs);
|
auto lhsDimSizes =
|
||||||
|
*mhlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits);
|
||||||
newDimSizes.insert(newDimSizes.end(), lhsDimSizes.begin(),
|
newDimSizes.insert(newDimSizes.end(), lhsDimSizes.begin(),
|
||||||
lhsDimSizes.end());
|
lhsDimSizes.end());
|
||||||
lhs = getBroadcastTensor(rewriter, op, lhs, newShape, newDimSizes,
|
lhs = getBroadcastTensor(rewriter, op, lhs, newShape, newDimSizes,
|
||||||
|
@ -103,9 +106,10 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
|
||||||
std::vector<int64_t> newShape(lhsShape.begin(),
|
std::vector<int64_t> newShape(lhsShape.begin(),
|
||||||
lhsShape.begin() + leadingRank);
|
lhsShape.begin() + leadingRank);
|
||||||
newShape.insert(newShape.end(), rhsShape.begin(), rhsShape.end());
|
newShape.insert(newShape.end(), rhsShape.begin(), rhsShape.end());
|
||||||
auto newDimSizes =
|
auto newDimSizes = *mhlo::getDimSizesOfTensor(
|
||||||
*mhlo::getDimSizesOfTensor(rewriter, op, lhs, leadingDims);
|
rewriter, op, lhs, leadingDims, dimSizeIndexBits);
|
||||||
auto rhsDimSizes = *mhlo::getDimSizesOfTensor(rewriter, op, rhs);
|
auto rhsDimSizes =
|
||||||
|
*mhlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits);
|
||||||
newDimSizes.insert(newDimSizes.end(), rhsDimSizes.begin(),
|
newDimSizes.insert(newDimSizes.end(), rhsDimSizes.begin(),
|
||||||
rhsDimSizes.end());
|
rhsDimSizes.end());
|
||||||
rhs = getBroadcastTensor(rewriter, op, rhs, newShape, newDimSizes,
|
rhs = getBroadcastTensor(rewriter, op, rhs, newShape, newDimSizes,
|
||||||
|
@ -122,9 +126,9 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
|
||||||
// implement their specialized input processing (e.g transpose), and output
|
// implement their specialized input processing (e.g transpose), and output
|
||||||
// processing, e.g. GEMM or fully connected bias handling.
|
// processing, e.g. GEMM or fully connected bias handling.
|
||||||
template <typename AtenOpT>
|
template <typename AtenOpT>
|
||||||
class ConvertAtenMatmulBaseOp : public OpConversionPattern<AtenOpT> {
|
class ConvertAtenMatmulBaseOp : public ConvertAtenOp<AtenOpT> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
|
||||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||||
// Each variant must implement corresponding parameter parsing options.
|
// Each variant must implement corresponding parameter parsing options.
|
||||||
// Maintain separate input read functions for each variant because it is not
|
// Maintain separate input read functions for each variant because it is not
|
||||||
|
@ -159,20 +163,24 @@ public:
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
|
||||||
int64_t nBatchDims;
|
int64_t nBatchDims;
|
||||||
if (rhsRank <= 2) {
|
if (rhsRank <= 2) {
|
||||||
auto leadingRank = lhsRank - 2;
|
auto leadingRank = lhsRank - 2;
|
||||||
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank);
|
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank,
|
||||||
|
options.dimSizeIndexBits);
|
||||||
nBatchDims = leadingRank;
|
nBatchDims = leadingRank;
|
||||||
} else if (lhsRank <= 2) {
|
} else if (lhsRank <= 2) {
|
||||||
auto leadingRank = rhsRank - 2;
|
auto leadingRank = rhsRank - 2;
|
||||||
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank);
|
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank,
|
||||||
|
options.dimSizeIndexBits);
|
||||||
nBatchDims = leadingRank;
|
nBatchDims = leadingRank;
|
||||||
} else {
|
} else {
|
||||||
assert(rhsRank > 2 && lhsRank > 2);
|
assert(rhsRank > 2 && lhsRank > 2);
|
||||||
auto leadingRank = std::max(lhsRank - rhsRank, rhsRank - lhsRank);
|
auto leadingRank = std::max(lhsRank - rhsRank, rhsRank - lhsRank);
|
||||||
nBatchDims = std::max(lhsRank - 2, rhsRank - 2);
|
nBatchDims = std::max(lhsRank - 2, rhsRank - 2);
|
||||||
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank);
|
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank,
|
||||||
|
options.dimSizeIndexBits);
|
||||||
}
|
}
|
||||||
auto batchDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, nBatchDims));
|
auto batchDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, nBatchDims));
|
||||||
auto lhsContractingDim = nBatchDims + 1;
|
auto lhsContractingDim = nBatchDims + 1;
|
||||||
|
@ -187,7 +195,7 @@ public:
|
||||||
/*rhsBatchingDimensions=*/batchDims,
|
/*rhsBatchingDimensions=*/batchDims,
|
||||||
/*lhsContractingDimensions=*/{lhsContractingDim},
|
/*lhsContractingDimensions=*/{lhsContractingDim},
|
||||||
/*rhsContractingDimensions=*/{rhsContractingDim});
|
/*rhsContractingDimensions=*/{rhsContractingDim});
|
||||||
auto resultTy = OpConversionPattern<AtenOpT>::getTypeConverter()
|
auto resultTy = ConvertAtenOp<AtenOpT>::getTypeConverter()
|
||||||
->convertType(op.getType())
|
->convertType(op.getType())
|
||||||
.template cast<RankedTensorType>();
|
.template cast<RankedTensorType>();
|
||||||
|
|
||||||
|
@ -215,7 +223,7 @@ public:
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
|
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
|
||||||
op,
|
op,
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()
|
ConvertAtenOp<AtenOpT>::getTypeConverter()
|
||||||
->convertType(op.getType())
|
->convertType(op.getType())
|
||||||
.template cast<RankedTensorType>(),
|
.template cast<RankedTensorType>(),
|
||||||
output);
|
output);
|
||||||
|
@ -340,7 +348,10 @@ public:
|
||||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
||||||
auto leadingRank = std::max(lhsTy.getRank() - rhsTy.getRank(),
|
auto leadingRank = std::max(lhsTy.getRank() - rhsTy.getRank(),
|
||||||
rhsTy.getRank() - lhsTy.getRank());
|
rhsTy.getRank() - lhsTy.getRank());
|
||||||
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank);
|
|
||||||
|
const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
|
||||||
|
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank,
|
||||||
|
options.dimSizeIndexBits);
|
||||||
auto resultRank = std::max(lhsTy.getRank(), rhsTy.getRank());
|
auto resultRank = std::max(lhsTy.getRank(), rhsTy.getRank());
|
||||||
auto nBatchDims = resultRank - 2;
|
auto nBatchDims = resultRank - 2;
|
||||||
auto batchDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, nBatchDims));
|
auto batchDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, nBatchDims));
|
||||||
|
@ -356,8 +367,7 @@ public:
|
||||||
/*rhsContractingDimensions=*/{rhsContractingDim});
|
/*rhsContractingDimensions=*/{rhsContractingDim});
|
||||||
|
|
||||||
auto resultTy =
|
auto resultTy =
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType());
|
||||||
op.getType());
|
|
||||||
|
|
||||||
Value matmulOutput = rewriter.create<mhlo::DotGeneralOp>(
|
Value matmulOutput = rewriter.create<mhlo::DotGeneralOp>(
|
||||||
op->getLoc(), resultTy, lhs, rhs, dotDimensionNumbers, nullptr);
|
op->getLoc(), resultTy, lhs, rhs, dotDimensionNumbers, nullptr);
|
||||||
|
@ -377,12 +387,9 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
class ConvertAtenConvolutionOp : public ConvertAtenOp<AtenConvolutionOp> {
|
||||||
|
|
||||||
namespace {
|
|
||||||
class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
|
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern<AtenConvolutionOp>::OpConversionPattern;
|
using ConvertAtenOp<AtenConvolutionOp>::ConvertAtenOp;
|
||||||
using OpAdaptor = typename AtenConvolutionOp::Adaptor;
|
using OpAdaptor = typename AtenConvolutionOp::Adaptor;
|
||||||
|
|
||||||
Value reshapeConvWeight(PatternRewriter &rewriter, Operation *op,
|
Value reshapeConvWeight(PatternRewriter &rewriter, Operation *op,
|
||||||
|
@ -390,8 +397,9 @@ public:
|
||||||
auto weightTy = weight.getType().cast<RankedTensorType>();
|
auto weightTy = weight.getType().cast<RankedTensorType>();
|
||||||
auto weightElemTy = weightTy.getElementType();
|
auto weightElemTy = weightTy.getElementType();
|
||||||
auto rank = weightTy.getRank();
|
auto rank = weightTy.getRank();
|
||||||
SmallVector<Value> weightShapeVec =
|
const auto &options = getOptions();
|
||||||
*mhlo::getDimSizesOfTensor(rewriter, op, weight);
|
SmallVector<Value> weightShapeVec = *mhlo::getDimSizesOfTensor(
|
||||||
|
rewriter, op, weight, options.dimSizeIndexBits);
|
||||||
auto weightShape = weightTy.getShape();
|
auto weightShape = weightTy.getShape();
|
||||||
SmallVector<int64_t> weightShapeInt(rank);
|
SmallVector<int64_t> weightShapeInt(rank);
|
||||||
std::copy(weightShape.begin(), weightShape.end(), weightShapeInt.begin());
|
std::copy(weightShape.begin(), weightShape.end(), weightShapeInt.begin());
|
||||||
|
@ -601,9 +609,8 @@ public:
|
||||||
return mhloConvOp.getResult();
|
return mhloConvOp.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult matchAndRewrite(AtenConvolutionOp op, OpAdaptor adaptor,
|
||||||
matchAndRewrite(AtenConvolutionOp op, OpAdaptor adaptor,
|
ConversionPatternRewriter &rewriter) const {
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
Value input = adaptor.input();
|
Value input = adaptor.input();
|
||||||
Value weight = adaptor.weight();
|
Value weight = adaptor.weight();
|
||||||
|
|
||||||
|
@ -714,7 +721,10 @@ public:
|
||||||
// Reshape and promote bias
|
// Reshape and promote bias
|
||||||
auto inputUnsqzDims =
|
auto inputUnsqzDims =
|
||||||
llvm::to_vector<4>(llvm::seq<int64_t>(-nSpatialDims, 0));
|
llvm::to_vector<4>(llvm::seq<int64_t>(-nSpatialDims, 0));
|
||||||
bias = *mhlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims);
|
|
||||||
|
const auto &options = getOptions();
|
||||||
|
bias = *mhlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims,
|
||||||
|
options.dimSizeIndexBits);
|
||||||
bias = mhlo::promoteType(rewriter, bias, outTy);
|
bias = mhlo::promoteType(rewriter, bias, outTy);
|
||||||
|
|
||||||
DenseIntElementsAttr bcastDimensions;
|
DenseIntElementsAttr bcastDimensions;
|
||||||
|
@ -727,31 +737,31 @@ public:
|
||||||
|
|
||||||
void mlir::torch::torch_to_mhlo::populateLinearOpPatternsAndLegality(
|
void mlir::torch::torch_to_mhlo::populateLinearOpPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
ConversionTarget &target) {
|
ConversionTarget &target, const TorchToMhloOptions &options) {
|
||||||
MLIRContext *context = patterns.getContext();
|
MLIRContext *context = patterns.getContext();
|
||||||
|
|
||||||
#define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \
|
#define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenMatMulOp<AtenOp>>(typeConverter, context);
|
patterns.add<ConvertAtenMatMulOp<AtenOp>>(typeConverter, context, options)
|
||||||
INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp);
|
INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp);
|
||||||
#undef INSERT_MATMUL_ATEMOP_PATTERN
|
#undef INSERT_MATMUL_ATEMOP_PATTERN
|
||||||
|
|
||||||
#define INSERT_MM_ATENOP_PATTERN(AtenOp) \
|
#define INSERT_MM_ATENOP_PATTERN(AtenOp) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenMmOp<AtenOp>>(typeConverter, context);
|
patterns.add<ConvertAtenMmOp<AtenOp>>(typeConverter, context, options)
|
||||||
INSERT_MM_ATENOP_PATTERN(AtenMmOp);
|
INSERT_MM_ATENOP_PATTERN(AtenMmOp);
|
||||||
INSERT_MM_ATENOP_PATTERN(AtenBmmOp);
|
INSERT_MM_ATENOP_PATTERN(AtenBmmOp);
|
||||||
#undef INSERT_MM_ATEMOP_PATTERN
|
#undef INSERT_MM_ATEMOP_PATTERN
|
||||||
|
|
||||||
#define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \
|
#define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenLinearOp<AtenOp>>(typeConverter, context);
|
patterns.add<ConvertAtenLinearOp<AtenOp>>(typeConverter, context, options)
|
||||||
INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp);
|
INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp);
|
||||||
#undef INSERT_LINEAR_ATEMOP_PATTERN
|
#undef INSERT_LINEAR_ATEMOP_PATTERN
|
||||||
|
|
||||||
#define INSERT_CONVOLUTION_ATENOP_PATTERN(AtenOp) \
|
#define INSERT_CONVOLUTION_ATENOP_PATTERN(AtenOp) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenConvolutionOp>(typeConverter, context);
|
patterns.add<ConvertAtenConvolutionOp>(typeConverter, context, options)
|
||||||
INSERT_CONVOLUTION_ATENOP_PATTERN(AtenConvolutionOp);
|
INSERT_CONVOLUTION_ATENOP_PATTERN(AtenConvolutionOp);
|
||||||
#undef INSERT_CONVOLUTION_ATENOP_PATTERN
|
#undef INSERT_CONVOLUTION_ATENOP_PATTERN
|
||||||
}
|
}
|
||||||
|
|
|
@ -259,9 +259,10 @@ SmallVector<size_t> toPositiveDims(ArrayRef<int64_t> dims, int64_t rank) {
|
||||||
return posDims;
|
return posDims;
|
||||||
}
|
}
|
||||||
|
|
||||||
FailureOr<SmallVector<Value, 4>>
|
FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
|
||||||
getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value,
|
Operation *op, Value value,
|
||||||
ArrayRef<int64_t> inpDims) {
|
ArrayRef<int64_t> inpDims,
|
||||||
|
size_t dimSizeIndexBits) {
|
||||||
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
|
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
|
||||||
if (!valueTy) {
|
if (!valueTy) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -276,14 +277,15 @@ getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value,
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
for (auto d : dims) {
|
for (auto d : dims) {
|
||||||
dimSizes.emplace_back(rewriter.create<arith::IndexCastOp>(
|
dimSizes.emplace_back(rewriter.create<arith::IndexCastOp>(
|
||||||
loc, rewriter.getIntegerType(kMhloDimSizeBits),
|
loc, rewriter.getIntegerType(dimSizeIndexBits),
|
||||||
rewriter.create<tensor::DimOp>(loc, value, d)));
|
rewriter.create<tensor::DimOp>(loc, value, d)));
|
||||||
}
|
}
|
||||||
return dimSizes;
|
return dimSizes;
|
||||||
}
|
}
|
||||||
|
|
||||||
FailureOr<SmallVector<Value, 4>>
|
FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
|
||||||
getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value) {
|
Operation *op, Value value,
|
||||||
|
size_t dimSizeIndexBits) {
|
||||||
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
|
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
|
||||||
if (!valueTy) {
|
if (!valueTy) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -294,12 +296,12 @@ getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value) {
|
||||||
// Get int vector [0, 1, ..., rank-1]
|
// Get int vector [0, 1, ..., rank-1]
|
||||||
std::vector<int64_t> dims(rank);
|
std::vector<int64_t> dims(rank);
|
||||||
std::iota(dims.begin(), dims.end(), 0);
|
std::iota(dims.begin(), dims.end(), 0);
|
||||||
return getDimSizesOfTensor(rewriter, op, value, dims);
|
return getDimSizesOfTensor(rewriter, op, value, dims, dimSizeIndexBits);
|
||||||
}
|
}
|
||||||
|
|
||||||
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
Value tensor,
|
Value tensor, ArrayRef<int64_t> inputUnsqzDims,
|
||||||
ArrayRef<int64_t> inputUnsqzDims) {
|
size_t dimSizeIndexBits) {
|
||||||
// Returns a new tensor with dims of size 1 inserted at the specified
|
// Returns a new tensor with dims of size 1 inserted at the specified
|
||||||
// position.
|
// position.
|
||||||
//
|
//
|
||||||
|
@ -307,7 +309,8 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
// tensor) are specified with unsqzDims. Indices must be in-order, and in
|
// tensor) are specified with unsqzDims. Indices must be in-order, and in
|
||||||
// range of tensor rank. Thus, unsqueeze a rank 1 tensor with {0, 2}, {0, 1,
|
// range of tensor rank. Thus, unsqueeze a rank 1 tensor with {0, 2}, {0, 1,
|
||||||
// 3}, {0, 1, 2} are all valid dimension sets, but {0, 3}, {2} are not.
|
// 3}, {0, 1, 2} are all valid dimension sets, but {0, 3}, {2} are not.
|
||||||
auto dimSizesInfo = getDimSizesOfTensor(rewriter, op, tensor);
|
auto dimSizesInfo =
|
||||||
|
getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits);
|
||||||
if (failed(dimSizesInfo))
|
if (failed(dimSizesInfo))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "failed to get dimension sizes of the input");
|
op, "failed to get dimension sizes of the input");
|
||||||
|
@ -324,7 +327,7 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
auto rankTy = tensor.getType().dyn_cast<RankedTensorType>();
|
auto rankTy = tensor.getType().dyn_cast<RankedTensorType>();
|
||||||
auto oldShape = rankTy.getShape();
|
auto oldShape = rankTy.getShape();
|
||||||
Type intType = rewriter.getIntegerType(kMhloDimSizeBits);
|
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
|
||||||
auto one = rewriter.create<arith::ConstantOp>(
|
auto one = rewriter.create<arith::ConstantOp>(
|
||||||
loc, rewriter.getIntegerAttr(intType, 1));
|
loc, rewriter.getIntegerAttr(intType, 1));
|
||||||
|
|
||||||
|
|
|
@ -19,11 +19,6 @@
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace mhlo {
|
namespace mhlo {
|
||||||
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
|
|
||||||
static constexpr size_t kMhloDimSizeBits = 32;
|
|
||||||
#else
|
|
||||||
static constexpr size_t kMhloDimSizeBits = 64;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
using mlir::ConversionPatternRewriter;
|
using mlir::ConversionPatternRewriter;
|
||||||
|
|
||||||
|
@ -60,22 +55,23 @@ SmallVector<size_t> toPositiveDims(ArrayRef<int64_t> dims, int64_t rank);
|
||||||
// Get the dimension sizes of the input tensor, given the dimension axes
|
// Get the dimension sizes of the input tensor, given the dimension axes
|
||||||
FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
|
FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
|
||||||
Operation *op, Value value,
|
Operation *op, Value value,
|
||||||
ArrayRef<int64_t> inpDims);
|
ArrayRef<int64_t> inpDims,
|
||||||
|
size_t dimSizeIndexBits);
|
||||||
|
|
||||||
// Get the dimension sizes of the input tensor
|
// Get the dimension sizes of the input tensor
|
||||||
FailureOr<SmallVector<Value, 4>>
|
FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
|
||||||
getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value);
|
Operation *op, Value value,
|
||||||
|
size_t dimSizeIndexBits);
|
||||||
|
|
||||||
// Get a tensor that unsqueezed the specified dimensions of the input tensor
|
// Get a tensor that unsqueezed the specified dimensions of the input tensor
|
||||||
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
Value tensor,
|
Value tensor, ArrayRef<int64_t> inputUnsqzDims,
|
||||||
ArrayRef<int64_t> inputUnsqzDims);
|
size_t dimSizeIndexBits);
|
||||||
|
|
||||||
|
|
||||||
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
|
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
|
||||||
const APFloat &constant, Value shape,
|
const APFloat &constant, Value shape,
|
||||||
TensorType outType);
|
TensorType outType);
|
||||||
} // namespace mhlo
|
} // namespace mhlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H
|
#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H
|
||||||
|
|
|
@ -28,6 +28,7 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
using namespace mlir::torch::torch_to_mhlo;
|
||||||
|
|
||||||
static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
||||||
PatternRewriter &rewriter) {
|
PatternRewriter &rewriter) {
|
||||||
|
@ -72,22 +73,9 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
|
||||||
template <typename AtenOpT>
|
|
||||||
class ConvertAtenPoolingOp : public OpConversionPattern<AtenOpT> {
|
|
||||||
public:
|
|
||||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
|
||||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const override;
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// AtenMaxPool2dOp
|
// AtenMaxPool2dOp
|
||||||
namespace {
|
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenPoolingOp<AtenMaxPool2dOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
|
||||||
AtenMaxPool2dOp op, OpAdaptor adaptor,
|
AtenMaxPool2dOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value input = adaptor.self();
|
Value input = adaptor.self();
|
||||||
|
@ -186,12 +174,10 @@ LogicalResult ConvertAtenPoolingOp<AtenMaxPool2dOp>::matchAndRewrite(
|
||||||
rewriter.replaceOp(op, reduceWindowOp.getResults());
|
rewriter.replaceOp(op, reduceWindowOp.getResults());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// AtenMaxPool2dWithIndicesOp
|
// AtenMaxPool2dWithIndicesOp
|
||||||
namespace {
|
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenPoolingOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
||||||
AtenMaxPool2dWithIndicesOp op, OpAdaptor adaptor,
|
AtenMaxPool2dWithIndicesOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value input = adaptor.self();
|
Value input = adaptor.self();
|
||||||
|
@ -269,7 +255,9 @@ LogicalResult ConvertAtenPoolingOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
||||||
rewriter.getI64Type()),
|
rewriter.getI64Type()),
|
||||||
mhloPadding);
|
mhloPadding);
|
||||||
|
|
||||||
auto inputShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input);
|
const auto &options = getOptions();
|
||||||
|
auto inputShapeInfo =
|
||||||
|
mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||||
if (failed(inputShapeInfo)) {
|
if (failed(inputShapeInfo)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "failed to get dimension sizes of the input");
|
op, "failed to get dimension sizes of the input");
|
||||||
|
@ -379,12 +367,10 @@ LogicalResult ConvertAtenPoolingOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
||||||
rewriter.replaceOp(op, reduceWindowOp.getResults());
|
rewriter.replaceOp(op, reduceWindowOp.getResults());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// AtenAvgPool2dOp
|
// AtenAvgPool2dOp
|
||||||
namespace {
|
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenPoolingOp<AtenAvgPool2dOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenAvgPool2dOp>::matchAndRewrite(
|
||||||
AtenAvgPool2dOp op, OpAdaptor adaptor,
|
AtenAvgPool2dOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value input = adaptor.self();
|
Value input = adaptor.self();
|
||||||
|
@ -502,7 +488,9 @@ LogicalResult ConvertAtenPoolingOp<AtenAvgPool2dOp>::matchAndRewrite(
|
||||||
Value windowSizeConst =
|
Value windowSizeConst =
|
||||||
mhlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
|
mhlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
|
||||||
windowSizeConst = mhlo::promoteType(rewriter, windowSizeConst, outTy);
|
windowSizeConst = mhlo::promoteType(rewriter, windowSizeConst, outTy);
|
||||||
auto inputShapeVec = *mhlo::getDimSizesOfTensor(rewriter, op, input);
|
const auto &options = getOptions();
|
||||||
|
auto inputShapeVec =
|
||||||
|
*mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||||
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||||
op->getLoc(), inputShapeVec);
|
op->getLoc(), inputShapeVec);
|
||||||
|
|
||||||
|
@ -540,17 +528,15 @@ LogicalResult ConvertAtenPoolingOp<AtenAvgPool2dOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void mlir::torch::torch_to_mhlo::populatePoolingOpPatternsAndLegality(
|
void mlir::torch::torch_to_mhlo::populatePoolingOpPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
ConversionTarget &target) {
|
ConversionTarget &target, const TorchToMhloOptions &options) {
|
||||||
MLIRContext *context = patterns.getContext();
|
MLIRContext *context = patterns.getContext();
|
||||||
target.addIllegalOp<AtenMaxPool2dOp>();
|
target.addIllegalOp<AtenMaxPool2dOp>();
|
||||||
patterns.add<ConvertAtenPoolingOp<AtenMaxPool2dOp>>(typeConverter, context);
|
patterns.add<ConvertAtenOp<AtenMaxPool2dOp>>(typeConverter, context, options);
|
||||||
target.addIllegalOp<AtenAvgPool2dOp>();
|
target.addIllegalOp<AtenAvgPool2dOp>();
|
||||||
patterns.add<ConvertAtenPoolingOp<AtenAvgPool2dOp>>(typeConverter, context);
|
patterns.add<ConvertAtenOp<AtenAvgPool2dOp>>(typeConverter, context, options);
|
||||||
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
|
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
|
||||||
patterns.add<ConvertAtenPoolingOp<AtenMaxPool2dWithIndicesOp>>(typeConverter,
|
patterns.add<ConvertAtenOp<AtenMaxPool2dWithIndicesOp>>(typeConverter,
|
||||||
context);
|
context, options);
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,25 +16,56 @@ namespace mlir {
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace torch_to_mhlo {
|
namespace torch_to_mhlo {
|
||||||
|
|
||||||
|
struct TorchToMhloOptions {
|
||||||
|
bool enableStaticShape = false;
|
||||||
|
size_t dimSizeIndexBits = 64;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename AtenOpT>
|
||||||
|
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
|
||||||
|
public:
|
||||||
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||||
|
ConvertAtenOp(TypeConverter &typeConverter, MLIRContext *context,
|
||||||
|
const TorchToMhloOptions &options)
|
||||||
|
: OpConversionPattern<AtenOpT>(typeConverter, context) {
|
||||||
|
this->options = options;
|
||||||
|
}
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
return rewriter.notifyMatchFailure(op, "haven't been implemented");
|
||||||
|
}
|
||||||
|
const TorchToMhloOptions &getOptions() const { return options; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
TorchToMhloOptions options;
|
||||||
|
};
|
||||||
|
|
||||||
void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
|
void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||||
RewritePatternSet &patterns,
|
RewritePatternSet &patterns,
|
||||||
ConversionTarget &target);
|
ConversionTarget &target,
|
||||||
|
const TorchToMhloOptions &options);
|
||||||
void populateViewLikeOpPatternsAndLegality(TypeConverter &typeConverter,
|
void populateViewLikeOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||||
RewritePatternSet &patterns,
|
RewritePatternSet &patterns,
|
||||||
ConversionTarget &target);
|
ConversionTarget &target,
|
||||||
|
const TorchToMhloOptions &options);
|
||||||
void populateGatherOpPatternsAndLegality(TypeConverter &typeConverter,
|
void populateGatherOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||||
RewritePatternSet &patterns,
|
RewritePatternSet &patterns,
|
||||||
ConversionTarget &target);
|
ConversionTarget &target,
|
||||||
|
const TorchToMhloOptions &options);
|
||||||
void populateReductionOpPatternsAndLegality(TypeConverter &typeConverter,
|
void populateReductionOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||||
RewritePatternSet &patterns,
|
RewritePatternSet &patterns,
|
||||||
ConversionTarget &target);
|
ConversionTarget &target,
|
||||||
|
const TorchToMhloOptions &options);
|
||||||
void populateLinearOpPatternsAndLegality(TypeConverter &typeConverter,
|
void populateLinearOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||||
RewritePatternSet &patterns,
|
RewritePatternSet &patterns,
|
||||||
ConversionTarget &target);
|
ConversionTarget &target,
|
||||||
|
const TorchToMhloOptions &options);
|
||||||
|
|
||||||
void populatePoolingOpPatternsAndLegality(TypeConverter &typeConverter,
|
void populatePoolingOpPatternsAndLegality(TypeConverter &typeConverter,
|
||||||
RewritePatternSet &patterns,
|
RewritePatternSet &patterns,
|
||||||
ConversionTarget &target);
|
ConversionTarget &target,
|
||||||
|
const TorchToMhloOptions &options);
|
||||||
|
|
||||||
} // namespace torch_to_mhlo
|
} // namespace torch_to_mhlo
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
using namespace mlir::torch::torch_to_mhlo;
|
||||||
|
|
||||||
static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
||||||
PatternRewriter &rewriter) {
|
PatternRewriter &rewriter) {
|
||||||
|
@ -72,7 +73,8 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
||||||
// Util for converting AtenArgmaxOp and AtenMaxDimOp
|
// Util for converting AtenArgmaxOp and AtenMaxDimOp
|
||||||
static llvm::Optional<ValueRange>
|
static llvm::Optional<ValueRange>
|
||||||
getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
||||||
ArrayRef<Value> inputShapeVec, int64_t dim) {
|
ArrayRef<Value> inputShapeVec, int64_t dim,
|
||||||
|
size_t dimSizeIndexBits) {
|
||||||
auto inputTy = input.getType().template cast<RankedTensorType>();
|
auto inputTy = input.getType().template cast<RankedTensorType>();
|
||||||
if (!inputTy) {
|
if (!inputTy) {
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
|
@ -86,7 +88,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
||||||
Value initValue = createInitialValueForReduceOp(op, inputElemTy, rewriter);
|
Value initValue = createInitialValueForReduceOp(op, inputElemTy, rewriter);
|
||||||
if (!initValue) return llvm::None;
|
if (!initValue) return llvm::None;
|
||||||
Value initIndex;
|
Value initIndex;
|
||||||
if (mlir::mhlo::kMhloDimSizeBits == 32) {
|
if (dimSizeIndexBits == 32) {
|
||||||
initIndex = mhlo::getConstTensor<int32_t>(rewriter, op, {0}, {}).value();
|
initIndex = mhlo::getConstTensor<int32_t>(rewriter, op, {0}, {}).value();
|
||||||
} else {
|
} else {
|
||||||
initIndex = mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
|
initIndex = mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
|
||||||
|
@ -98,7 +100,9 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
||||||
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||||
op->getLoc(), inputShapeVec);
|
op->getLoc(), inputShapeVec);
|
||||||
auto indexTensor = rewriter.create<mhlo::DynamicIotaOp>(
|
auto indexTensor = rewriter.create<mhlo::DynamicIotaOp>(
|
||||||
op->getLoc(), RankedTensorType::get(inputShape, rewriter.getIntegerType(mlir::mhlo::kMhloDimSizeBits)),
|
op->getLoc(),
|
||||||
|
RankedTensorType::get(inputShape,
|
||||||
|
rewriter.getIntegerType(dimSizeIndexBits)),
|
||||||
inputShapeTensor, static_cast<uint64_t>(dim));
|
inputShapeTensor, static_cast<uint64_t>(dim));
|
||||||
|
|
||||||
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
|
auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
|
||||||
|
@ -114,7 +118,8 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
||||||
// Add block arguments
|
// Add block arguments
|
||||||
auto blockValArgumentType =
|
auto blockValArgumentType =
|
||||||
RankedTensorType::get({}, inputTy.getElementType());
|
RankedTensorType::get({}, inputTy.getElementType());
|
||||||
auto blockIdxArgumentType = RankedTensorType::get({}, rewriter.getIntegerType(mlir::mhlo::kMhloDimSizeBits));
|
auto blockIdxArgumentType =
|
||||||
|
RankedTensorType::get({}, rewriter.getIntegerType(dimSizeIndexBits));
|
||||||
auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type());
|
auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type());
|
||||||
block.addArgument(blockValArgumentType, op->getLoc());
|
block.addArgument(blockValArgumentType, op->getLoc());
|
||||||
block.addArgument(blockIdxArgumentType, op->getLoc());
|
block.addArgument(blockIdxArgumentType, op->getLoc());
|
||||||
|
@ -171,9 +176,9 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template <typename AtenOpT>
|
template <typename AtenOpT>
|
||||||
class ConvertAtenReductionOp : public OpConversionPattern<AtenOpT> {
|
class ConvertAtenReductionOp : public ConvertAtenOp<AtenOpT> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
|
||||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
|
@ -220,21 +225,24 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
|
||||||
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
|
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto inputShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input);
|
const auto &options = getOptions();
|
||||||
|
auto inputShapeInfo =
|
||||||
|
mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||||
if (failed(inputShapeInfo)) {
|
if (failed(inputShapeInfo)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "failed to get dimension sizes of the input");
|
op, "failed to get dimension sizes of the input");
|
||||||
}
|
}
|
||||||
auto inputShapeVec = *inputShapeInfo;
|
auto inputShapeVec = *inputShapeInfo;
|
||||||
auto mhloReduceResults =
|
auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim,
|
||||||
getMaxInDim(rewriter, op, input, inputShapeVec, dim).value();
|
options.dimSizeIndexBits)
|
||||||
|
.value();
|
||||||
|
|
||||||
if (keepDim) {
|
if (keepDim) {
|
||||||
auto outShapeVec = inputShapeVec;
|
auto outShapeVec = inputShapeVec;
|
||||||
|
|
||||||
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
|
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
|
||||||
op->getLoc(), rewriter.getIntegerAttr(
|
op->getLoc(),
|
||||||
rewriter.getIntegerType(mlir::mhlo::kMhloDimSizeBits), 1));
|
rewriter.getIntegerAttr(
|
||||||
|
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
|
||||||
|
|
||||||
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||||
op->getLoc(), outShapeVec);
|
op->getLoc(), outShapeVec);
|
||||||
|
@ -297,20 +305,24 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
|
||||||
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
|
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto inputShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input);
|
const auto &options = getOptions();
|
||||||
|
auto inputShapeInfo =
|
||||||
|
mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
||||||
if (failed(inputShapeInfo)) {
|
if (failed(inputShapeInfo)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "failed to get dimension sizes of the input");
|
op, "failed to get dimension sizes of the input");
|
||||||
}
|
}
|
||||||
auto inputShapeVec = *inputShapeInfo;
|
auto inputShapeVec = *inputShapeInfo;
|
||||||
auto mhloReduceResults =
|
auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim,
|
||||||
getMaxInDim(rewriter, op, input, inputShapeVec, dim).value();
|
options.dimSizeIndexBits)
|
||||||
|
.value();
|
||||||
|
|
||||||
if (keepDim) {
|
if (keepDim) {
|
||||||
auto outShapeVec = inputShapeVec;
|
auto outShapeVec = inputShapeVec;
|
||||||
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
|
outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
|
||||||
op->getLoc(), rewriter.getIntegerAttr(
|
op->getLoc(),
|
||||||
rewriter.getIntegerType(mhlo::kMhloDimSizeBits), 1));
|
rewriter.getIntegerAttr(
|
||||||
|
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
|
||||||
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||||
op->getLoc(), outShapeVec);
|
op->getLoc(), outShapeVec);
|
||||||
|
|
||||||
|
@ -532,15 +544,18 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (keepDim) {
|
if (keepDim) {
|
||||||
auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input);
|
const auto &options = getOptions();
|
||||||
|
auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input,
|
||||||
|
options.dimSizeIndexBits);
|
||||||
if (failed(outShapeInfo)) {
|
if (failed(outShapeInfo)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "failed to get dimension sizes of the input");
|
op, "failed to get dimension sizes of the input");
|
||||||
}
|
}
|
||||||
auto outShapeVec = *outShapeInfo;
|
auto outShapeVec = *outShapeInfo;
|
||||||
auto one = rewriter.create<mlir::arith::ConstantOp>(
|
auto one = rewriter.create<mlir::arith::ConstantOp>(
|
||||||
op->getLoc(), rewriter.getIntegerAttr(
|
op->getLoc(),
|
||||||
rewriter.getIntegerType(mhlo::kMhloDimSizeBits), 1));
|
rewriter.getIntegerAttr(
|
||||||
|
rewriter.getIntegerType(options.dimSizeIndexBits), 1));
|
||||||
for (int64_t i : dims) {
|
for (int64_t i : dims) {
|
||||||
outShapeVec[i] = one;
|
outShapeVec[i] = one;
|
||||||
}
|
}
|
||||||
|
@ -558,11 +573,11 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
||||||
|
|
||||||
void mlir::torch::torch_to_mhlo::populateReductionOpPatternsAndLegality(
|
void mlir::torch::torch_to_mhlo::populateReductionOpPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
ConversionTarget &target) {
|
ConversionTarget &target, const TorchToMhloOptions &options) {
|
||||||
MLIRContext *context = patterns.getContext();
|
MLIRContext *context = patterns.getContext();
|
||||||
#define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \
|
#define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenReductionOp<AtenOp>>(typeConverter, context);
|
patterns.add<ConvertAtenReductionOp<AtenOp>>(typeConverter, context, options)
|
||||||
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenArgmaxOp);
|
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenArgmaxOp);
|
||||||
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxDimOp);
|
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxDimOp);
|
||||||
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp);
|
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp);
|
||||||
|
|
|
@ -32,6 +32,12 @@ namespace {
|
||||||
|
|
||||||
class ConvertTorchToMhlo : public ConvertTorchToMhloBase<ConvertTorchToMhlo> {
|
class ConvertTorchToMhlo : public ConvertTorchToMhloBase<ConvertTorchToMhlo> {
|
||||||
public:
|
public:
|
||||||
|
ConvertTorchToMhlo() = default;
|
||||||
|
ConvertTorchToMhlo(bool enableStaticShape, bool enableI32Index) {
|
||||||
|
this->enableStaticShape = enableStaticShape;
|
||||||
|
this->enableI32Index = enableI32Index;
|
||||||
|
}
|
||||||
|
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<chlo::ChloDialect>();
|
registry.insert<chlo::ChloDialect>();
|
||||||
registry.insert<mhlo::MhloDialect>();
|
registry.insert<mhlo::MhloDialect>();
|
||||||
|
@ -51,18 +57,20 @@ public:
|
||||||
|
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
|
|
||||||
|
torch_to_mhlo::TorchToMhloOptions options{enableStaticShape,
|
||||||
|
enableI32Index ? 32u : 64u};
|
||||||
torch_to_mhlo::populateBasicOpPatternsAndLegality(typeConverter, patterns,
|
torch_to_mhlo::populateBasicOpPatternsAndLegality(typeConverter, patterns,
|
||||||
target);
|
target, options);
|
||||||
torch_to_mhlo::populateViewLikeOpPatternsAndLegality(typeConverter,
|
torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
|
||||||
patterns, target);
|
typeConverter, patterns, target, options);
|
||||||
torch_to_mhlo::populateGatherOpPatternsAndLegality(typeConverter, patterns,
|
torch_to_mhlo::populateGatherOpPatternsAndLegality(typeConverter, patterns,
|
||||||
target);
|
target, options);
|
||||||
torch_to_mhlo::populateReductionOpPatternsAndLegality(typeConverter,
|
torch_to_mhlo::populateReductionOpPatternsAndLegality(
|
||||||
patterns, target);
|
typeConverter, patterns, target, options);
|
||||||
torch_to_mhlo::populateLinearOpPatternsAndLegality(typeConverter, patterns,
|
torch_to_mhlo::populateLinearOpPatternsAndLegality(typeConverter, patterns,
|
||||||
target);
|
target, options);
|
||||||
torch_to_mhlo::populatePoolingOpPatternsAndLegality(typeConverter, patterns,
|
torch_to_mhlo::populatePoolingOpPatternsAndLegality(typeConverter, patterns,
|
||||||
target);
|
target, options);
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
std::move(patterns)))) {
|
std::move(patterns)))) {
|
||||||
|
@ -75,5 +83,12 @@ public:
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
mlir::torch::createConvertTorchToMhloPass() {
|
mlir::torch::createConvertTorchToMhloPass() {
|
||||||
return std::make_unique<ConvertTorchToMhlo>();
|
return std::make_unique<ConvertTorchToMhlo>(false, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
|
mlir::torch::createConvertTorchToMhloPass(bool enableStaticShape,
|
||||||
|
bool enableI32Index) {
|
||||||
|
return std::make_unique<ConvertTorchToMhlo>(enableStaticShape,
|
||||||
|
enableI32Index);
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,6 +28,7 @@ using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
using namespace mlir::torch::TorchConversion;
|
using namespace mlir::torch::TorchConversion;
|
||||||
|
using namespace mlir::torch::torch_to_mhlo;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// A dimension index from torch.dialect might outside the range [0, dimSize].
|
// A dimension index from torch.dialect might outside the range [0, dimSize].
|
||||||
|
@ -55,10 +56,11 @@ Value getNormalizedDimSizeInternal(PatternRewriter &rewriter, Operation *op,
|
||||||
Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op,
|
Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op,
|
||||||
Type outTy, Value input, Value startIndex,
|
Type outTy, Value input, Value startIndex,
|
||||||
Value endIndex, Value step, size_t dimIndex,
|
Value endIndex, Value step, size_t dimIndex,
|
||||||
ArrayRef<Value> dimSizes) {
|
ArrayRef<Value> dimSizes,
|
||||||
|
size_t dimSizeIndexBits) {
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
// startIndex & endIndex has been normailized into range [0, dSize]
|
// startIndex & endIndex has been normailized into range [0, dSize]
|
||||||
Type intType = rewriter.getIntegerType(mhlo::kMhloDimSizeBits);
|
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
|
||||||
Value zero = rewriter.create<arith::ConstantOp>(
|
Value zero = rewriter.create<arith::ConstantOp>(
|
||||||
loc, rewriter.getIntegerAttr(intType, 0));
|
loc, rewriter.getIntegerAttr(intType, 0));
|
||||||
Value one = rewriter.create<arith::ConstantOp>(
|
Value one = rewriter.create<arith::ConstantOp>(
|
||||||
|
@ -109,7 +111,8 @@ FailureOr<Value> getDynamicSlice(PatternRewriter &rewriter, Operation *op,
|
||||||
Type outTy, Value input,
|
Type outTy, Value input,
|
||||||
llvm::Optional<Value> startIndexOpt,
|
llvm::Optional<Value> startIndexOpt,
|
||||||
llvm::Optional<Value> endIndexOpt,
|
llvm::Optional<Value> endIndexOpt,
|
||||||
llvm::Optional<Value> stepOpt, int64_t dim) {
|
llvm::Optional<Value> stepOpt, int64_t dim,
|
||||||
|
size_t dimSizeIndexBits) {
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||||
auto rank = inputTy.getRank();
|
auto rank = inputTy.getRank();
|
||||||
|
@ -133,77 +136,31 @@ FailureOr<Value> getDynamicSlice(PatternRewriter &rewriter, Operation *op,
|
||||||
: rewriter.create<arith::ConstantOp>(
|
: rewriter.create<arith::ConstantOp>(
|
||||||
loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1));
|
loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1));
|
||||||
|
|
||||||
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
|
if (dimSizeIndexBits == 32) {
|
||||||
auto i32Type = rewriter.getIntegerType(mhlo::kMhloDimSizeBits);
|
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
|
||||||
normStartIndex =
|
normStartIndex =
|
||||||
rewriter.create<arith::TruncIOp>(loc, i32Type, normStartIndex);
|
rewriter.create<arith::TruncIOp>(loc, intType, normStartIndex);
|
||||||
normEndIndex = rewriter.create<arith::TruncIOp>(loc, i32Type, normEndIndex);
|
normEndIndex = rewriter.create<arith::TruncIOp>(loc, intType, normEndIndex);
|
||||||
step = rewriter.create<arith::TruncIOp>(loc, i32Type, step);
|
step = rewriter.create<arith::TruncIOp>(loc, intType, step);
|
||||||
#endif
|
}
|
||||||
FailureOr<SmallVector<Value, 4>> dimSizesInfo =
|
FailureOr<SmallVector<Value, 4>> dimSizesInfo =
|
||||||
mhlo::getDimSizesOfTensor(rewriter, op, input);
|
mhlo::getDimSizesOfTensor(rewriter, op, input, dimSizeIndexBits);
|
||||||
if (failed(dimSizesInfo))
|
if (failed(dimSizesInfo))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "failed to get dimension sizes of the input");
|
op, "failed to get dimension sizes of the input");
|
||||||
|
|
||||||
auto dimSizes = *dimSizesInfo;
|
auto dimSizes = *dimSizesInfo;
|
||||||
return getDynamicSliceInternal(rewriter, op, outTy, input, normStartIndex,
|
return getDynamicSliceInternal(rewriter, op, outTy, input, normStartIndex,
|
||||||
normEndIndex, step, dim, dimSizes);
|
normEndIndex, step, dim, dimSizes,
|
||||||
}
|
dimSizeIndexBits);
|
||||||
|
|
||||||
template <typename AtenOpT>
|
|
||||||
class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
|
|
||||||
public:
|
|
||||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
|
||||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const override;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
|
|
||||||
AtenSliceTensorOp op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const {
|
|
||||||
auto self = adaptor.self();
|
|
||||||
auto selfTy = self.getType().template cast<RankedTensorType>();
|
|
||||||
if (!selfTy)
|
|
||||||
return op.emitError("only ranked tensor types are supported");
|
|
||||||
auto outTy =
|
|
||||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
|
||||||
int64_t dim;
|
|
||||||
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "only constant dim is currently supported");
|
|
||||||
|
|
||||||
auto getOptionalVal = [&](Value val) -> llvm::Optional<Value> {
|
|
||||||
if (val.getType().isa<Torch::NoneType>()) {
|
|
||||||
return llvm::None;
|
|
||||||
} else {
|
|
||||||
return val;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
llvm::Optional<Value> start = getOptionalVal(adaptor.start());
|
|
||||||
llvm::Optional<Value> end = getOptionalVal(adaptor.end());
|
|
||||||
llvm::Optional<Value> step = getOptionalVal(adaptor.step());
|
|
||||||
|
|
||||||
FailureOr<Value> sliceInfo =
|
|
||||||
getDynamicSlice(rewriter, op, outTy, self, start, end, step, dim);
|
|
||||||
if (failed(sliceInfo))
|
|
||||||
return op.emitError("can not create a dynmaic slice");
|
|
||||||
|
|
||||||
auto slice = *sliceInfo;
|
|
||||||
rewriter.replaceOp(op, slice);
|
|
||||||
return success();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// This defines a template to construct ops whose legalizations are
|
// This defines a template to construct ops whose legalizations are
|
||||||
// specialized.
|
// specialized.
|
||||||
template <typename AtenOpT>
|
template <typename AtenOpT>
|
||||||
class ConvertAtenViewOp : public OpConversionPattern<AtenOpT> {
|
class ConvertAtenViewOp : public ConvertAtenOp<AtenOpT> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
|
||||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
|
@ -235,19 +192,19 @@ public:
|
||||||
return dSize;
|
return dSize;
|
||||||
});
|
});
|
||||||
|
|
||||||
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
|
const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
|
||||||
// The i64 calculation is much slower than i32 on some devices, such as
|
Type intType = rewriter.getIntegerType(options.dimSizeIndexBits);
|
||||||
// Nvidia GPU. One can truncate from i64 to i32 since dimension sizes are
|
if (options.dimSizeIndexBits == 32) {
|
||||||
// unlikely to exceed the range of i32(4GiB)
|
// The i64 calculation is much slower than i32 on some devices, such as
|
||||||
std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value &dSize) {
|
// Nvidia GPU. One can truncate from i64 to i32 since dimension sizes are
|
||||||
// dimSize: cast i64 -> i32
|
// unlikely to exceed the range of i32(4GiB)
|
||||||
dSize =
|
std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value &dSize) {
|
||||||
rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), dSize);
|
// dimSize: cast i64 -> i32
|
||||||
return dSize;
|
dSize = rewriter.create<arith::TruncIOp>(loc, intType, dSize);
|
||||||
});
|
return dSize;
|
||||||
#endif
|
});
|
||||||
|
}
|
||||||
|
|
||||||
Type intType = rewriter.getIntegerType(mhlo::kMhloDimSizeBits);
|
|
||||||
Value numel = rewriter.create<arith::ConstantOp>(
|
Value numel = rewriter.create<arith::ConstantOp>(
|
||||||
loc, rewriter.getIntegerAttr(intType, 1));
|
loc, rewriter.getIntegerAttr(intType, 1));
|
||||||
for (auto d : dimSizes) {
|
for (auto d : dimSizes) {
|
||||||
|
@ -293,6 +250,45 @@ bool ConvertAtenViewOp<AtenReshapeOp>::getAtenViewOpSizes(
|
||||||
SmallVector<Value, 4> &dimSizes) const {
|
SmallVector<Value, 4> &dimSizes) const {
|
||||||
return getListConstructElements(adaptor.shape(), dimSizes);
|
return getListConstructElements(adaptor.shape(), dimSizes);
|
||||||
}
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
|
||||||
|
AtenSliceTensorOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
auto self = adaptor.self();
|
||||||
|
auto selfTy = self.getType().template cast<RankedTensorType>();
|
||||||
|
if (!selfTy)
|
||||||
|
return op.emitError("only ranked tensor types are supported");
|
||||||
|
auto outTy =
|
||||||
|
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||||
|
int64_t dim;
|
||||||
|
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "only constant dim is currently supported");
|
||||||
|
|
||||||
|
auto getOptionalVal = [&](Value val) -> llvm::Optional<Value> {
|
||||||
|
if (val.getType().isa<Torch::NoneType>()) {
|
||||||
|
return llvm::None;
|
||||||
|
} else {
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
llvm::Optional<Value> start = getOptionalVal(adaptor.start());
|
||||||
|
llvm::Optional<Value> end = getOptionalVal(adaptor.end());
|
||||||
|
llvm::Optional<Value> step = getOptionalVal(adaptor.step());
|
||||||
|
|
||||||
|
FailureOr<Value> sliceInfo =
|
||||||
|
getDynamicSlice(rewriter, op, outTy, self, start, end, step, dim,
|
||||||
|
options.dimSizeIndexBits);
|
||||||
|
if (failed(sliceInfo))
|
||||||
|
return op.emitError("can not create a dynmaic slice");
|
||||||
|
|
||||||
|
auto slice = *sliceInfo;
|
||||||
|
rewriter.replaceOp(op, slice);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
|
||||||
|
@ -324,7 +320,8 @@ LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims);
|
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims,
|
||||||
|
options.dimSizeIndexBits);
|
||||||
if (failed(newDimSizesInfo))
|
if (failed(newDimSizesInfo))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "failed to get dimension sizes of the input");
|
op, "failed to get dimension sizes of the input");
|
||||||
|
@ -372,7 +369,8 @@ LogicalResult ConvertAtenOp<AtenSqueezeDimOp>::matchAndRewrite(
|
||||||
op, getTypeConverter()->convertType(op.getType()), self);
|
op, getTypeConverter()->convertType(op.getType()), self);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims);
|
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims,
|
||||||
|
options.dimSizeIndexBits);
|
||||||
if (failed(newDimSizesInfo))
|
if (failed(newDimSizesInfo))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "failed to get dimension sizes of the input");
|
op, "failed to get dimension sizes of the input");
|
||||||
|
@ -397,8 +395,8 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
|
||||||
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
|
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
|
||||||
return op->emitError("dim must be a Scalar constant");
|
return op->emitError("dim must be a Scalar constant");
|
||||||
|
|
||||||
auto unsqzTensorInfo =
|
auto unsqzTensorInfo = mhlo::unsqueezeTensor(rewriter, op, adaptor.self(),
|
||||||
mhlo::unsqueezeTensor(rewriter, op, adaptor.self(), {dim});
|
{dim}, options.dimSizeIndexBits);
|
||||||
if (failed(unsqzTensorInfo))
|
if (failed(unsqzTensorInfo))
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"failed to create unsqueezed tensor");
|
"failed to create unsqueezed tensor");
|
||||||
|
@ -406,16 +404,15 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
|
||||||
rewriter.replaceOp(op, *unsqzTensorInfo);
|
rewriter.replaceOp(op, *unsqzTensorInfo);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
|
void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
ConversionTarget &target) {
|
ConversionTarget &target, const TorchToMhloOptions &options) {
|
||||||
MLIRContext *context = patterns.getContext();
|
MLIRContext *context = patterns.getContext();
|
||||||
|
|
||||||
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
|
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
|
||||||
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
|
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenSqueezeOp);
|
INSERT_ATENOP_PATTERN(AtenSqueezeOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenSqueezeDimOp);
|
INSERT_ATENOP_PATTERN(AtenSqueezeDimOp);
|
||||||
|
@ -424,7 +421,7 @@ void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
|
||||||
|
|
||||||
#define INSERT_VIEW_OP_PATTERN(AtenOp) \
|
#define INSERT_VIEW_OP_PATTERN(AtenOp) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenViewOp<AtenOp>>(typeConverter, context);
|
patterns.add<ConvertAtenViewOp<AtenOp>>(typeConverter, context, options)
|
||||||
INSERT_VIEW_OP_PATTERN(AtenViewOp);
|
INSERT_VIEW_OP_PATTERN(AtenViewOp);
|
||||||
INSERT_VIEW_OP_PATTERN(AtenReshapeOp);
|
INSERT_VIEW_OP_PATTERN(AtenReshapeOp);
|
||||||
#undef INSERT_VIEW_OP_PATTERN
|
#undef INSERT_VIEW_OP_PATTERN
|
||||||
|
|
Loading…
Reference in New Issue