mirror of https://github.com/llvm/torch-mlir
[MHLO] Support for dynamic shape in basic op conversion by introducing CHLO dialect (#1123)
* [MHLO] Support for dynamic shape in basic op conversion by introducing CHLO dialect Co-authored-by: Bairen Yi <yibairen.byron@bytedance.com> Co-authored-by: Jiawei Wu <xremold@gmail.com> Co-authored-by: Tianyou Guo <tianyou.gty@alibaba-inc.com> Co-authored-by: Xu Yan <yancey.yx@alibaba-inc.com> Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com> * [MHLO] Support I32 as shape tensor dtype * [NFC] Add a 'TODO' annotationpull/1134/head
parent
3772e0bd91
commit
76c976682c
|
@ -12,7 +12,10 @@
|
|||
#include "../PassDetail.h"
|
||||
#include "./MhloLegalizeUtils.h"
|
||||
#include "./PopulatePatterns.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
|
@ -51,7 +54,7 @@ public:
|
|||
auto selfTy = self.getType().cast<TensorType>();
|
||||
|
||||
if (!selfTy)
|
||||
return op.emitError("Only Tensor types supported in MHLO");
|
||||
return op.emitError("only Tensor types supported in MHLO");
|
||||
|
||||
if (selfTy.getElementType().isa<mlir::FloatType>()) {
|
||||
rewriter.replaceOpWithNewOp<MhloOpT>(
|
||||
|
@ -62,7 +65,7 @@ public:
|
|||
return success();
|
||||
} else {
|
||||
return op.emitError(
|
||||
"Only floating-point datatype legalization supported");
|
||||
"only floating-point datatype legalization supported");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -85,29 +88,29 @@ public:
|
|||
.template dyn_cast<TensorType>();
|
||||
|
||||
if (!outType)
|
||||
return op.emitError("Only Tensor types supported in MHLO");
|
||||
return op.emitError("only Tensor types supported in MHLO");
|
||||
|
||||
Type outElemTy = outType.getElementType();
|
||||
if (!outElemTy.isIntOrFloat())
|
||||
return op.emitError(
|
||||
"Only floating-point or integer datatype legalization supported");
|
||||
"only floating-point or integer datatype legalization supported");
|
||||
|
||||
// FIXME: Handle layout, device and pin_memory. Assume dtype has been
|
||||
// processed to set output type correctly?
|
||||
if (!op.layout().getType().template isa<Torch::NoneType>())
|
||||
return op.emitError("Only default layout is supported");
|
||||
return op.emitError("only default layout is supported");
|
||||
|
||||
bool pinMemory;
|
||||
if (!op.pin_memory().getType().template isa<Torch::NoneType>() &&
|
||||
(!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) ||
|
||||
pinMemory)) {
|
||||
return op.emitError(
|
||||
"Unsupported pin_memory, should be either None or false");
|
||||
"unsupported pin_memory, should be either None or false");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> shape;
|
||||
if (!matchPattern(op.size(), m_TorchConstantIntList(shape))) {
|
||||
return op.emitError("Shape must be a list of Scalar constants");
|
||||
return op.emitError("shape must be a list of Scalar constants");
|
||||
}
|
||||
|
||||
int64_t size = 1;
|
||||
|
@ -128,7 +131,7 @@ public:
|
|||
// These binary op legalizations are specific to add/sub which have an
|
||||
// alpha multiplier.
|
||||
namespace {
|
||||
template <typename AtenOpT, typename MhloOpT>
|
||||
template <typename AtenOpT, typename ChloOpT>
|
||||
class ConvertAtenAddSubOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
|
@ -137,12 +140,12 @@ public:
|
|||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value lhs = adaptor.self();
|
||||
TensorType lhsType = lhs.getType().dyn_cast<TensorType>();
|
||||
RankedTensorType lhsType = lhs.getType().dyn_cast<RankedTensorType>();
|
||||
Value rhs = adaptor.other();
|
||||
TensorType rhsType = rhs.getType().dyn_cast<TensorType>();
|
||||
RankedTensorType rhsType = rhs.getType().dyn_cast<RankedTensorType>();
|
||||
|
||||
if (!lhsType)
|
||||
return op.emitError("Only Tensor types supported in MHLO");
|
||||
return op.emitError("only Tensor types supported in MHLO");
|
||||
|
||||
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
|
@ -151,53 +154,44 @@ public:
|
|||
Type outElemTy = outType.getElementType();
|
||||
if (!outElemTy.isIntOrFloat()) {
|
||||
return op.emitError(
|
||||
"Only floating-point or integer datatype legalization supported");
|
||||
"only floating-point or integer datatype legalization supported");
|
||||
}
|
||||
|
||||
Value rhsAsTensor;
|
||||
if (!rhsType) {
|
||||
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(),
|
||||
rhsAsTensor, outElemTy,
|
||||
outType.getShape())))
|
||||
return op.emitError("Currently only scalar constants are supported for "
|
||||
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), rhs,
|
||||
outElemTy, {})))
|
||||
return op.emitError("currently only scalar constants are supported for "
|
||||
"conversion in MHLO operation");
|
||||
}
|
||||
Value lhsTensor = lhs;
|
||||
Value rhsTensor = rhsType ? rhs : rhsAsTensor;
|
||||
|
||||
// Handle broadcasting. Since we have the output type already, here we
|
||||
// just broodcast operands' shape to output shape.
|
||||
lhsTensor = mhlo::promoteAndBroadcast(rewriter, lhsTensor, outType);
|
||||
rhsTensor = mhlo::promoteAndBroadcast(rewriter, rhsTensor, outType);
|
||||
lhs = mhlo::promoteType(rewriter, lhs, outType);
|
||||
rhs = mhlo::promoteType(rewriter, rhs, outType);
|
||||
|
||||
// Handle alpha.
|
||||
Value multTensor;
|
||||
if (skipMultiplyAlpha(op.alpha())) {
|
||||
multTensor = rhsTensor;
|
||||
} else {
|
||||
Value alphaTensor;
|
||||
if (!skipMultiplyAlpha(op.alpha())) {
|
||||
Value alpha;
|
||||
if (failed(mhlo::torchAlphaToMhloTensor(rewriter, op.getOperation(),
|
||||
op.alpha(), alphaTensor,
|
||||
outElemTy, outType.getShape(),
|
||||
op.alpha(), alpha, outElemTy, {},
|
||||
/*checkForUnity=*/false))) {
|
||||
return op.emitError("Currently only scalar constants are supported for "
|
||||
return op.emitError("currently only scalar constants are supported for "
|
||||
"alpha in conversion to MHLO operation");
|
||||
}
|
||||
|
||||
multTensor = rewriter.create<mhlo::MulOp>(op.getLoc(), outType, rhsTensor,
|
||||
alphaTensor);
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
rhs = rewriter.create<chlo::BroadcastMulOp>(op->getLoc(), rhs, alpha,
|
||||
bcastDimensions);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<MhloOpT>(op, outType, lhsTensor, multTensor);
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
|
||||
bcastDimensions);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Binary op legalizations for Mul variants.
|
||||
// Binary op legalizations for Mul/Div variants.
|
||||
namespace {
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenMulOp : public OpConversionPattern<AtenOpT> {
|
||||
template <typename AtenOpT, typename ChloOpT>
|
||||
class ConvertAtenMulDivOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
|
@ -210,7 +204,7 @@ public:
|
|||
TensorType rhsType = rhs.getType().dyn_cast<TensorType>();
|
||||
|
||||
if (!lhsType)
|
||||
return op.emitError("Only Tensor types supported in MHLO");
|
||||
return op.emitError("only Tensor types supported in MHLO");
|
||||
|
||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
|
@ -219,86 +213,23 @@ public:
|
|||
Type outElemTy = outType.getElementType();
|
||||
if (!outElemTy.isIntOrFloat()) {
|
||||
return op.emitError(
|
||||
"Only floating-point or integer datatype legalization supported");
|
||||
"only floating-point or integer datatype legalization supported");
|
||||
}
|
||||
|
||||
Value lhsTensor = lhs;
|
||||
Value rhsTensor;
|
||||
if (std::is_same<AtenOpT, AtenSquareOp>()) {
|
||||
rhsTensor = lhs;
|
||||
} else {
|
||||
if (!rhsType) {
|
||||
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(),
|
||||
rhsTensor, outElemTy,
|
||||
outType.getShape())))
|
||||
return op.emitError(
|
||||
"Currently only scalar constants are supported for "
|
||||
"conversion in MHLO operation");
|
||||
} else {
|
||||
rhsTensor = rhs;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle broadcasting. Since we have the output type already, here we
|
||||
// just broodcast operands' shape to output shape.
|
||||
lhsTensor = mhlo::promoteAndBroadcast(rewriter, lhsTensor, outType);
|
||||
rhsTensor = mhlo::promoteAndBroadcast(rewriter, rhsTensor, outType);
|
||||
|
||||
rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, outType, lhsTensor, rhsTensor);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Binary op legalizations for Div variants.
|
||||
namespace {
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value lhs = adaptor.self();
|
||||
auto lhsTy = lhs.getType().dyn_cast<TensorType>();
|
||||
Value rhs = adaptor.other();
|
||||
auto rhsTy = rhs.getType().dyn_cast<TensorType>();
|
||||
|
||||
if (!lhsTy)
|
||||
return op.emitError("Only Tensor types supported.");
|
||||
|
||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
Type outElemTy = outType.getElementType();
|
||||
if (!outElemTy.isIntOrFloat()) {
|
||||
return op.emitError(
|
||||
"Only floating-point or integer datatype legalization supported");
|
||||
}
|
||||
|
||||
Value lhsTensor = lhs;
|
||||
Value rhsTensor;
|
||||
if (!rhsTy) {
|
||||
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(),
|
||||
rhsTensor, outElemTy,
|
||||
outType.getShape())))
|
||||
return op.emitError("Currently only scalar constants are supported for "
|
||||
if (!rhsType) {
|
||||
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), rhs,
|
||||
outElemTy, {})))
|
||||
return op.emitError("currently only scalar constants are supported for "
|
||||
"conversion in MHLO operation");
|
||||
} else {
|
||||
rhsTensor = rhs;
|
||||
}
|
||||
|
||||
// Handle broadcasting. Since we have the output type already, here we
|
||||
// just broodcast operands' shape to output shape.
|
||||
lhsTensor = mhlo::promoteAndBroadcast(rewriter, lhsTensor, outType);
|
||||
rhsTensor = mhlo::promoteAndBroadcast(rewriter, rhsTensor, outType);
|
||||
|
||||
rewriter.replaceOpWithNewOp<mhlo::DivOp>(op, outType, lhsTensor, rhsTensor);
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
lhs = mhlo::promoteType(rewriter, lhs, outType);
|
||||
rhs = mhlo::promoteType(rewriter, rhs, outType);
|
||||
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
|
||||
bcastDimensions);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// Binary op legalizations for comparator ops.
|
||||
|
@ -318,7 +249,7 @@ public:
|
|||
RankedTensorType rhsTy = rhs.getType().dyn_cast<RankedTensorType>();
|
||||
|
||||
if (!lhsTy)
|
||||
return op.emitError("Only Tensor types supported in MHLO");
|
||||
return op.emitError("only Tensor types supported in MHLO");
|
||||
|
||||
RankedTensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
|
@ -327,21 +258,19 @@ public:
|
|||
Type lhsElemTy = lhsTy.getElementType();
|
||||
if (!lhsElemTy.isIntOrFloat()) {
|
||||
return op.emitError(
|
||||
"Only floating-point or integer datatype legalization supported");
|
||||
"only floating-point or integer datatype legalization supported");
|
||||
}
|
||||
|
||||
Value rhsAsTensor;
|
||||
if (!rhsTy) {
|
||||
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(),
|
||||
rhsAsTensor, lhsElemTy, {}))) {
|
||||
return op.emitError("Currently only scalar constants are supported for "
|
||||
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), rhs,
|
||||
lhsElemTy, {}))) {
|
||||
return op.emitError("currently only scalar constants are supported for "
|
||||
"conversion in MHLO operation");
|
||||
}
|
||||
}
|
||||
|
||||
Value lhsTensor = lhs;
|
||||
Value rhsTensor = rhsTy ? rhs : rhsAsTensor;
|
||||
rhsTensor = mhlo::promoteAndBroadcast(rewriter, rhsTensor, lhsTy);
|
||||
// TODO: what is the PyTorch default type promotion?
|
||||
rhs = mhlo::promoteType(rewriter, rhs, lhsTy);
|
||||
|
||||
mhlo::ComparisonTypeAttr compareTypeAttr;
|
||||
mhlo::ComparisonDirectionAttr compareDirectionAttr;
|
||||
|
@ -371,9 +300,9 @@ public:
|
|||
compareDirectionAttr = mhlo::ComparisonDirectionAttr::get(
|
||||
op->getContext(), mhlo::ComparisonDirection::NE);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<mhlo::CompareOp>(
|
||||
op, outType, lhsTensor, rhsTensor, compareDirectionAttr,
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
rewriter.replaceOpWithNewOp<chlo::BroadcastCompareOp>(
|
||||
op, outType, lhs, rhs, bcastDimensions, compareDirectionAttr,
|
||||
compareTypeAttr);
|
||||
return success();
|
||||
}
|
||||
|
@ -438,12 +367,63 @@ public:
|
|||
matchAndRewrite(AtenBroadcastToOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value self = adaptor.self();
|
||||
auto selfTy = self.getType().cast<RankedTensorType>();
|
||||
auto outType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
|
||||
Value bcastOp = mhlo::promoteAndBroadcast(rewriter, self, outType);
|
||||
rewriter.replaceOp(op, bcastOp);
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO_STATIC_SHAPE
|
||||
if (selfTy.hasStaticShape()) {
|
||||
Value bcastOp = mhlo::promoteAndBroadcast(rewriter, self, outType);
|
||||
rewriter.replaceOp(op, bcastOp);
|
||||
return success();
|
||||
}
|
||||
#endif
|
||||
|
||||
SmallVector<Value> shape;
|
||||
if (!(getListConstructElements(adaptor.size(), shape))) {
|
||||
return op->emitError("desired shape must be a list of scalar");
|
||||
}
|
||||
SmallVector<Value> bcastShapeVec;
|
||||
int64_t totalRank = shape.size();
|
||||
int64_t selfRank = selfTy.getRank();
|
||||
int64_t leadingRank = totalRank - selfRank;
|
||||
|
||||
for (int64_t i = 0; i < totalRank; ++i) {
|
||||
Value dValue = shape[i];
|
||||
Value newD;
|
||||
int64_t dInt;
|
||||
if (!(matchPattern(dValue, m_TorchConstantInt(&dInt)))) {
|
||||
return op->emitError("element of desired shape must be a scalar");
|
||||
}
|
||||
if (i >= leadingRank && dInt == -1) {
|
||||
newD = rewriter.create<mlir::tensor::DimOp>(op->getLoc(), self,
|
||||
i - leadingRank);
|
||||
} else {
|
||||
dValue = rewriter.create<torch::TorchConversion::ToI64Op>(op->getLoc(),
|
||||
dValue);
|
||||
newD = rewriter.create<mlir::arith::IndexCastOp>(
|
||||
op->getLoc(), rewriter.getIndexType(), dValue);
|
||||
}
|
||||
bcastShapeVec.push_back(newD);
|
||||
}
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
|
||||
for (auto &dsize : bcastShapeVec) {
|
||||
auto dsizeI64 = rewriter.create<mlir::arith::IndexCastOp>(
|
||||
op->getLoc(), rewriter.getI64Type(), dsize);
|
||||
dsize = rewriter.create<arith::TruncIOp>(op->getLoc(),
|
||||
rewriter.getI32Type(), dsizeI64);
|
||||
}
|
||||
#endif
|
||||
|
||||
Value bcastShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), ValueRange{bcastShapeVec});
|
||||
auto dimensionNumbers =
|
||||
llvm::to_vector<4>(llvm::seq<int64_t>(leadingRank, totalRank));
|
||||
rewriter.replaceOpWithNewOp<mhlo::DynamicBroadcastInDimOp>(
|
||||
op, outType, self, bcastShapeTensor,
|
||||
rewriter.getI64TensorAttr(dimensionNumbers));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -464,19 +444,19 @@ public:
|
|||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
if (!inType)
|
||||
return op.emitError("Only ranked tensor types with static shapes are "
|
||||
return op.emitError("only ranked tensor types with static shapes are "
|
||||
"currently supported");
|
||||
|
||||
SmallVector<int64_t> permValues;
|
||||
if (!matchPattern(adaptor.dims(), m_TorchConstantIntList(permValues)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only constant dimensions are currently supported");
|
||||
op, "only constant dimensions are currently supported");
|
||||
|
||||
int64_t inRank = inType.getRank();
|
||||
for (auto &d : permValues) {
|
||||
d = toPositiveDim(d, inRank);
|
||||
if (!isValidDim(d, inRank))
|
||||
return op.emitError("Not all dims are valid");
|
||||
return op.emitError("not all dims are valid");
|
||||
}
|
||||
|
||||
DenseIntElementsAttr permutation = DenseIntElementsAttr::get(
|
||||
|
@ -517,7 +497,7 @@ LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
|
|||
return success();
|
||||
} else {
|
||||
return op.emitError(
|
||||
"Only floating-point datatype legalization currently supported");
|
||||
"only floating-point datatype legalization currently supported");
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
@ -535,6 +515,7 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
|
|||
// Tensors with integer types need to be converted to signless integer
|
||||
// element type. All tensors with element types other than integer can reuse
|
||||
// existing elements attribute.
|
||||
// TODO: what about unsigned integer?
|
||||
if (auto elements = op.valueAttr().dyn_cast<DenseIntElementsAttr>()) {
|
||||
Type builtinTensorElemTy = resultType.getElementType();
|
||||
unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth();
|
||||
|
@ -566,13 +547,11 @@ LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
|
|||
auto outTy =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
if (!inputTy.getElementType().isa<mlir::FloatType>()) {
|
||||
return op.emitError("Only floating-point datatype legalization supported "
|
||||
return op.emitError("only floating-point datatype legalization supported "
|
||||
"for AtenReciprocalOp");
|
||||
}
|
||||
Value oneTensor =
|
||||
mhlo::getConstTensor<float>(rewriter, op, {static_cast<float>(1.0)}, {})
|
||||
.getValue();
|
||||
oneTensor = mhlo::promoteAndBroadcast(rewriter, oneTensor, inputTy);
|
||||
|
||||
Value oneTensor = chlo::getConstantLike(rewriter, op->getLoc(), 1, input);
|
||||
rewriter.replaceOpWithNewOp<mhlo::DivOp>(op, outTy, oneTensor, input);
|
||||
return success();
|
||||
}
|
||||
|
@ -593,7 +572,7 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
|
|||
if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.a(), mhloTensor,
|
||||
outputElemType, outputShape,
|
||||
false))) {
|
||||
return op->emitError("Failed lowering PrimNumToTensorScalarOp to MHLO");
|
||||
return op->emitError("failed lowering PrimNumToTensorScalarOp to MHLO");
|
||||
}
|
||||
rewriter.replaceOp(op, mhloTensor);
|
||||
return success();
|
||||
|
@ -611,7 +590,7 @@ LogicalResult ConvertAtenOp<AtenContiguousOp>::matchAndRewrite(
|
|||
// Not a tensor type.
|
||||
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
|
||||
if (!selfType)
|
||||
return op.emitError("Only tensor types are currently supported");
|
||||
return op.emitError("only tensor types are currently supported");
|
||||
|
||||
// FIXME: memory_format is not handled.
|
||||
|
||||
|
@ -633,27 +612,17 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
|
|||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
||||
auto lhsElemTy = lhsTy.getElementType();
|
||||
|
||||
int64_t lhsSize = 1;
|
||||
for (auto &en : llvm::enumerate(lhsTy.getShape())) {
|
||||
lhsSize *= en.value();
|
||||
if (!lhsElemTy.isa<mlir::FloatType>()) {
|
||||
return op->emitError("only float tensor in relu op is supported");
|
||||
}
|
||||
auto constTy = RankedTensorType::get(lhsTy.getShape(), lhsElemTy);
|
||||
DenseElementsAttr constAttr;
|
||||
if (lhsElemTy.isa<mlir::FloatType>()) {
|
||||
std::vector<APFloat> constVec(
|
||||
lhsSize,
|
||||
APFloat::getZero(lhsElemTy.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
/*negative=*/false));
|
||||
constAttr = DenseElementsAttr::get(constTy, constVec);
|
||||
} else if (lhsElemTy.isa<mlir::IntegerType>()) {
|
||||
std::vector<APInt> constVec(
|
||||
lhsSize, APInt::getZero(lhsElemTy.getIntOrFloatBitWidth()));
|
||||
constAttr = DenseElementsAttr::get(constTy, constVec);
|
||||
}
|
||||
Value rhs =
|
||||
rewriter.create<mhlo::ConstantOp>(op.getLoc(), constTy, constAttr);
|
||||
|
||||
rewriter.replaceOpWithNewOp<mhlo::MaxOp>(op, lhs, rhs);
|
||||
Value zeroTensor;
|
||||
zeroTensor = chlo::getConstantLike(
|
||||
rewriter, op->getLoc(),
|
||||
APFloat::getZero(lhsElemTy.cast<mlir::FloatType>().getFloatSemantics(),
|
||||
false),
|
||||
lhs);
|
||||
rewriter.replaceOpWithNewOp<mhlo::MaxOp>(op, lhs, zeroTensor);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -666,72 +635,12 @@ LogicalResult ConvertAtenOp<AtenErfOp>::matchAndRewrite(
|
|||
AtenErfOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value input = adaptor.self();
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto inputType = input.getType().cast<TensorType>();
|
||||
if (!inputType.getElementType().isa<mlir::FloatType>()) {
|
||||
return rewriter.notifyMatchFailure(op, "Only support float data type");
|
||||
return rewriter.notifyMatchFailure(op, "only float tensor is supported");
|
||||
}
|
||||
auto outType =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
|
||||
// Using:
|
||||
// https://en.wikipedia.org/wiki/Error_function#Numerical_approximations with
|
||||
// maximum error as 5 x 10^-4 where a1 = 0.278393, a2 = 0.230389, a3 =
|
||||
// 0.000972, a4 = 0.078108.
|
||||
// Erf = 1 - 1 / (1 + a1X + a2X^2 + a3X^3 + a4X^4)^4
|
||||
|
||||
auto loc = op->getLoc();
|
||||
auto zeroConst =
|
||||
mhlo::getConstTensor<float>(rewriter, op, {0.0}, {}).getValue();
|
||||
auto zero = mhlo::promoteAndBroadcast(rewriter, zeroConst, outType);
|
||||
auto oneConst =
|
||||
mhlo::getConstTensor<float>(rewriter, op, {1.0}, {}).getValue();
|
||||
auto one = mhlo::promoteAndBroadcast(rewriter, oneConst, outType);
|
||||
auto a1Const =
|
||||
mhlo::getConstTensor<float>(rewriter, op, {0.278393}, {}).getValue();
|
||||
auto a1 = mhlo::promoteAndBroadcast(rewriter, a1Const, outType);
|
||||
auto a2Const =
|
||||
mhlo::getConstTensor<float>(rewriter, op, {0.230389}, {}).getValue();
|
||||
auto a2 = mhlo::promoteAndBroadcast(rewriter, a2Const, outType);
|
||||
auto a3Const =
|
||||
mhlo::getConstTensor<float>(rewriter, op, {0.000972}, {}).getValue();
|
||||
auto a3 = mhlo::promoteAndBroadcast(rewriter, a3Const, outType);
|
||||
auto a4Const =
|
||||
mhlo::getConstTensor<float>(rewriter, op, {0.078108}, {}).getValue();
|
||||
auto a4 = mhlo::promoteAndBroadcast(rewriter, a4Const, outType);
|
||||
|
||||
auto absX = rewriter.create<mhlo::AbsOp>(loc, outType, input);
|
||||
auto a1X = rewriter.create<mhlo::MulOp>(loc, outType, a1, absX);
|
||||
auto sum = rewriter.create<mhlo::AddOp>(loc, outType, a1X, one);
|
||||
|
||||
auto x2 = rewriter.create<mhlo::MulOp>(loc, outType, absX, absX);
|
||||
auto a2X = rewriter.create<mhlo::MulOp>(loc, outType, a2, x2);
|
||||
sum = rewriter.create<mhlo::AddOp>(loc, outType, sum, a2X);
|
||||
|
||||
auto x3 = rewriter.create<mhlo::MulOp>(loc, outType, x2, absX);
|
||||
auto a3X = rewriter.create<mhlo::MulOp>(loc, outType, a3, x3);
|
||||
sum = rewriter.create<mhlo::AddOp>(loc, outType, sum, a3X);
|
||||
|
||||
auto x4 = rewriter.create<mhlo::MulOp>(loc, outType, x3, absX);
|
||||
auto a4X = rewriter.create<mhlo::MulOp>(loc, outType, a4, x4);
|
||||
sum = rewriter.create<mhlo::AddOp>(loc, outType, sum, a4X);
|
||||
|
||||
auto rcprl = rewriter.create<mhlo::DivOp>(loc, outType, one, sum);
|
||||
auto rcprl2 = rewriter.create<mhlo::MulOp>(loc, outType, rcprl, rcprl);
|
||||
auto rcprl4 = rewriter.create<mhlo::MulOp>(loc, outType, rcprl2, rcprl2);
|
||||
auto erf = rewriter.create<mhlo::SubtractOp>(loc, outType, one, rcprl4);
|
||||
|
||||
// Deal with negative x.
|
||||
mhlo::ComparisonDirectionAttr compareDirectionAttr =
|
||||
mhlo::ComparisonDirectionAttr::get(op->getContext(),
|
||||
mhlo::ComparisonDirection::GE);
|
||||
mhlo::ComparisonTypeAttr compareTypeAttr = mhlo::ComparisonTypeAttr::get(
|
||||
op->getContext(), mhlo::ComparisonType::FLOAT);
|
||||
auto geZero = rewriter.create<mhlo::CompareOp>(
|
||||
loc, RankedTensorType::get(outType.getShape(), rewriter.getI1Type()),
|
||||
input, zero, compareDirectionAttr, compareTypeAttr);
|
||||
auto negaErf = rewriter.create<mhlo::NegOp>(loc, erf);
|
||||
rewriter.replaceOpWithNewOp<mhlo::SelectOp>(op, outType, geZero, erf,
|
||||
negaErf);
|
||||
rewriter.replaceOpWithNewOp<chlo::ErfOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), input);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -754,59 +663,71 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
Value momentum = adaptor.momentum();
|
||||
(void)momentum;
|
||||
|
||||
// init weight, bias, runningVar, runningMean if they are none
|
||||
auto initNoneValue = [&](Value &input, bool zero) {
|
||||
SmallVector<APFloat> constVec(inputTy.getShape()[1],
|
||||
APFloat::getZero(inputTy.getElementType()
|
||||
.cast<mlir::FloatType>()
|
||||
.getFloatSemantics()));
|
||||
if (!zero) {
|
||||
for (auto &item : constVec) {
|
||||
item = APFloat(inputTy.getElementType()
|
||||
.cast<mlir::FloatType>()
|
||||
.getFloatSemantics(),
|
||||
1);
|
||||
}
|
||||
}
|
||||
auto constType = RankedTensorType::get({inputTy.getShape()[1]},
|
||||
inputTy.getElementType());
|
||||
auto constAttr = DenseElementsAttr::get(constType, constVec);
|
||||
input =
|
||||
rewriter.create<mhlo::ConstantOp>(op.getLoc(), constType, constAttr);
|
||||
};
|
||||
if (inputTy.getRank() <= 2) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"input should have rank larger than 2");
|
||||
}
|
||||
if (!inputTy.getElementType().template isa<mlir::FloatType>()) {
|
||||
return op.emitError("only input tensor of float type is supported");
|
||||
}
|
||||
auto inputElemTy = inputTy.getElementType().cast<mlir::FloatType>();
|
||||
|
||||
Value channelDim = rewriter.create<tensor::DimOp>(op->getLoc(), input, 1);
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
|
||||
auto channelDimI64 = rewriter.create<mlir::arith::IndexCastOp>(
|
||||
op->getLoc(), rewriter.getI64Type(), channelDim);
|
||||
channelDim = rewriter.create<arith::TruncIOp>(
|
||||
op->getLoc(), rewriter.getI32Type(), channelDimI64);
|
||||
#endif
|
||||
|
||||
Value channelShape = rewriter.create<tensor::FromElementsOp>(
|
||||
op->getLoc(), ValueRange{channelDim});
|
||||
if (failed(checkNotNone(rewriter, op, weight))) {
|
||||
initNoneValue(weight, false);
|
||||
weight = mhlo::getConstantOfShape(
|
||||
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1),
|
||||
channelShape,
|
||||
RankedTensorType::get({inputTy.getShape()[1]},
|
||||
inputTy.getElementType()));
|
||||
}
|
||||
if (failed(checkNotNone(rewriter, op, bias))) {
|
||||
initNoneValue(bias, true);
|
||||
bias = mhlo::getConstantOfShape(
|
||||
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0),
|
||||
channelShape,
|
||||
RankedTensorType::get({inputTy.getShape()[1]},
|
||||
inputTy.getElementType()));
|
||||
}
|
||||
if (failed(checkNotNone(rewriter, op, runningVar))) {
|
||||
initNoneValue(runningVar, false);
|
||||
runningVar = mhlo::getConstantOfShape(
|
||||
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1),
|
||||
channelShape,
|
||||
RankedTensorType::get({inputTy.getShape()[1]},
|
||||
inputTy.getElementType()));
|
||||
}
|
||||
if (failed(checkNotNone(rewriter, op, runningMean))) {
|
||||
initNoneValue(runningMean, true);
|
||||
runningMean = mhlo::getConstantOfShape(
|
||||
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0),
|
||||
channelShape,
|
||||
RankedTensorType::get({inputTy.getShape()[1]},
|
||||
inputTy.getElementType()));
|
||||
}
|
||||
|
||||
auto weightTy = weight.getType().cast<RankedTensorType>();
|
||||
auto biasTy = bias.getType().cast<RankedTensorType>();
|
||||
auto runningMeanTy = runningMean.getType().cast<RankedTensorType>();
|
||||
auto runningVarTy = runningVar.getType().cast<RankedTensorType>();
|
||||
if (inputTy.getRank() <= 2) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"input should have rank larger than 2");
|
||||
}
|
||||
|
||||
if (weightTy.getRank() != 1 || biasTy.getRank() != 1 ||
|
||||
runningMeanTy.getRank() != 1 || runningVarTy.getRank() != 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expect weight, bias, running_mean and running_var to be rank 1");
|
||||
}
|
||||
if (!inputTy.getElementType().template isa<mlir::FloatType>() ||
|
||||
!weightTy.getElementType().template isa<mlir::FloatType>() ||
|
||||
if (!weightTy.getElementType().template isa<mlir::FloatType>() ||
|
||||
!biasTy.getElementType().template isa<mlir::FloatType>() ||
|
||||
!runningMeanTy.getElementType().template isa<mlir::FloatType>() ||
|
||||
!runningVarTy.getElementType().template isa<mlir::FloatType>()) {
|
||||
return op.emitError(
|
||||
"Only float element type is supported in MHLO BatchNormOp");
|
||||
return op.emitError("only float weight/bias/runningMean/runningVar tensor "
|
||||
"of float type is supported");
|
||||
}
|
||||
|
||||
double eps = 0.0;
|
||||
|
@ -858,6 +779,10 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
|||
Value weight = adaptor.weight();
|
||||
Value bias = adaptor.bias();
|
||||
|
||||
if (!inputTy.hasStaticShape()) {
|
||||
return op->emitError("dynamic shaped input is not supported");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> normalizedShape;
|
||||
if (!matchPattern(op.normalized_shape(),
|
||||
m_TorchConstantIntList(normalizedShape))) {
|
||||
|
@ -866,11 +791,12 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
|||
}
|
||||
double eps = 0;
|
||||
if (!matchPattern(op.eps(), m_TorchConstantFloat(&eps))) {
|
||||
return rewriter.notifyMatchFailure(op, "non const float eps unsupported");
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non const float eps is unsupported");
|
||||
}
|
||||
if (failed(checkNotNone(rewriter, op, weight)) ||
|
||||
failed(checkNotNone(rewriter, op, bias))) {
|
||||
return op->emitError("Unsupported None for weight or bias");
|
||||
return op->emitError("none weight or bias is unsupported");
|
||||
}
|
||||
auto weightTy = weight.getType().cast<RankedTensorType>();
|
||||
auto biasTy = bias.getType().cast<RankedTensorType>();
|
||||
|
@ -878,13 +804,13 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
|||
if (!inputTy.getElementType().isa<mlir::FloatType>() ||
|
||||
!biasTy.getElementType().isa<mlir::FloatType>() ||
|
||||
!weightTy.getElementType().isa<mlir::FloatType>()) {
|
||||
return op->emitError("For now, only float data type are supported");
|
||||
return op->emitError("currently only float data type are supported");
|
||||
}
|
||||
int64_t normalizedShapeRank = normalizedShape.size();
|
||||
if (weightTy.getRank() != normalizedShapeRank ||
|
||||
biasTy.getRank() != normalizedShapeRank ||
|
||||
inputRank < normalizedShapeRank || normalizedShapeRank < 1) {
|
||||
return rewriter.notifyMatchFailure(op, "Input or weight or bias shape or"
|
||||
return rewriter.notifyMatchFailure(op, "input or weight or bias shape or"
|
||||
"normalized shape not compatible");
|
||||
}
|
||||
for (int64_t i = 1; i <= normalizedShapeRank; i++) {
|
||||
|
@ -897,7 +823,7 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
|||
}
|
||||
}
|
||||
|
||||
// flatten dims to fit batch_norm operation.
|
||||
// Flatten dims to fit batch_norm operation.
|
||||
int64_t numFeatureDimSize = 1;
|
||||
int64_t numEmbeddingDimSize = 1;
|
||||
for (int64_t i = 0; i < inputRank - normalizedShapeRank; i++) {
|
||||
|
@ -915,14 +841,14 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
|||
auto mhloBathNormOutMeanOrVarTy =
|
||||
RankedTensorType::get(meanOrVarMhloOutShape, inputTy.getElementType());
|
||||
|
||||
// reshape input
|
||||
// Reshape input
|
||||
auto mhloInput = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
op->getLoc(), mhloBatchNormOutTy, input,
|
||||
mhlo::getConstTensor(rewriter, op, llvm::makeArrayRef(inputFlattenShape),
|
||||
{static_cast<int64_t>(inputFlattenShape.size())})
|
||||
.getValue());
|
||||
|
||||
// generate "scale" and "offset" Value for mhlo.BatchNormTrainingOp.
|
||||
// Generate "scale" and "offset" Value for mhlo.BatchNormTrainingOp.
|
||||
SmallVector<APFloat> zeroConstVec(
|
||||
numFeatureDimSize, APFloat::getZero(inputTy.getElementType()
|
||||
.cast<mlir::FloatType>()
|
||||
|
@ -946,7 +872,7 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
|||
mhloBathNormOutMeanOrVarTy, mhloInput, scale, offset,
|
||||
rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(1));
|
||||
|
||||
// reshape back
|
||||
// Reshape back
|
||||
auto outputTy =
|
||||
getTypeConverter()->convertType(op.getType(0)).cast<RankedTensorType>();
|
||||
auto outputMeanOrVarTy =
|
||||
|
@ -1016,32 +942,28 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
|||
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
|
||||
#undef INSERT_CONSTANT_FILL_PATTERN
|
||||
|
||||
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, MhloOp) \
|
||||
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, ChloOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenAddSubOp<AtenOp, MhloOp>>(typeConverter, context);
|
||||
INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, mhlo::AddOp);
|
||||
INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, mhlo::AddOp);
|
||||
INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, mhlo::SubtractOp);
|
||||
INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, mhlo::SubtractOp);
|
||||
patterns.add<ConvertAtenAddSubOp<AtenOp, ChloOp>>(typeConverter, context);
|
||||
INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, chlo::BroadcastAddOp);
|
||||
INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, chlo::BroadcastAddOp);
|
||||
INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, chlo::BroadcastSubOp);
|
||||
INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, chlo::BroadcastSubOp);
|
||||
#undef INSERT_BINARY_ADDSUB_PATTERN
|
||||
|
||||
#define INSERT_BINARY_MUL_PATTERN(AtenOp) \
|
||||
#define INSERT_BINARY_MULDIV_PATTERN(AtenOp, ChloOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenMulOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_BINARY_MUL_PATTERN(AtenMulTensorOp);
|
||||
INSERT_BINARY_MUL_PATTERN(AtenMulScalarOp);
|
||||
#undef INSERT_BINARY_MUL_PATTERN
|
||||
|
||||
#define INSERT_BINARY_DIV_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenDivOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp);
|
||||
INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp);
|
||||
#undef INSERT_BINARY_DIV_PATTERN
|
||||
patterns.add<ConvertAtenMulDivOp<AtenOp, ChloOp>>(typeConverter, context);
|
||||
INSERT_BINARY_MULDIV_PATTERN(AtenMulTensorOp, chlo::BroadcastMulOp);
|
||||
INSERT_BINARY_MULDIV_PATTERN(AtenMulScalarOp, chlo::BroadcastMulOp);
|
||||
INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorOp, chlo::BroadcastDivOp);
|
||||
INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarOp, chlo::BroadcastDivOp);
|
||||
#undef INSERT_BINARY_MULDIV_PATTERN
|
||||
|
||||
#define INSERT_BINARY_COMPARE_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenCompareOp<AtenOp>>(typeConverter, context);
|
||||
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp);
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp);
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp);
|
||||
|
@ -1067,4 +989,4 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
|||
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
|
||||
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
}
|
||||
}
|
|
@ -10,6 +10,7 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo
|
|||
|
||||
DEPENDS
|
||||
MhloDialect
|
||||
ChloDialect
|
||||
TorchMLIRConversionPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
|
@ -19,6 +20,7 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo
|
|||
MLIRIR
|
||||
MLIRPass
|
||||
MhloDialect
|
||||
ChloDialect
|
||||
TorchMLIRTorchDialect
|
||||
)
|
||||
|
||||
|
|
|
@ -117,9 +117,9 @@ llvm::Optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
|
|||
}
|
||||
|
||||
template <>
|
||||
llvm::Optional<Value> getConstTensor<double>(PatternRewriter &rewriter,
|
||||
Operation *op, ArrayRef<double> vec,
|
||||
ArrayRef<int64_t> shape) {
|
||||
llvm::Optional<Value>
|
||||
getConstTensor<double>(PatternRewriter &rewriter, Operation *op,
|
||||
ArrayRef<double> vec, ArrayRef<int64_t> shape) {
|
||||
uint64_t num_total_elements = 1;
|
||||
for (int64_t a : shape) {
|
||||
num_total_elements *= a;
|
||||
|
@ -149,7 +149,6 @@ template llvm::Optional<Value> getConstTensor<int64_t>(PatternRewriter &,
|
|||
ArrayRef<int64_t> vec,
|
||||
ArrayRef<int64_t> shape);
|
||||
|
||||
|
||||
template <typename T>
|
||||
static bool isInValidRange(bool isFloat, const double &doubleValue, bool isInt,
|
||||
const int64_t &intValue) {
|
||||
|
@ -166,20 +165,16 @@ static bool isInValidRange(bool isFloat, const double &doubleValue, bool isInt,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
Value getSplatConstTensor(ConversionPatternRewriter &rewriter,
|
||||
Operation *op,
|
||||
T val,
|
||||
Type dtype,
|
||||
llvm::ArrayRef<int64_t> dshape) {
|
||||
auto const_type = RankedTensorType::get(
|
||||
dshape, dtype);
|
||||
Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
T val, Type dtype, llvm::ArrayRef<int64_t> dshape) {
|
||||
auto const_type = RankedTensorType::get(dshape, dtype);
|
||||
auto const_attr = SplatElementsAttr::get(const_type, val);
|
||||
auto const_op =
|
||||
rewriter.create<mhlo::ConstantOp>(op->getLoc(), const_type, const_attr);
|
||||
return const_op.getResult();
|
||||
}
|
||||
|
||||
|
||||
// TODO: Support for variable scalar.
|
||||
LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter,
|
||||
Operation *op, Value torchScalarValue,
|
||||
Value &mhloTensor, Type dtype,
|
||||
|
@ -198,9 +193,8 @@ LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter,
|
|||
|
||||
if (dtype.isa<mlir::FloatType>()) {
|
||||
if (doBroadcast) {
|
||||
mhloTensor = getSplatConstTensor<float>(rewriter, op,
|
||||
(isFloat ? doubleValue : intValue),
|
||||
dtype, dshape);
|
||||
mhloTensor = getSplatConstTensor<float>(
|
||||
rewriter, op, (isFloat ? doubleValue : intValue), dtype, dshape);
|
||||
} else {
|
||||
mhloTensor = mhlo::getConstTensor<float>(
|
||||
rewriter, op, (isFloat ? doubleValue : intValue), dshape)
|
||||
|
@ -219,7 +213,8 @@ LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter,
|
|||
int32_t d = isFloat ? static_cast<int32_t>(doubleValue)
|
||||
: static_cast<int32_t>(intValue);
|
||||
if (doBroadcast) {
|
||||
mhloTensor = getSplatConstTensor<int32_t>(rewriter, op, d, dtype, dshape);
|
||||
mhloTensor =
|
||||
getSplatConstTensor<int32_t>(rewriter, op, d, dtype, dshape);
|
||||
} else {
|
||||
mhloTensor =
|
||||
mhlo::getConstTensor<int32_t>(rewriter, op, {d}, dshape).getValue();
|
||||
|
@ -231,7 +226,8 @@ LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter,
|
|||
}
|
||||
int64_t d = (isFloat ? static_cast<int64_t>(doubleValue) : intValue);
|
||||
if (doBroadcast) {
|
||||
mhloTensor = getSplatConstTensor<int64_t>(rewriter, op, d, dtype, dshape);
|
||||
mhloTensor =
|
||||
getSplatConstTensor<int64_t>(rewriter, op, d, dtype, dshape);
|
||||
} else {
|
||||
mhloTensor =
|
||||
mhlo::getConstTensor<int64_t>(rewriter, op, {d}, dshape).getValue();
|
||||
|
@ -243,7 +239,6 @@ LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter,
|
|||
return success();
|
||||
}
|
||||
|
||||
|
||||
LogicalResult torchAlphaToMhloTensor(ConversionPatternRewriter &rewriter,
|
||||
Operation *op, Value alphaScalar,
|
||||
Value &alphaTensor, Type dtype,
|
||||
|
@ -268,20 +263,33 @@ LogicalResult torchAlphaToMhloTensor(ConversionPatternRewriter &rewriter,
|
|||
return success();
|
||||
}
|
||||
|
||||
|
||||
Value promoteAndBroadcast(ConversionPatternRewriter &rewriter,
|
||||
Value input, TensorType outType) {
|
||||
// Two tensors are “broadcastable” if the following rules hold:
|
||||
// - Each tensor has at least one dimension.
|
||||
// - When iterating over the dimension sizes, starting at the trailing dimension,
|
||||
// the dimension sizes must either be equal, one of them is 1, or one of them
|
||||
// does not exist.
|
||||
Operation* op = input.getDefiningOp();
|
||||
Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) {
|
||||
Operation *op = input.getDefiningOp();
|
||||
TensorType in_type = input.getType().dyn_cast<TensorType>();
|
||||
|
||||
if (in_type.getElementType() != outType.getElementType()) {
|
||||
TensorType promoted_type = in_type.cloneWith(in_type.getShape(), outType.getElementType());
|
||||
input = rewriter.create<mhlo::ConvertOp>(op->getLoc(), promoted_type, input);
|
||||
TensorType promotedType =
|
||||
in_type.cloneWith(in_type.getShape(), outType.getElementType());
|
||||
return rewriter.create<mhlo::ConvertOp>(op->getLoc(), promotedType, input);
|
||||
}
|
||||
return input;
|
||||
}
|
||||
|
||||
Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
|
||||
TensorType outType) {
|
||||
// Two tensors are “broadcastable” if the following rules hold:
|
||||
// - Each tensor has at least one dimension.
|
||||
// - When iterating over the dimension sizes, starting at the trailing
|
||||
// dimension, the dimension sizes must either be equal, one of them is 1, or
|
||||
// one of them does not exist.
|
||||
Operation *op = input.getDefiningOp();
|
||||
TensorType in_type = input.getType().dyn_cast<TensorType>();
|
||||
|
||||
if (in_type.getElementType() != outType.getElementType()) {
|
||||
TensorType promoted_type =
|
||||
in_type.cloneWith(in_type.getShape(), outType.getElementType());
|
||||
input =
|
||||
rewriter.create<mhlo::ConvertOp>(op->getLoc(), promoted_type, input);
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> inShape = in_type.getShape();
|
||||
|
@ -301,7 +309,8 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter,
|
|||
bcastDims.push_back(outPos);
|
||||
do_bcast = true;
|
||||
} else {
|
||||
op->emitError("The size of tensor a (") << inDim << ")"
|
||||
op->emitError("The size of tensor a (")
|
||||
<< inDim << ")"
|
||||
<< "must match the size of tensor b (" << outDim << ")"
|
||||
<< "at non-singleton dimension " << inPos;
|
||||
}
|
||||
|
@ -311,10 +320,11 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter,
|
|||
return input;
|
||||
}
|
||||
DenseIntElementsAttr bcast_attr = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({static_cast<long int>(bcastDims.size())}, rewriter.getI64Type()),
|
||||
RankedTensorType::get({static_cast<long int>(bcastDims.size())},
|
||||
rewriter.getI64Type()),
|
||||
bcastDims);
|
||||
auto bcast_op =
|
||||
rewriter.create<mhlo::BroadcastInDimOp>(op->getLoc(), outType, input, bcast_attr);
|
||||
auto bcast_op = rewriter.create<mhlo::BroadcastInDimOp>(op->getLoc(), outType,
|
||||
input, bcast_attr);
|
||||
return bcast_op.getResult();
|
||||
}
|
||||
|
||||
|
@ -418,5 +428,15 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
|||
.getResult();
|
||||
}
|
||||
|
||||
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
|
||||
const APFloat &constant, Value shape,
|
||||
TensorType outType) {
|
||||
auto constAttr = rewriter.getFloatAttr(outType.getElementType(), constant);
|
||||
auto constTensor = rewriter.create<mhlo::ConstantOp>(loc, constAttr);
|
||||
return rewriter
|
||||
.create<mhlo::DynamicBroadcastInDimOp>(loc, outType, constTensor, shape,
|
||||
rewriter.getI64TensorAttr({}))
|
||||
.getResult();
|
||||
}
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
} // namespace mlir
|
|
@ -59,6 +59,8 @@ LogicalResult torchAlphaToMhloTensor(ConversionPatternRewriter &rewriter,
|
|||
llvm::ArrayRef<int64_t> dshape,
|
||||
bool checkForUnity);
|
||||
|
||||
Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType);
|
||||
|
||||
Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
|
||||
TensorType outType);
|
||||
|
||||
|
@ -78,7 +80,11 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
|||
Value tensor,
|
||||
ArrayRef<int64_t> inputUnsqzDims);
|
||||
|
||||
|
||||
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
|
||||
const APFloat &constant, Value shape,
|
||||
TensorType outType);
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H
|
||||
#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H
|
|
@ -12,6 +12,7 @@
|
|||
#include "../PassDetail.h"
|
||||
#include "./PopulatePatterns.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
|
@ -32,6 +33,7 @@ namespace {
|
|||
class ConvertTorchToMhlo : public ConvertTorchToMhloBase<ConvertTorchToMhlo> {
|
||||
public:
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<chlo::ChloDialect>();
|
||||
registry.insert<mhlo::MhloDialect>();
|
||||
registry.insert<tensor::TensorDialect>();
|
||||
registry.insert<arith::ArithmeticDialect>();
|
||||
|
@ -40,7 +42,7 @@ public:
|
|||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<mhlo::MhloDialect, tensor::TensorDialect,
|
||||
target.addLegalDialect<chlo::ChloDialect, mhlo::MhloDialect, tensor::TensorDialect,
|
||||
arith::ArithmeticDialect, Torch::TorchDialect>();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
|
@ -68,4 +70,4 @@ public:
|
|||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::createConvertTorchToMhloPass() {
|
||||
return std::make_unique<ConvertTorchToMhlo>();
|
||||
}
|
||||
}
|
|
@ -1,3 +1,19 @@
|
|||
set(LinkedLibs MLIRIR
|
||||
MLIRPass
|
||||
MLIRFuncTransforms
|
||||
TorchMLIRTorchConversionDialect
|
||||
TorchMLIRTorchDialect
|
||||
TorchMLIRTorchPasses
|
||||
TorchMLIRTorchToLinalg
|
||||
TorchMLIRTorchToTMTensor
|
||||
TorchMLIRTorchToStd
|
||||
TorchMLIRTorchToSCF
|
||||
MLIRMemRefTransforms)
|
||||
|
||||
if(TORCH_MLIR_ENABLE_MHLO)
|
||||
list(APPEND LinkedLibs ChloPasses)
|
||||
endif()
|
||||
|
||||
add_mlir_library(TorchMLIRTorchConversionPasses
|
||||
BackendTypeConversion.cpp
|
||||
BackendTypeConversionPasses.cpp
|
||||
|
@ -17,15 +33,5 @@ add_mlir_library(TorchMLIRTorchConversionPasses
|
|||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRFuncTransforms
|
||||
TorchMLIRTorchConversionDialect
|
||||
TorchMLIRTorchDialect
|
||||
TorchMLIRTorchPasses
|
||||
TorchMLIRTorchToLinalg
|
||||
TorchMLIRTorchToTMTensor
|
||||
TorchMLIRTorchToStd
|
||||
TorchMLIRTorchToSCF
|
||||
MLIRMemRefTransforms
|
||||
)
|
||||
${LinkedLibs}
|
||||
)
|
|
@ -21,6 +21,7 @@
|
|||
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
||||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
||||
#ifdef TORCH_MLIR_ENABLE_MHLO
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
|
||||
#endif
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
|
@ -145,10 +146,20 @@ void TorchConversion::createTorchBackendToMhloBackendPipeline(
|
|||
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
||||
pm.addNestedPass<func::FuncOp>(createCSEPass());
|
||||
}
|
||||
|
||||
// Convert CHLO ops to MHLO ops
|
||||
pm.addNestedPass<func::FuncOp>(mhlo::createChloLegalizeToHloPass());
|
||||
if (options.optimize) {
|
||||
// Clean up any non-canonical code introduced above..
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
// The resolution of `dim` ops tends to create identical ops. CSE them.
|
||||
pm.addNestedPass<func::FuncOp>(createCSEPass());
|
||||
}
|
||||
|
||||
// Finish the type conversion from `torch` types to the types of the
|
||||
// MHLO backend contract.
|
||||
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
||||
pm.addNestedPass<func::FuncOp>(
|
||||
TorchConversion::createFinalizingBackendTypeConversionPass());
|
||||
}
|
||||
#endif
|
||||
#endif
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue