[MHLO] Support for dynamic shape in basic op conversion by introducing CHLO dialect (#1123)

* [MHLO] Support for dynamic shape in basic op conversion by introducing CHLO dialect
Co-authored-by: Bairen Yi <yibairen.byron@bytedance.com>
Co-authored-by: Jiawei Wu <xremold@gmail.com>
Co-authored-by: Tianyou Guo <tianyou.gty@alibaba-inc.com>
Co-authored-by: Xu Yan <yancey.yx@alibaba-inc.com>
Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>

* [MHLO] Support I32 as shape tensor dtype

* [NFC] Add a 'TODO' annotation
pull/1134/head
武家伟 2022-08-02 12:53:24 +08:00 committed by GitHub
parent 3772e0bd91
commit 76c976682c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 815 additions and 893 deletions

View File

@ -12,7 +12,10 @@
#include "../PassDetail.h"
#include "./MhloLegalizeUtils.h"
#include "./PopulatePatterns.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -51,7 +54,7 @@ public:
auto selfTy = self.getType().cast<TensorType>();
if (!selfTy)
return op.emitError("Only Tensor types supported in MHLO");
return op.emitError("only Tensor types supported in MHLO");
if (selfTy.getElementType().isa<mlir::FloatType>()) {
rewriter.replaceOpWithNewOp<MhloOpT>(
@ -62,7 +65,7 @@ public:
return success();
} else {
return op.emitError(
"Only floating-point datatype legalization supported");
"only floating-point datatype legalization supported");
}
}
};
@ -85,29 +88,29 @@ public:
.template dyn_cast<TensorType>();
if (!outType)
return op.emitError("Only Tensor types supported in MHLO");
return op.emitError("only Tensor types supported in MHLO");
Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat())
return op.emitError(
"Only floating-point or integer datatype legalization supported");
"only floating-point or integer datatype legalization supported");
// FIXME: Handle layout, device and pin_memory. Assume dtype has been
// processed to set output type correctly?
if (!op.layout().getType().template isa<Torch::NoneType>())
return op.emitError("Only default layout is supported");
return op.emitError("only default layout is supported");
bool pinMemory;
if (!op.pin_memory().getType().template isa<Torch::NoneType>() &&
(!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) ||
pinMemory)) {
return op.emitError(
"Unsupported pin_memory, should be either None or false");
"unsupported pin_memory, should be either None or false");
}
SmallVector<int64_t> shape;
if (!matchPattern(op.size(), m_TorchConstantIntList(shape))) {
return op.emitError("Shape must be a list of Scalar constants");
return op.emitError("shape must be a list of Scalar constants");
}
int64_t size = 1;
@ -128,7 +131,7 @@ public:
// These binary op legalizations are specific to add/sub which have an
// alpha multiplier.
namespace {
template <typename AtenOpT, typename MhloOpT>
template <typename AtenOpT, typename ChloOpT>
class ConvertAtenAddSubOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
@ -137,12 +140,12 @@ public:
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value lhs = adaptor.self();
TensorType lhsType = lhs.getType().dyn_cast<TensorType>();
RankedTensorType lhsType = lhs.getType().dyn_cast<RankedTensorType>();
Value rhs = adaptor.other();
TensorType rhsType = rhs.getType().dyn_cast<TensorType>();
RankedTensorType rhsType = rhs.getType().dyn_cast<RankedTensorType>();
if (!lhsType)
return op.emitError("Only Tensor types supported in MHLO");
return op.emitError("only Tensor types supported in MHLO");
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
@ -151,53 +154,44 @@ public:
Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) {
return op.emitError(
"Only floating-point or integer datatype legalization supported");
"only floating-point or integer datatype legalization supported");
}
Value rhsAsTensor;
if (!rhsType) {
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(),
rhsAsTensor, outElemTy,
outType.getShape())))
return op.emitError("Currently only scalar constants are supported for "
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), rhs,
outElemTy, {})))
return op.emitError("currently only scalar constants are supported for "
"conversion in MHLO operation");
}
Value lhsTensor = lhs;
Value rhsTensor = rhsType ? rhs : rhsAsTensor;
// Handle broadcasting. Since we have the output type already, here we
// just broodcast operands' shape to output shape.
lhsTensor = mhlo::promoteAndBroadcast(rewriter, lhsTensor, outType);
rhsTensor = mhlo::promoteAndBroadcast(rewriter, rhsTensor, outType);
lhs = mhlo::promoteType(rewriter, lhs, outType);
rhs = mhlo::promoteType(rewriter, rhs, outType);
// Handle alpha.
Value multTensor;
if (skipMultiplyAlpha(op.alpha())) {
multTensor = rhsTensor;
} else {
Value alphaTensor;
if (!skipMultiplyAlpha(op.alpha())) {
Value alpha;
if (failed(mhlo::torchAlphaToMhloTensor(rewriter, op.getOperation(),
op.alpha(), alphaTensor,
outElemTy, outType.getShape(),
op.alpha(), alpha, outElemTy, {},
/*checkForUnity=*/false))) {
return op.emitError("Currently only scalar constants are supported for "
return op.emitError("currently only scalar constants are supported for "
"alpha in conversion to MHLO operation");
}
multTensor = rewriter.create<mhlo::MulOp>(op.getLoc(), outType, rhsTensor,
alphaTensor);
DenseIntElementsAttr bcastDimensions;
rhs = rewriter.create<chlo::BroadcastMulOp>(op->getLoc(), rhs, alpha,
bcastDimensions);
}
rewriter.replaceOpWithNewOp<MhloOpT>(op, outType, lhsTensor, multTensor);
DenseIntElementsAttr bcastDimensions;
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
bcastDimensions);
return success();
}
};
} // namespace
// Binary op legalizations for Mul variants.
// Binary op legalizations for Mul/Div variants.
namespace {
template <typename AtenOpT>
class ConvertAtenMulOp : public OpConversionPattern<AtenOpT> {
template <typename AtenOpT, typename ChloOpT>
class ConvertAtenMulDivOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
@ -210,7 +204,7 @@ public:
TensorType rhsType = rhs.getType().dyn_cast<TensorType>();
if (!lhsType)
return op.emitError("Only Tensor types supported in MHLO");
return op.emitError("only Tensor types supported in MHLO");
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
@ -219,86 +213,23 @@ public:
Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) {
return op.emitError(
"Only floating-point or integer datatype legalization supported");
"only floating-point or integer datatype legalization supported");
}
Value lhsTensor = lhs;
Value rhsTensor;
if (std::is_same<AtenOpT, AtenSquareOp>()) {
rhsTensor = lhs;
} else {
if (!rhsType) {
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(),
rhsTensor, outElemTy,
outType.getShape())))
return op.emitError(
"Currently only scalar constants are supported for "
"conversion in MHLO operation");
} else {
rhsTensor = rhs;
}
}
// Handle broadcasting. Since we have the output type already, here we
// just broodcast operands' shape to output shape.
lhsTensor = mhlo::promoteAndBroadcast(rewriter, lhsTensor, outType);
rhsTensor = mhlo::promoteAndBroadcast(rewriter, rhsTensor, outType);
rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, outType, lhsTensor, rhsTensor);
return success();
}
};
} // namespace
// Binary op legalizations for Div variants.
namespace {
template <typename AtenOpT>
class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value lhs = adaptor.self();
auto lhsTy = lhs.getType().dyn_cast<TensorType>();
Value rhs = adaptor.other();
auto rhsTy = rhs.getType().dyn_cast<TensorType>();
if (!lhsTy)
return op.emitError("Only Tensor types supported.");
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) {
return op.emitError(
"Only floating-point or integer datatype legalization supported");
}
Value lhsTensor = lhs;
Value rhsTensor;
if (!rhsTy) {
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(),
rhsTensor, outElemTy,
outType.getShape())))
return op.emitError("Currently only scalar constants are supported for "
if (!rhsType) {
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), rhs,
outElemTy, {})))
return op.emitError("currently only scalar constants are supported for "
"conversion in MHLO operation");
} else {
rhsTensor = rhs;
}
// Handle broadcasting. Since we have the output type already, here we
// just broodcast operands' shape to output shape.
lhsTensor = mhlo::promoteAndBroadcast(rewriter, lhsTensor, outType);
rhsTensor = mhlo::promoteAndBroadcast(rewriter, rhsTensor, outType);
rewriter.replaceOpWithNewOp<mhlo::DivOp>(op, outType, lhsTensor, rhsTensor);
DenseIntElementsAttr bcastDimensions;
lhs = mhlo::promoteType(rewriter, lhs, outType);
rhs = mhlo::promoteType(rewriter, rhs, outType);
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
bcastDimensions);
return success();
}
};
} // namespace
// Binary op legalizations for comparator ops.
@ -318,7 +249,7 @@ public:
RankedTensorType rhsTy = rhs.getType().dyn_cast<RankedTensorType>();
if (!lhsTy)
return op.emitError("Only Tensor types supported in MHLO");
return op.emitError("only Tensor types supported in MHLO");
RankedTensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
@ -327,21 +258,19 @@ public:
Type lhsElemTy = lhsTy.getElementType();
if (!lhsElemTy.isIntOrFloat()) {
return op.emitError(
"Only floating-point or integer datatype legalization supported");
"only floating-point or integer datatype legalization supported");
}
Value rhsAsTensor;
if (!rhsTy) {
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(),
rhsAsTensor, lhsElemTy, {}))) {
return op.emitError("Currently only scalar constants are supported for "
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), rhs,
lhsElemTy, {}))) {
return op.emitError("currently only scalar constants are supported for "
"conversion in MHLO operation");
}
}
Value lhsTensor = lhs;
Value rhsTensor = rhsTy ? rhs : rhsAsTensor;
rhsTensor = mhlo::promoteAndBroadcast(rewriter, rhsTensor, lhsTy);
// TODO: what is the PyTorch default type promotion?
rhs = mhlo::promoteType(rewriter, rhs, lhsTy);
mhlo::ComparisonTypeAttr compareTypeAttr;
mhlo::ComparisonDirectionAttr compareDirectionAttr;
@ -371,9 +300,9 @@ public:
compareDirectionAttr = mhlo::ComparisonDirectionAttr::get(
op->getContext(), mhlo::ComparisonDirection::NE);
}
rewriter.replaceOpWithNewOp<mhlo::CompareOp>(
op, outType, lhsTensor, rhsTensor, compareDirectionAttr,
DenseIntElementsAttr bcastDimensions;
rewriter.replaceOpWithNewOp<chlo::BroadcastCompareOp>(
op, outType, lhs, rhs, bcastDimensions, compareDirectionAttr,
compareTypeAttr);
return success();
}
@ -438,12 +367,63 @@ public:
matchAndRewrite(AtenBroadcastToOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value self = adaptor.self();
auto selfTy = self.getType().cast<RankedTensorType>();
auto outType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
Value bcastOp = mhlo::promoteAndBroadcast(rewriter, self, outType);
rewriter.replaceOp(op, bcastOp);
#ifdef TORCH_MLIR_ENABLE_MHLO_STATIC_SHAPE
if (selfTy.hasStaticShape()) {
Value bcastOp = mhlo::promoteAndBroadcast(rewriter, self, outType);
rewriter.replaceOp(op, bcastOp);
return success();
}
#endif
SmallVector<Value> shape;
if (!(getListConstructElements(adaptor.size(), shape))) {
return op->emitError("desired shape must be a list of scalar");
}
SmallVector<Value> bcastShapeVec;
int64_t totalRank = shape.size();
int64_t selfRank = selfTy.getRank();
int64_t leadingRank = totalRank - selfRank;
for (int64_t i = 0; i < totalRank; ++i) {
Value dValue = shape[i];
Value newD;
int64_t dInt;
if (!(matchPattern(dValue, m_TorchConstantInt(&dInt)))) {
return op->emitError("element of desired shape must be a scalar");
}
if (i >= leadingRank && dInt == -1) {
newD = rewriter.create<mlir::tensor::DimOp>(op->getLoc(), self,
i - leadingRank);
} else {
dValue = rewriter.create<torch::TorchConversion::ToI64Op>(op->getLoc(),
dValue);
newD = rewriter.create<mlir::arith::IndexCastOp>(
op->getLoc(), rewriter.getIndexType(), dValue);
}
bcastShapeVec.push_back(newD);
}
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
for (auto &dsize : bcastShapeVec) {
auto dsizeI64 = rewriter.create<mlir::arith::IndexCastOp>(
op->getLoc(), rewriter.getI64Type(), dsize);
dsize = rewriter.create<arith::TruncIOp>(op->getLoc(),
rewriter.getI32Type(), dsizeI64);
}
#endif
Value bcastShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), ValueRange{bcastShapeVec});
auto dimensionNumbers =
llvm::to_vector<4>(llvm::seq<int64_t>(leadingRank, totalRank));
rewriter.replaceOpWithNewOp<mhlo::DynamicBroadcastInDimOp>(
op, outType, self, bcastShapeTensor,
rewriter.getI64TensorAttr(dimensionNumbers));
return success();
}
};
@ -464,19 +444,19 @@ public:
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
if (!inType)
return op.emitError("Only ranked tensor types with static shapes are "
return op.emitError("only ranked tensor types with static shapes are "
"currently supported");
SmallVector<int64_t> permValues;
if (!matchPattern(adaptor.dims(), m_TorchConstantIntList(permValues)))
return rewriter.notifyMatchFailure(
op, "Only constant dimensions are currently supported");
op, "only constant dimensions are currently supported");
int64_t inRank = inType.getRank();
for (auto &d : permValues) {
d = toPositiveDim(d, inRank);
if (!isValidDim(d, inRank))
return op.emitError("Not all dims are valid");
return op.emitError("not all dims are valid");
}
DenseIntElementsAttr permutation = DenseIntElementsAttr::get(
@ -517,7 +497,7 @@ LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
return success();
} else {
return op.emitError(
"Only floating-point datatype legalization currently supported");
"only floating-point datatype legalization currently supported");
}
}
} // namespace
@ -535,6 +515,7 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
// Tensors with integer types need to be converted to signless integer
// element type. All tensors with element types other than integer can reuse
// existing elements attribute.
// TODO: what about unsigned integer?
if (auto elements = op.valueAttr().dyn_cast<DenseIntElementsAttr>()) {
Type builtinTensorElemTy = resultType.getElementType();
unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth();
@ -566,13 +547,11 @@ LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
auto outTy =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
if (!inputTy.getElementType().isa<mlir::FloatType>()) {
return op.emitError("Only floating-point datatype legalization supported "
return op.emitError("only floating-point datatype legalization supported "
"for AtenReciprocalOp");
}
Value oneTensor =
mhlo::getConstTensor<float>(rewriter, op, {static_cast<float>(1.0)}, {})
.getValue();
oneTensor = mhlo::promoteAndBroadcast(rewriter, oneTensor, inputTy);
Value oneTensor = chlo::getConstantLike(rewriter, op->getLoc(), 1, input);
rewriter.replaceOpWithNewOp<mhlo::DivOp>(op, outTy, oneTensor, input);
return success();
}
@ -593,7 +572,7 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.a(), mhloTensor,
outputElemType, outputShape,
false))) {
return op->emitError("Failed lowering PrimNumToTensorScalarOp to MHLO");
return op->emitError("failed lowering PrimNumToTensorScalarOp to MHLO");
}
rewriter.replaceOp(op, mhloTensor);
return success();
@ -611,7 +590,7 @@ LogicalResult ConvertAtenOp<AtenContiguousOp>::matchAndRewrite(
// Not a tensor type.
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
if (!selfType)
return op.emitError("Only tensor types are currently supported");
return op.emitError("only tensor types are currently supported");
// FIXME: memory_format is not handled.
@ -633,27 +612,17 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
auto lhsTy = lhs.getType().cast<RankedTensorType>();
auto lhsElemTy = lhsTy.getElementType();
int64_t lhsSize = 1;
for (auto &en : llvm::enumerate(lhsTy.getShape())) {
lhsSize *= en.value();
if (!lhsElemTy.isa<mlir::FloatType>()) {
return op->emitError("only float tensor in relu op is supported");
}
auto constTy = RankedTensorType::get(lhsTy.getShape(), lhsElemTy);
DenseElementsAttr constAttr;
if (lhsElemTy.isa<mlir::FloatType>()) {
std::vector<APFloat> constVec(
lhsSize,
APFloat::getZero(lhsElemTy.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/false));
constAttr = DenseElementsAttr::get(constTy, constVec);
} else if (lhsElemTy.isa<mlir::IntegerType>()) {
std::vector<APInt> constVec(
lhsSize, APInt::getZero(lhsElemTy.getIntOrFloatBitWidth()));
constAttr = DenseElementsAttr::get(constTy, constVec);
}
Value rhs =
rewriter.create<mhlo::ConstantOp>(op.getLoc(), constTy, constAttr);
rewriter.replaceOpWithNewOp<mhlo::MaxOp>(op, lhs, rhs);
Value zeroTensor;
zeroTensor = chlo::getConstantLike(
rewriter, op->getLoc(),
APFloat::getZero(lhsElemTy.cast<mlir::FloatType>().getFloatSemantics(),
false),
lhs);
rewriter.replaceOpWithNewOp<mhlo::MaxOp>(op, lhs, zeroTensor);
return success();
}
@ -666,72 +635,12 @@ LogicalResult ConvertAtenOp<AtenErfOp>::matchAndRewrite(
AtenErfOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.self();
auto inputType = input.getType().cast<RankedTensorType>();
auto inputType = input.getType().cast<TensorType>();
if (!inputType.getElementType().isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure(op, "Only support float data type");
return rewriter.notifyMatchFailure(op, "only float tensor is supported");
}
auto outType =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
// Using:
// https://en.wikipedia.org/wiki/Error_function#Numerical_approximations with
// maximum error as 5 x 10^-4 where a1 = 0.278393, a2 = 0.230389, a3 =
// 0.000972, a4 = 0.078108.
// Erf = 1 - 1 / (1 + a1X + a2X^2 + a3X^3 + a4X^4)^4
auto loc = op->getLoc();
auto zeroConst =
mhlo::getConstTensor<float>(rewriter, op, {0.0}, {}).getValue();
auto zero = mhlo::promoteAndBroadcast(rewriter, zeroConst, outType);
auto oneConst =
mhlo::getConstTensor<float>(rewriter, op, {1.0}, {}).getValue();
auto one = mhlo::promoteAndBroadcast(rewriter, oneConst, outType);
auto a1Const =
mhlo::getConstTensor<float>(rewriter, op, {0.278393}, {}).getValue();
auto a1 = mhlo::promoteAndBroadcast(rewriter, a1Const, outType);
auto a2Const =
mhlo::getConstTensor<float>(rewriter, op, {0.230389}, {}).getValue();
auto a2 = mhlo::promoteAndBroadcast(rewriter, a2Const, outType);
auto a3Const =
mhlo::getConstTensor<float>(rewriter, op, {0.000972}, {}).getValue();
auto a3 = mhlo::promoteAndBroadcast(rewriter, a3Const, outType);
auto a4Const =
mhlo::getConstTensor<float>(rewriter, op, {0.078108}, {}).getValue();
auto a4 = mhlo::promoteAndBroadcast(rewriter, a4Const, outType);
auto absX = rewriter.create<mhlo::AbsOp>(loc, outType, input);
auto a1X = rewriter.create<mhlo::MulOp>(loc, outType, a1, absX);
auto sum = rewriter.create<mhlo::AddOp>(loc, outType, a1X, one);
auto x2 = rewriter.create<mhlo::MulOp>(loc, outType, absX, absX);
auto a2X = rewriter.create<mhlo::MulOp>(loc, outType, a2, x2);
sum = rewriter.create<mhlo::AddOp>(loc, outType, sum, a2X);
auto x3 = rewriter.create<mhlo::MulOp>(loc, outType, x2, absX);
auto a3X = rewriter.create<mhlo::MulOp>(loc, outType, a3, x3);
sum = rewriter.create<mhlo::AddOp>(loc, outType, sum, a3X);
auto x4 = rewriter.create<mhlo::MulOp>(loc, outType, x3, absX);
auto a4X = rewriter.create<mhlo::MulOp>(loc, outType, a4, x4);
sum = rewriter.create<mhlo::AddOp>(loc, outType, sum, a4X);
auto rcprl = rewriter.create<mhlo::DivOp>(loc, outType, one, sum);
auto rcprl2 = rewriter.create<mhlo::MulOp>(loc, outType, rcprl, rcprl);
auto rcprl4 = rewriter.create<mhlo::MulOp>(loc, outType, rcprl2, rcprl2);
auto erf = rewriter.create<mhlo::SubtractOp>(loc, outType, one, rcprl4);
// Deal with negative x.
mhlo::ComparisonDirectionAttr compareDirectionAttr =
mhlo::ComparisonDirectionAttr::get(op->getContext(),
mhlo::ComparisonDirection::GE);
mhlo::ComparisonTypeAttr compareTypeAttr = mhlo::ComparisonTypeAttr::get(
op->getContext(), mhlo::ComparisonType::FLOAT);
auto geZero = rewriter.create<mhlo::CompareOp>(
loc, RankedTensorType::get(outType.getShape(), rewriter.getI1Type()),
input, zero, compareDirectionAttr, compareTypeAttr);
auto negaErf = rewriter.create<mhlo::NegOp>(loc, erf);
rewriter.replaceOpWithNewOp<mhlo::SelectOp>(op, outType, geZero, erf,
negaErf);
rewriter.replaceOpWithNewOp<chlo::ErfOp>(
op, getTypeConverter()->convertType(op.getType()), input);
return success();
}
@ -754,59 +663,71 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
Value momentum = adaptor.momentum();
(void)momentum;
// init weight, bias, runningVar, runningMean if they are none
auto initNoneValue = [&](Value &input, bool zero) {
SmallVector<APFloat> constVec(inputTy.getShape()[1],
APFloat::getZero(inputTy.getElementType()
.cast<mlir::FloatType>()
.getFloatSemantics()));
if (!zero) {
for (auto &item : constVec) {
item = APFloat(inputTy.getElementType()
.cast<mlir::FloatType>()
.getFloatSemantics(),
1);
}
}
auto constType = RankedTensorType::get({inputTy.getShape()[1]},
inputTy.getElementType());
auto constAttr = DenseElementsAttr::get(constType, constVec);
input =
rewriter.create<mhlo::ConstantOp>(op.getLoc(), constType, constAttr);
};
if (inputTy.getRank() <= 2) {
return rewriter.notifyMatchFailure(op,
"input should have rank larger than 2");
}
if (!inputTy.getElementType().template isa<mlir::FloatType>()) {
return op.emitError("only input tensor of float type is supported");
}
auto inputElemTy = inputTy.getElementType().cast<mlir::FloatType>();
Value channelDim = rewriter.create<tensor::DimOp>(op->getLoc(), input, 1);
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
auto channelDimI64 = rewriter.create<mlir::arith::IndexCastOp>(
op->getLoc(), rewriter.getI64Type(), channelDim);
channelDim = rewriter.create<arith::TruncIOp>(
op->getLoc(), rewriter.getI32Type(), channelDimI64);
#endif
Value channelShape = rewriter.create<tensor::FromElementsOp>(
op->getLoc(), ValueRange{channelDim});
if (failed(checkNotNone(rewriter, op, weight))) {
initNoneValue(weight, false);
weight = mhlo::getConstantOfShape(
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1),
channelShape,
RankedTensorType::get({inputTy.getShape()[1]},
inputTy.getElementType()));
}
if (failed(checkNotNone(rewriter, op, bias))) {
initNoneValue(bias, true);
bias = mhlo::getConstantOfShape(
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0),
channelShape,
RankedTensorType::get({inputTy.getShape()[1]},
inputTy.getElementType()));
}
if (failed(checkNotNone(rewriter, op, runningVar))) {
initNoneValue(runningVar, false);
runningVar = mhlo::getConstantOfShape(
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1),
channelShape,
RankedTensorType::get({inputTy.getShape()[1]},
inputTy.getElementType()));
}
if (failed(checkNotNone(rewriter, op, runningMean))) {
initNoneValue(runningMean, true);
runningMean = mhlo::getConstantOfShape(
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0),
channelShape,
RankedTensorType::get({inputTy.getShape()[1]},
inputTy.getElementType()));
}
auto weightTy = weight.getType().cast<RankedTensorType>();
auto biasTy = bias.getType().cast<RankedTensorType>();
auto runningMeanTy = runningMean.getType().cast<RankedTensorType>();
auto runningVarTy = runningVar.getType().cast<RankedTensorType>();
if (inputTy.getRank() <= 2) {
return rewriter.notifyMatchFailure(op,
"input should have rank larger than 2");
}
if (weightTy.getRank() != 1 || biasTy.getRank() != 1 ||
runningMeanTy.getRank() != 1 || runningVarTy.getRank() != 1) {
return rewriter.notifyMatchFailure(
op, "expect weight, bias, running_mean and running_var to be rank 1");
}
if (!inputTy.getElementType().template isa<mlir::FloatType>() ||
!weightTy.getElementType().template isa<mlir::FloatType>() ||
if (!weightTy.getElementType().template isa<mlir::FloatType>() ||
!biasTy.getElementType().template isa<mlir::FloatType>() ||
!runningMeanTy.getElementType().template isa<mlir::FloatType>() ||
!runningVarTy.getElementType().template isa<mlir::FloatType>()) {
return op.emitError(
"Only float element type is supported in MHLO BatchNormOp");
return op.emitError("only float weight/bias/runningMean/runningVar tensor "
"of float type is supported");
}
double eps = 0.0;
@ -858,6 +779,10 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
Value weight = adaptor.weight();
Value bias = adaptor.bias();
if (!inputTy.hasStaticShape()) {
return op->emitError("dynamic shaped input is not supported");
}
SmallVector<int64_t> normalizedShape;
if (!matchPattern(op.normalized_shape(),
m_TorchConstantIntList(normalizedShape))) {
@ -866,11 +791,12 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
}
double eps = 0;
if (!matchPattern(op.eps(), m_TorchConstantFloat(&eps))) {
return rewriter.notifyMatchFailure(op, "non const float eps unsupported");
return rewriter.notifyMatchFailure(op,
"non const float eps is unsupported");
}
if (failed(checkNotNone(rewriter, op, weight)) ||
failed(checkNotNone(rewriter, op, bias))) {
return op->emitError("Unsupported None for weight or bias");
return op->emitError("none weight or bias is unsupported");
}
auto weightTy = weight.getType().cast<RankedTensorType>();
auto biasTy = bias.getType().cast<RankedTensorType>();
@ -878,13 +804,13 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
if (!inputTy.getElementType().isa<mlir::FloatType>() ||
!biasTy.getElementType().isa<mlir::FloatType>() ||
!weightTy.getElementType().isa<mlir::FloatType>()) {
return op->emitError("For now, only float data type are supported");
return op->emitError("currently only float data type are supported");
}
int64_t normalizedShapeRank = normalizedShape.size();
if (weightTy.getRank() != normalizedShapeRank ||
biasTy.getRank() != normalizedShapeRank ||
inputRank < normalizedShapeRank || normalizedShapeRank < 1) {
return rewriter.notifyMatchFailure(op, "Input or weight or bias shape or"
return rewriter.notifyMatchFailure(op, "input or weight or bias shape or"
"normalized shape not compatible");
}
for (int64_t i = 1; i <= normalizedShapeRank; i++) {
@ -897,7 +823,7 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
}
}
// flatten dims to fit batch_norm operation.
// Flatten dims to fit batch_norm operation.
int64_t numFeatureDimSize = 1;
int64_t numEmbeddingDimSize = 1;
for (int64_t i = 0; i < inputRank - normalizedShapeRank; i++) {
@ -915,14 +841,14 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
auto mhloBathNormOutMeanOrVarTy =
RankedTensorType::get(meanOrVarMhloOutShape, inputTy.getElementType());
// reshape input
// Reshape input
auto mhloInput = rewriter.create<mhlo::DynamicReshapeOp>(
op->getLoc(), mhloBatchNormOutTy, input,
mhlo::getConstTensor(rewriter, op, llvm::makeArrayRef(inputFlattenShape),
{static_cast<int64_t>(inputFlattenShape.size())})
.getValue());
// generate "scale" and "offset" Value for mhlo.BatchNormTrainingOp.
// Generate "scale" and "offset" Value for mhlo.BatchNormTrainingOp.
SmallVector<APFloat> zeroConstVec(
numFeatureDimSize, APFloat::getZero(inputTy.getElementType()
.cast<mlir::FloatType>()
@ -946,7 +872,7 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
mhloBathNormOutMeanOrVarTy, mhloInput, scale, offset,
rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(1));
// reshape back
// Reshape back
auto outputTy =
getTypeConverter()->convertType(op.getType(0)).cast<RankedTensorType>();
auto outputMeanOrVarTy =
@ -1016,32 +942,28 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
#undef INSERT_CONSTANT_FILL_PATTERN
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, MhloOp) \
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, ChloOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenAddSubOp<AtenOp, MhloOp>>(typeConverter, context);
INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, mhlo::AddOp);
INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, mhlo::AddOp);
INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, mhlo::SubtractOp);
INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, mhlo::SubtractOp);
patterns.add<ConvertAtenAddSubOp<AtenOp, ChloOp>>(typeConverter, context);
INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, chlo::BroadcastAddOp);
INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, chlo::BroadcastAddOp);
INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, chlo::BroadcastSubOp);
INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, chlo::BroadcastSubOp);
#undef INSERT_BINARY_ADDSUB_PATTERN
#define INSERT_BINARY_MUL_PATTERN(AtenOp) \
#define INSERT_BINARY_MULDIV_PATTERN(AtenOp, ChloOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenMulOp<AtenOp>>(typeConverter, context);
INSERT_BINARY_MUL_PATTERN(AtenMulTensorOp);
INSERT_BINARY_MUL_PATTERN(AtenMulScalarOp);
#undef INSERT_BINARY_MUL_PATTERN
#define INSERT_BINARY_DIV_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenDivOp<AtenOp>>(typeConverter, context);
INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp);
INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp);
#undef INSERT_BINARY_DIV_PATTERN
patterns.add<ConvertAtenMulDivOp<AtenOp, ChloOp>>(typeConverter, context);
INSERT_BINARY_MULDIV_PATTERN(AtenMulTensorOp, chlo::BroadcastMulOp);
INSERT_BINARY_MULDIV_PATTERN(AtenMulScalarOp, chlo::BroadcastMulOp);
INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorOp, chlo::BroadcastDivOp);
INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarOp, chlo::BroadcastDivOp);
#undef INSERT_BINARY_MULDIV_PATTERN
#define INSERT_BINARY_COMPARE_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenCompareOp<AtenOp>>(typeConverter, context);
INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp);
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp);
INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp);
@ -1067,4 +989,4 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
#undef INSERT_ATENOP_PATTERN
}
}

View File

@ -10,6 +10,7 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo
DEPENDS
MhloDialect
ChloDialect
TorchMLIRConversionPassIncGen
LINK_COMPONENTS
@ -19,6 +20,7 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo
MLIRIR
MLIRPass
MhloDialect
ChloDialect
TorchMLIRTorchDialect
)

View File

@ -117,9 +117,9 @@ llvm::Optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
}
template <>
llvm::Optional<Value> getConstTensor<double>(PatternRewriter &rewriter,
Operation *op, ArrayRef<double> vec,
ArrayRef<int64_t> shape) {
llvm::Optional<Value>
getConstTensor<double>(PatternRewriter &rewriter, Operation *op,
ArrayRef<double> vec, ArrayRef<int64_t> shape) {
uint64_t num_total_elements = 1;
for (int64_t a : shape) {
num_total_elements *= a;
@ -149,7 +149,6 @@ template llvm::Optional<Value> getConstTensor<int64_t>(PatternRewriter &,
ArrayRef<int64_t> vec,
ArrayRef<int64_t> shape);
template <typename T>
static bool isInValidRange(bool isFloat, const double &doubleValue, bool isInt,
const int64_t &intValue) {
@ -166,20 +165,16 @@ static bool isInValidRange(bool isFloat, const double &doubleValue, bool isInt,
}
template <typename T>
Value getSplatConstTensor(ConversionPatternRewriter &rewriter,
Operation *op,
T val,
Type dtype,
llvm::ArrayRef<int64_t> dshape) {
auto const_type = RankedTensorType::get(
dshape, dtype);
Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
T val, Type dtype, llvm::ArrayRef<int64_t> dshape) {
auto const_type = RankedTensorType::get(dshape, dtype);
auto const_attr = SplatElementsAttr::get(const_type, val);
auto const_op =
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
return const_op.getResult();
}
// TODO: Support for variable scalar.
LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter,
Operation *op, Value torchScalarValue,
Value &mhloTensor, Type dtype,
@ -198,9 +193,8 @@ LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter,
if (dtype.isa<mlir::FloatType>()) {
if (doBroadcast) {
mhloTensor = getSplatConstTensor<float>(rewriter, op,
(isFloat ? doubleValue : intValue),
dtype, dshape);
mhloTensor = getSplatConstTensor<float>(
rewriter, op, (isFloat ? doubleValue : intValue), dtype, dshape);
} else {
mhloTensor = mhlo::getConstTensor<float>(
rewriter, op, (isFloat ? doubleValue : intValue), dshape)
@ -219,7 +213,8 @@ LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter,
int32_t d = isFloat ? static_cast<int32_t>(doubleValue)
: static_cast<int32_t>(intValue);
if (doBroadcast) {
mhloTensor = getSplatConstTensor<int32_t>(rewriter, op, d, dtype, dshape);
mhloTensor =
getSplatConstTensor<int32_t>(rewriter, op, d, dtype, dshape);
} else {
mhloTensor =
mhlo::getConstTensor<int32_t>(rewriter, op, {d}, dshape).getValue();
@ -231,7 +226,8 @@ LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter,
}
int64_t d = (isFloat ? static_cast<int64_t>(doubleValue) : intValue);
if (doBroadcast) {
mhloTensor = getSplatConstTensor<int64_t>(rewriter, op, d, dtype, dshape);
mhloTensor =
getSplatConstTensor<int64_t>(rewriter, op, d, dtype, dshape);
} else {
mhloTensor =
mhlo::getConstTensor<int64_t>(rewriter, op, {d}, dshape).getValue();
@ -243,7 +239,6 @@ LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter,
return success();
}
LogicalResult torchAlphaToMhloTensor(ConversionPatternRewriter &rewriter,
Operation *op, Value alphaScalar,
Value &alphaTensor, Type dtype,
@ -268,20 +263,33 @@ LogicalResult torchAlphaToMhloTensor(ConversionPatternRewriter &rewriter,
return success();
}
Value promoteAndBroadcast(ConversionPatternRewriter &rewriter,
Value input, TensorType outType) {
// Two tensors are “broadcastable” if the following rules hold:
// - Each tensor has at least one dimension.
// - When iterating over the dimension sizes, starting at the trailing dimension,
// the dimension sizes must either be equal, one of them is 1, or one of them
// does not exist.
Operation* op = input.getDefiningOp();
Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) {
Operation *op = input.getDefiningOp();
TensorType in_type = input.getType().dyn_cast<TensorType>();
if (in_type.getElementType() != outType.getElementType()) {
TensorType promoted_type = in_type.cloneWith(in_type.getShape(), outType.getElementType());
input = rewriter.create<mhlo::ConvertOp>(op->getLoc(), promoted_type, input);
TensorType promotedType =
in_type.cloneWith(in_type.getShape(), outType.getElementType());
return rewriter.create<mhlo::ConvertOp>(op->getLoc(), promotedType, input);
}
return input;
}
Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
TensorType outType) {
// Two tensors are “broadcastable” if the following rules hold:
// - Each tensor has at least one dimension.
// - When iterating over the dimension sizes, starting at the trailing
// dimension, the dimension sizes must either be equal, one of them is 1, or
// one of them does not exist.
Operation *op = input.getDefiningOp();
TensorType in_type = input.getType().dyn_cast<TensorType>();
if (in_type.getElementType() != outType.getElementType()) {
TensorType promoted_type =
in_type.cloneWith(in_type.getShape(), outType.getElementType());
input =
rewriter.create<mhlo::ConvertOp>(op->getLoc(), promoted_type, input);
}
ArrayRef<int64_t> inShape = in_type.getShape();
@ -301,7 +309,8 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter,
bcastDims.push_back(outPos);
do_bcast = true;
} else {
op->emitError("The size of tensor a (") << inDim << ")"
op->emitError("The size of tensor a (")
<< inDim << ")"
<< "must match the size of tensor b (" << outDim << ")"
<< "at non-singleton dimension " << inPos;
}
@ -311,10 +320,11 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter,
return input;
}
DenseIntElementsAttr bcast_attr = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<long int>(bcastDims.size())}, rewriter.getI64Type()),
RankedTensorType::get({static_cast<long int>(bcastDims.size())},
rewriter.getI64Type()),
bcastDims);
auto bcast_op =
rewriter.create<mhlo::BroadcastInDimOp>(op->getLoc(), outType, input, bcast_attr);
auto bcast_op = rewriter.create<mhlo::BroadcastInDimOp>(op->getLoc(), outType,
input, bcast_attr);
return bcast_op.getResult();
}
@ -418,5 +428,15 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
.getResult();
}
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
const APFloat &constant, Value shape,
TensorType outType) {
auto constAttr = rewriter.getFloatAttr(outType.getElementType(), constant);
auto constTensor = rewriter.create<mhlo::ConstantOp>(loc, constAttr);
return rewriter
.create<mhlo::DynamicBroadcastInDimOp>(loc, outType, constTensor, shape,
rewriter.getI64TensorAttr({}))
.getResult();
}
} // namespace mhlo
} // namespace mlir
} // namespace mlir

View File

@ -59,6 +59,8 @@ LogicalResult torchAlphaToMhloTensor(ConversionPatternRewriter &rewriter,
llvm::ArrayRef<int64_t> dshape,
bool checkForUnity);
Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType);
Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
TensorType outType);
@ -78,7 +80,11 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
Value tensor,
ArrayRef<int64_t> inputUnsqzDims);
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
const APFloat &constant, Value shape,
TensorType outType);
} // namespace mhlo
} // namespace mlir
#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H
#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H

View File

@ -12,6 +12,7 @@
#include "../PassDetail.h"
#include "./PopulatePatterns.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h"
@ -32,6 +33,7 @@ namespace {
class ConvertTorchToMhlo : public ConvertTorchToMhloBase<ConvertTorchToMhlo> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<chlo::ChloDialect>();
registry.insert<mhlo::MhloDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<arith::ArithmeticDialect>();
@ -40,7 +42,7 @@ public:
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<mhlo::MhloDialect, tensor::TensorDialect,
target.addLegalDialect<chlo::ChloDialect, mhlo::MhloDialect, tensor::TensorDialect,
arith::ArithmeticDialect, Torch::TorchDialect>();
TypeConverter typeConverter;
@ -68,4 +70,4 @@ public:
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::createConvertTorchToMhloPass() {
return std::make_unique<ConvertTorchToMhlo>();
}
}

View File

@ -1,3 +1,19 @@
set(LinkedLibs MLIRIR
MLIRPass
MLIRFuncTransforms
TorchMLIRTorchConversionDialect
TorchMLIRTorchDialect
TorchMLIRTorchPasses
TorchMLIRTorchToLinalg
TorchMLIRTorchToTMTensor
TorchMLIRTorchToStd
TorchMLIRTorchToSCF
MLIRMemRefTransforms)
if(TORCH_MLIR_ENABLE_MHLO)
list(APPEND LinkedLibs ChloPasses)
endif()
add_mlir_library(TorchMLIRTorchConversionPasses
BackendTypeConversion.cpp
BackendTypeConversionPasses.cpp
@ -17,15 +33,5 @@ add_mlir_library(TorchMLIRTorchConversionPasses
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRFuncTransforms
TorchMLIRTorchConversionDialect
TorchMLIRTorchDialect
TorchMLIRTorchPasses
TorchMLIRTorchToLinalg
TorchMLIRTorchToTMTensor
TorchMLIRTorchToStd
TorchMLIRTorchToSCF
MLIRMemRefTransforms
)
${LinkedLibs}
)

View File

@ -21,6 +21,7 @@
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#ifdef TORCH_MLIR_ENABLE_MHLO
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#endif
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
@ -145,10 +146,20 @@ void TorchConversion::createTorchBackendToMhloBackendPipeline(
// The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<func::FuncOp>(createCSEPass());
}
// Convert CHLO ops to MHLO ops
pm.addNestedPass<func::FuncOp>(mhlo::createChloLegalizeToHloPass());
if (options.optimize) {
// Clean up any non-canonical code introduced above..
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
// The resolution of `dim` ops tends to create identical ops. CSE them.
pm.addNestedPass<func::FuncOp>(createCSEPass());
}
// Finish the type conversion from `torch` types to the types of the
// MHLO backend contract.
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
pm.addNestedPass<func::FuncOp>(
TorchConversion::createFinalizingBackendTypeConversionPass());
}
#endif
#endif

File diff suppressed because it is too large Load Diff