2022-07-21 07:18:16 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
2022-07-21 07:18:16 +08:00
|
|
|
|
|
|
|
#include "../PassDetail.h"
|
2023-02-02 21:29:47 +08:00
|
|
|
#include "PopulatePatterns.h"
|
2024-03-04 23:31:54 +08:00
|
|
|
#include "Utils.h"
|
2023-02-02 21:29:47 +08:00
|
|
|
|
2022-10-05 21:28:06 +08:00
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
2023-07-27 18:35:25 +08:00
|
|
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
2022-08-02 12:53:24 +08:00
|
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
2022-08-31 03:44:00 +08:00
|
|
|
#include "stablehlo/dialect/ChloOps.h"
|
2023-02-02 21:29:47 +08:00
|
|
|
#include "stablehlo/dialect/StablehloOps.h"
|
2023-05-25 10:32:55 +08:00
|
|
|
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
|
2022-07-21 07:18:16 +08:00
|
|
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
2023-05-25 10:32:55 +08:00
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
2022-07-21 07:18:16 +08:00
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
|
|
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
2024-05-09 11:39:13 +08:00
|
|
|
#include <cmath>
|
2022-07-21 07:18:16 +08:00
|
|
|
#include <numeric>
|
2024-04-16 04:45:10 +08:00
|
|
|
#include <type_traits>
|
2022-07-21 07:18:16 +08:00
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::torch;
|
|
|
|
using namespace mlir::torch::Torch;
|
2023-02-02 21:29:47 +08:00
|
|
|
using namespace mlir::torch::torch_to_stablehlo;
|
2022-07-21 07:18:16 +08:00
|
|
|
|
2022-11-24 14:28:34 +08:00
|
|
|
LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
|
|
|
|
mlir::Value &self, mlir::Value &other,
|
|
|
|
size_t dimSizeIndexBits) {
|
2024-04-28 05:00:56 +08:00
|
|
|
auto selfTy = dyn_cast<RankedTensorType>(self.getType());
|
|
|
|
auto otherTy = dyn_cast<RankedTensorType>(other.getType());
|
2022-11-24 14:28:34 +08:00
|
|
|
auto selfRank = selfTy.getRank();
|
|
|
|
auto otherRank = otherTy.getRank();
|
|
|
|
if (selfRank == 0 || otherRank == 0)
|
|
|
|
return success();
|
|
|
|
if (selfRank > otherRank) {
|
|
|
|
auto unsqueezeDims =
|
|
|
|
llvm::to_vector<4>(llvm::seq<int64_t>(0, selfRank - otherRank));
|
2023-02-02 21:29:47 +08:00
|
|
|
auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, other,
|
|
|
|
unsqueezeDims, dimSizeIndexBits);
|
2022-11-24 14:28:34 +08:00
|
|
|
if (failed(unsqueezeInfo))
|
|
|
|
return failure();
|
|
|
|
other = *unsqueezeInfo;
|
|
|
|
} else if (otherRank > selfRank) {
|
|
|
|
auto unsqueezeDims =
|
|
|
|
llvm::to_vector<4>(llvm::seq<int64_t>(0, otherRank - selfRank));
|
2023-02-02 21:29:47 +08:00
|
|
|
auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, self, unsqueezeDims,
|
|
|
|
dimSizeIndexBits);
|
2022-11-24 14:28:34 +08:00
|
|
|
if (failed(unsqueezeInfo))
|
|
|
|
return failure();
|
|
|
|
self = *unsqueezeInfo;
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-07-27 13:07:51 +08:00
|
|
|
bool skipMultiplyAlpha(Value alphaValue) {
|
|
|
|
double doubleValue;
|
|
|
|
auto isFloat = matchPattern(alphaValue, m_TorchConstantFloat(&doubleValue));
|
|
|
|
|
|
|
|
int64_t intValue;
|
|
|
|
auto isInt = matchPattern(alphaValue, m_TorchConstantInt(&intValue));
|
|
|
|
|
|
|
|
return ((isFloat && doubleValue == 1.0) || (isInt && intValue == 1.0));
|
|
|
|
}
|
|
|
|
|
2022-09-16 15:09:21 +08:00
|
|
|
static FailureOr<Value> getMaxValueOfDtype(Operation *op, Type elementType,
|
|
|
|
PatternRewriter &rewriter) {
|
|
|
|
auto constType = RankedTensorType::get({}, elementType);
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<mlir::FloatType>(elementType)) {
|
2022-09-16 15:09:21 +08:00
|
|
|
auto constAttr = SplatElementsAttr::get(
|
|
|
|
constType,
|
2024-04-11 21:47:35 +08:00
|
|
|
APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
|
2022-09-16 15:09:21 +08:00
|
|
|
/*negative=*/false));
|
2023-02-02 21:29:47 +08:00
|
|
|
return rewriter
|
|
|
|
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
2022-09-16 15:09:21 +08:00
|
|
|
.getResult();
|
|
|
|
}
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<mlir::IntegerType>(elementType)) {
|
|
|
|
auto integerType = cast<mlir::IntegerType>(elementType);
|
2022-09-16 15:09:21 +08:00
|
|
|
DenseElementsAttr constAttr;
|
|
|
|
if (integerType.isUnsigned()) {
|
|
|
|
constAttr = SplatElementsAttr::get(
|
|
|
|
constType, APInt::getMaxValue(integerType.getWidth()));
|
|
|
|
} else {
|
|
|
|
constAttr = SplatElementsAttr::get(
|
|
|
|
constType, APInt::getSignedMaxValue(integerType.getWidth()));
|
|
|
|
}
|
2023-02-02 21:29:47 +08:00
|
|
|
return rewriter
|
|
|
|
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
2022-09-16 15:09:21 +08:00
|
|
|
.getResult();
|
|
|
|
}
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
|
|
|
static FailureOr<Value> getMinValueOfDtype(Operation *op, Type elementType,
|
|
|
|
PatternRewriter &rewriter) {
|
|
|
|
auto constType = RankedTensorType::get({}, elementType);
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<mlir::FloatType>(elementType)) {
|
2022-09-16 15:09:21 +08:00
|
|
|
auto constAttr = SplatElementsAttr::get(
|
|
|
|
constType,
|
2024-04-11 21:47:35 +08:00
|
|
|
APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
|
2022-09-16 15:09:21 +08:00
|
|
|
/*negative=*/true));
|
2023-02-02 21:29:47 +08:00
|
|
|
return rewriter
|
|
|
|
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
2022-09-16 15:09:21 +08:00
|
|
|
.getResult();
|
|
|
|
}
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<mlir::IntegerType>(elementType)) {
|
|
|
|
auto integerType = cast<mlir::IntegerType>(elementType);
|
2022-09-16 15:09:21 +08:00
|
|
|
DenseElementsAttr constAttr;
|
|
|
|
if (integerType.isUnsigned()) {
|
|
|
|
constAttr = SplatElementsAttr::get(
|
|
|
|
constType, APInt::getMinValue(integerType.getWidth()));
|
|
|
|
} else {
|
|
|
|
constAttr = SplatElementsAttr::get(
|
|
|
|
constType, APInt::getSignedMinValue(integerType.getWidth()));
|
|
|
|
}
|
2023-02-02 21:29:47 +08:00
|
|
|
return rewriter
|
|
|
|
.create<stablehlo::ConstantOp>(op->getLoc(), constType, constAttr)
|
2022-09-16 15:09:21 +08:00
|
|
|
.getResult();
|
|
|
|
}
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
|
2023-01-12 06:40:03 +08:00
|
|
|
// These legalizations are for unary ops.
|
|
|
|
namespace {
|
2023-02-02 21:29:47 +08:00
|
|
|
template <typename AtenOpT, typename StablehloOpT>
|
2023-01-12 06:40:03 +08:00
|
|
|
class ConvertAtenUnaryOp : public OpConversionPattern<AtenOpT> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
|
|
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
Value self = adaptor.getSelf();
|
2024-04-28 05:00:56 +08:00
|
|
|
auto selfType = cast<TensorType>(self.getType());
|
2023-01-12 06:40:03 +08:00
|
|
|
if (!selfType) {
|
2023-02-02 21:29:47 +08:00
|
|
|
return op.emitError("only Tensor types supported in StableHLO");
|
2023-01-12 06:40:03 +08:00
|
|
|
}
|
|
|
|
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
|
|
|
->convertType(op.getType())
|
|
|
|
.template cast<TensorType>();
|
2023-06-26 00:04:17 +08:00
|
|
|
self = hlo::promoteType(rewriter, op.getLoc(), self, outType);
|
2023-02-02 21:29:47 +08:00
|
|
|
rewriter.replaceOpWithNewOp<StablehloOpT>(op, outType, self);
|
2023-01-12 06:40:03 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-07-27 13:07:51 +08:00
|
|
|
// These legalizations are for unary ops with only for floating point datatypes.
|
|
|
|
// There is no supported quantized integer mode for these.
|
|
|
|
namespace {
|
2023-02-02 21:29:47 +08:00
|
|
|
template <typename AtenOpT, typename StablehloOpT>
|
2022-07-27 13:07:51 +08:00
|
|
|
class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern<AtenOpT> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
|
|
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value self = adaptor.getSelf();
|
2024-04-28 05:00:56 +08:00
|
|
|
auto selfTy = cast<TensorType>(self.getType());
|
2022-07-27 13:07:51 +08:00
|
|
|
|
|
|
|
if (!selfTy)
|
2023-02-02 21:29:47 +08:00
|
|
|
return op.emitError("only Tensor types supported in StableHLO");
|
2022-07-27 13:07:51 +08:00
|
|
|
|
2024-04-28 05:00:56 +08:00
|
|
|
if (isa<mlir::FloatType>(selfTy.getElementType())) {
|
2023-02-02 21:29:47 +08:00
|
|
|
rewriter.replaceOpWithNewOp<StablehloOpT>(
|
2022-07-27 13:07:51 +08:00
|
|
|
op,
|
|
|
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
|
|
|
op.getType()),
|
|
|
|
self);
|
|
|
|
return success();
|
|
|
|
} else {
|
|
|
|
return op.emitError(
|
2022-08-02 12:53:24 +08:00
|
|
|
"only floating-point datatype legalization supported");
|
2022-07-27 13:07:51 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2024-04-23 17:57:12 +08:00
|
|
|
// These legalizations are for unary ops with promoting to floating point
|
|
|
|
// datatypes.
|
|
|
|
namespace {
|
|
|
|
template <typename AtenOpT, typename StablehloOpT>
|
|
|
|
class ConvertAtenUnaryPromoteToFPOp : public OpConversionPattern<AtenOpT> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
|
|
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
Value self = adaptor.getSelf();
|
2024-04-28 05:00:56 +08:00
|
|
|
auto selfTy = cast<TensorType>(self.getType());
|
2024-04-23 17:57:12 +08:00
|
|
|
if (!selfTy)
|
|
|
|
return op.emitError("only Tensor types supported in StableHLO");
|
|
|
|
auto resultTy = OpConversionPattern<AtenOpT>::getTypeConverter()
|
|
|
|
->convertType(op.getType())
|
|
|
|
.template cast<TensorType>();
|
|
|
|
|
2024-04-28 05:00:56 +08:00
|
|
|
if (isa<mlir::FloatType>(resultTy.getElementType())) {
|
2024-04-23 17:57:12 +08:00
|
|
|
Value src = hlo::promoteType(rewriter, op.getLoc(), self, resultTy);
|
|
|
|
rewriter.replaceOpWithNewOp<StablehloOpT>(op, resultTy, src);
|
|
|
|
return success();
|
|
|
|
} else {
|
|
|
|
return op.emitError(
|
|
|
|
"only result to be floating-point datatype legalization supported");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-07-27 13:07:51 +08:00
|
|
|
// aten.ones & aten.zeros
|
|
|
|
// Ref: Error checking based on the Torch to TOSA lowering
|
|
|
|
namespace {
|
|
|
|
template <typename AtenOpT, int fillVal>
|
|
|
|
class ConvertAtenConstPatternOp : public OpConversionPattern<AtenOpT> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
|
|
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
|
|
|
|
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
|
|
|
->convertType(op.getType())
|
|
|
|
.template dyn_cast<TensorType>();
|
|
|
|
|
|
|
|
if (!outType)
|
2023-02-02 21:29:47 +08:00
|
|
|
return op.emitError("only Tensor types supported in StableHLO");
|
2022-07-27 13:07:51 +08:00
|
|
|
|
|
|
|
Type outElemTy = outType.getElementType();
|
|
|
|
if (!outElemTy.isIntOrFloat())
|
|
|
|
return op.emitError(
|
2022-08-02 12:53:24 +08:00
|
|
|
"only floating-point or integer datatype legalization supported");
|
2022-07-27 13:07:51 +08:00
|
|
|
|
|
|
|
SmallVector<int64_t> shape;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(shape))) {
|
2022-08-02 12:53:24 +08:00
|
|
|
return op.emitError("shape must be a list of Scalar constants");
|
2022-07-27 13:07:51 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
int64_t size = 1;
|
|
|
|
for (auto s : shape)
|
|
|
|
size *= s;
|
|
|
|
|
|
|
|
SmallVector<int32_t> values(size, fillVal);
|
|
|
|
auto constOp =
|
2023-02-02 21:29:47 +08:00
|
|
|
hlo::getConstTensor<int32_t>(rewriter, op, values, shape).value();
|
2022-07-27 13:07:51 +08:00
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(op, outType, constOp);
|
2022-07-27 13:07:51 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
2023-07-29 21:55:49 +08:00
|
|
|
namespace {
|
|
|
|
// Casts a tensor of exactly one element to an elemental type.
|
|
|
|
// Many codes borrowed from
|
|
|
|
// `lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp`
|
|
|
|
template <typename AtenOpT>
|
|
|
|
class ConvertAtenTensorToScalarLikeOp : public OpConversionPattern<AtenOpT> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
|
|
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inputType = dyn_cast<RankedTensorType>(adaptor.getA().getType());
|
2023-07-29 21:55:49 +08:00
|
|
|
if (!inputType)
|
|
|
|
|
|
|
|
op.emitError("only Tensor types supported in StableHLO");
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
Value input = adaptor.getA();
|
|
|
|
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
|
|
|
|
int64_t inputRank = inputSizes.size();
|
2024-04-28 05:00:56 +08:00
|
|
|
Type inputDtype = cast<BaseTensorType>(op.getA().getType()).getDtype();
|
2023-07-29 21:55:49 +08:00
|
|
|
|
|
|
|
Value constantOne =
|
|
|
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
for (int64_t i = 0; i < inputRank; i++)
|
|
|
|
checkDimEqualHelper(rewriter, loc, inputSizes[i], constantOne);
|
|
|
|
|
|
|
|
Value constantZero =
|
|
|
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
|
|
|
|
SmallVector<Value> indices(inputRank, constantZero);
|
|
|
|
Value result = rewriter.create<tensor::ExtractOp>(loc, input, indices);
|
|
|
|
Type resultType =
|
|
|
|
this->getTypeConverter()->convertType(op->getResult(0).getType());
|
|
|
|
rewriter.replaceOp(op, convertScalarToDtype(rewriter, loc, result,
|
|
|
|
resultType, inputDtype));
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-09-23 20:39:15 +08:00
|
|
|
// The binary broadcast patterns
|
|
|
|
namespace {
|
|
|
|
template <typename AtenOpT, typename ChloOpT>
|
|
|
|
class ConvertAtenBinaryBroadcastOp : public OpConversionPattern<AtenOpT> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
|
|
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value lhs = adaptor.getSelf();
|
2024-04-28 05:00:56 +08:00
|
|
|
auto lhsTy = cast<TensorType>(lhs.getType());
|
2022-12-08 04:20:41 +08:00
|
|
|
Value rhs = adaptor.getOther();
|
2024-04-28 05:00:56 +08:00
|
|
|
auto rhsTy = cast<TensorType>(rhs.getType());
|
2022-09-23 20:39:15 +08:00
|
|
|
|
|
|
|
if (!lhsTy || !rhsTy)
|
|
|
|
return op.emitError("only Tensor types supported");
|
|
|
|
|
2023-01-12 06:40:03 +08:00
|
|
|
auto outTy = OpConversionPattern<AtenOpT>::getTypeConverter()
|
|
|
|
->convertType(op.getType())
|
|
|
|
.template cast<TensorType>();
|
2022-09-23 20:39:15 +08:00
|
|
|
|
2023-06-26 00:04:17 +08:00
|
|
|
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy);
|
|
|
|
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy);
|
2022-09-23 20:39:15 +08:00
|
|
|
|
2023-01-12 06:40:03 +08:00
|
|
|
rewriter.replaceOpWithNewOp<ChloOpT>(op, outTy, lhs, rhs,
|
|
|
|
/*broadcast_attr*/ nullptr);
|
2022-09-23 20:39:15 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-07-27 13:07:51 +08:00
|
|
|
// These binary op legalizations are specific to add/sub which have an
|
|
|
|
// alpha multiplier.
|
|
|
|
namespace {
|
2022-08-02 12:53:24 +08:00
|
|
|
template <typename AtenOpT, typename ChloOpT>
|
2022-07-27 13:07:51 +08:00
|
|
|
class ConvertAtenAddSubOp : public OpConversionPattern<AtenOpT> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
|
|
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value lhs = adaptor.getSelf();
|
2024-04-28 05:00:56 +08:00
|
|
|
RankedTensorType lhsType = dyn_cast<RankedTensorType>(lhs.getType());
|
2022-12-08 04:20:41 +08:00
|
|
|
Value rhs = adaptor.getOther();
|
2024-04-28 05:00:56 +08:00
|
|
|
RankedTensorType rhsType = dyn_cast<RankedTensorType>(rhs.getType());
|
2022-07-27 13:07:51 +08:00
|
|
|
|
|
|
|
if (!lhsType)
|
2023-02-02 21:29:47 +08:00
|
|
|
return op.emitError("only Tensor types supported in StableHLO");
|
2022-07-27 13:07:51 +08:00
|
|
|
|
|
|
|
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
|
|
|
->convertType(op.getType())
|
|
|
|
.template cast<TensorType>();
|
|
|
|
|
|
|
|
Type outElemTy = outType.getElementType();
|
|
|
|
if (!outElemTy.isIntOrFloat()) {
|
|
|
|
return op.emitError(
|
2022-08-02 12:53:24 +08:00
|
|
|
"only floating-point or integer datatype legalization supported");
|
2022-07-27 13:07:51 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
if (!rhsType) {
|
2023-02-02 21:29:47 +08:00
|
|
|
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
|
|
|
|
outElemTy);
|
2022-08-17 09:07:36 +08:00
|
|
|
if (isa<AtenRsubScalarOp>(op)) {
|
|
|
|
std::swap(lhs, rhs);
|
|
|
|
}
|
2022-07-27 13:07:51 +08:00
|
|
|
}
|
2022-08-02 12:53:24 +08:00
|
|
|
|
2023-06-26 00:04:17 +08:00
|
|
|
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
|
|
|
|
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
|
2022-08-02 12:53:24 +08:00
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!skipMultiplyAlpha(op.getAlpha())) {
|
2023-02-02 21:29:47 +08:00
|
|
|
Value alpha = hlo::scalarToStablehloTensor(rewriter, op,
|
|
|
|
adaptor.getAlpha(), outElemTy);
|
Bump stablehlo to openxla/stablehlo@fd52182f76cadb82f2064fe5fc49a4fb4347a826 (#2821)
With the recent LLVM integrate and changes from
https://github.com/llvm/llvm-project/pull/78260, we hit this build error
in Stablehlo (which is quite old).
```
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1020:14: error: no member named 'startRootUpdate' in 'mlir::PatternRewriter'
rewriter.startRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1026:16: error: no member named 'finalizeRootUpdate' in 'mlir::PatternRewriter'
rewriter.finalizeRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1029:16: error: no member named 'cancelRootUpdate' in 'mlir::PatternRewriter'
rewriter.cancelRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1108:14: error: no member named 'updateRootInPlace' in 'mlir::PatternRewriter'
rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; });
~~~~~~~~ ^
4 errors generated.
Target @torch-mlir//:torch-mlir-opt failed to build
```
I'm still puzzled as to how this didn't fail with the CMake merge gating
CI (do we not test Stablehlo builds/tests?). In any case, bumping our
submodule to https://github.com/openxla/stablehlo/pull/1918 fixes it.
It exposes a new failing lit test in TorchToStablehlo though, that I
have looped stablehlo developers into
([here](https://discord.com/channels/999073994483433573/999074539138990131/1201235845391331419)).
```
bazel run @torch-mlir//test/Conversion:TorchToStablehlo/scatter.mlir.test
...external/torch-mlir/test/Conversion/TorchToStablehlo/scatter.mlir
within split at <stdin>:1 offset :33:8: error: unexpected error: Expects non-empty reduction block for type inference
%0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64>
^
LLVM ERROR: Failed to infer result type(s).
```
Bazel CI:
https://github.com/sjain-stanford/torch-mlir/actions/runs/7732673480/job/21083102228
2024-02-01 06:21:17 +08:00
|
|
|
DenseI64ArrayAttr bcastDimensions;
|
2022-08-02 12:53:24 +08:00
|
|
|
rhs = rewriter.create<chlo::BroadcastMulOp>(op->getLoc(), rhs, alpha,
|
|
|
|
bcastDimensions);
|
2022-07-27 13:07:51 +08:00
|
|
|
}
|
|
|
|
|
Bump stablehlo to openxla/stablehlo@fd52182f76cadb82f2064fe5fc49a4fb4347a826 (#2821)
With the recent LLVM integrate and changes from
https://github.com/llvm/llvm-project/pull/78260, we hit this build error
in Stablehlo (which is quite old).
```
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1020:14: error: no member named 'startRootUpdate' in 'mlir::PatternRewriter'
rewriter.startRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1026:16: error: no member named 'finalizeRootUpdate' in 'mlir::PatternRewriter'
rewriter.finalizeRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1029:16: error: no member named 'cancelRootUpdate' in 'mlir::PatternRewriter'
rewriter.cancelRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1108:14: error: no member named 'updateRootInPlace' in 'mlir::PatternRewriter'
rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; });
~~~~~~~~ ^
4 errors generated.
Target @torch-mlir//:torch-mlir-opt failed to build
```
I'm still puzzled as to how this didn't fail with the CMake merge gating
CI (do we not test Stablehlo builds/tests?). In any case, bumping our
submodule to https://github.com/openxla/stablehlo/pull/1918 fixes it.
It exposes a new failing lit test in TorchToStablehlo though, that I
have looped stablehlo developers into
([here](https://discord.com/channels/999073994483433573/999074539138990131/1201235845391331419)).
```
bazel run @torch-mlir//test/Conversion:TorchToStablehlo/scatter.mlir.test
...external/torch-mlir/test/Conversion/TorchToStablehlo/scatter.mlir
within split at <stdin>:1 offset :33:8: error: unexpected error: Expects non-empty reduction block for type inference
%0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64>
^
LLVM ERROR: Failed to infer result type(s).
```
Bazel CI:
https://github.com/sjain-stanford/torch-mlir/actions/runs/7732673480/job/21083102228
2024-02-01 06:21:17 +08:00
|
|
|
DenseI64ArrayAttr bcastDimensions;
|
2022-08-02 12:53:24 +08:00
|
|
|
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
|
|
|
|
bcastDimensions);
|
2022-07-27 13:07:51 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-08-02 12:53:24 +08:00
|
|
|
// Binary op legalizations for Mul/Div variants.
|
2022-07-27 13:07:51 +08:00
|
|
|
namespace {
|
2022-08-02 12:53:24 +08:00
|
|
|
template <typename AtenOpT, typename ChloOpT>
|
|
|
|
class ConvertAtenMulDivOp : public OpConversionPattern<AtenOpT> {
|
2022-07-27 13:07:51 +08:00
|
|
|
public:
|
|
|
|
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
|
|
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value lhs = adaptor.getSelf();
|
2024-04-28 05:00:56 +08:00
|
|
|
auto lhsType = dyn_cast<TensorType>(lhs.getType());
|
2022-12-08 04:20:41 +08:00
|
|
|
Value rhs = adaptor.getOther();
|
2024-04-28 05:00:56 +08:00
|
|
|
TensorType rhsType = dyn_cast<TensorType>(rhs.getType());
|
2022-07-27 13:07:51 +08:00
|
|
|
|
|
|
|
if (!lhsType)
|
2023-02-02 21:29:47 +08:00
|
|
|
return op.emitError("only Tensor types supported in StableHLO");
|
2022-07-27 13:07:51 +08:00
|
|
|
|
2024-04-16 04:45:10 +08:00
|
|
|
auto outType = cast<TensorType>(
|
|
|
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
|
|
|
op.getType()));
|
2022-07-27 13:07:51 +08:00
|
|
|
|
|
|
|
Type outElemTy = outType.getElementType();
|
|
|
|
if (!outElemTy.isIntOrFloat()) {
|
|
|
|
return op.emitError(
|
2022-08-02 12:53:24 +08:00
|
|
|
"only floating-point or integer datatype legalization supported");
|
2022-07-27 13:07:51 +08:00
|
|
|
}
|
|
|
|
|
2024-05-10 17:07:37 +08:00
|
|
|
if constexpr (std::is_same<AtenOpT, AtenSquareOp>()) {
|
2022-08-03 08:16:31 +08:00
|
|
|
rhs = lhs;
|
2024-05-10 17:07:37 +08:00
|
|
|
} else {
|
|
|
|
if (!rhsType) {
|
|
|
|
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
|
|
|
|
outElemTy);
|
|
|
|
}
|
2022-08-03 08:16:31 +08:00
|
|
|
}
|
Bump stablehlo to openxla/stablehlo@fd52182f76cadb82f2064fe5fc49a4fb4347a826 (#2821)
With the recent LLVM integrate and changes from
https://github.com/llvm/llvm-project/pull/78260, we hit this build error
in Stablehlo (which is quite old).
```
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1020:14: error: no member named 'startRootUpdate' in 'mlir::PatternRewriter'
rewriter.startRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1026:16: error: no member named 'finalizeRootUpdate' in 'mlir::PatternRewriter'
rewriter.finalizeRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1029:16: error: no member named 'cancelRootUpdate' in 'mlir::PatternRewriter'
rewriter.cancelRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1108:14: error: no member named 'updateRootInPlace' in 'mlir::PatternRewriter'
rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; });
~~~~~~~~ ^
4 errors generated.
Target @torch-mlir//:torch-mlir-opt failed to build
```
I'm still puzzled as to how this didn't fail with the CMake merge gating
CI (do we not test Stablehlo builds/tests?). In any case, bumping our
submodule to https://github.com/openxla/stablehlo/pull/1918 fixes it.
It exposes a new failing lit test in TorchToStablehlo though, that I
have looped stablehlo developers into
([here](https://discord.com/channels/999073994483433573/999074539138990131/1201235845391331419)).
```
bazel run @torch-mlir//test/Conversion:TorchToStablehlo/scatter.mlir.test
...external/torch-mlir/test/Conversion/TorchToStablehlo/scatter.mlir
within split at <stdin>:1 offset :33:8: error: unexpected error: Expects non-empty reduction block for type inference
%0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64>
^
LLVM ERROR: Failed to infer result type(s).
```
Bazel CI:
https://github.com/sjain-stanford/torch-mlir/actions/runs/7732673480/job/21083102228
2024-02-01 06:21:17 +08:00
|
|
|
DenseI64ArrayAttr bcastDimensions;
|
2023-06-26 00:04:17 +08:00
|
|
|
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
|
|
|
|
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
|
2022-08-06 23:38:06 +08:00
|
|
|
auto loc = op.getLoc();
|
|
|
|
Value result =
|
|
|
|
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
|
|
|
|
|
2024-05-10 17:07:37 +08:00
|
|
|
if constexpr (!std::is_same<AtenDivTensorModeOp, AtenOpT>() &&
|
|
|
|
!std::is_same<AtenDivScalarModeOp, AtenOpT>()) {
|
2022-08-06 23:38:06 +08:00
|
|
|
rewriter.replaceOp(op, result);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2024-04-16 04:45:10 +08:00
|
|
|
auto tensorOp = dyn_cast<AtenDivTensorModeOp>(op.getOperation());
|
|
|
|
auto opRoundingMode =
|
|
|
|
tensorOp
|
|
|
|
? tensorOp.getRoundingMode()
|
|
|
|
: cast<AtenDivScalarModeOp>(op.getOperation()).getRoundingMode();
|
|
|
|
|
2022-08-06 23:38:06 +08:00
|
|
|
std::string roundingMode;
|
2024-04-16 04:45:10 +08:00
|
|
|
if (!matchPattern(opRoundingMode, m_TorchConstantStr(roundingMode))) {
|
2022-08-06 23:38:06 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "only support constant str rounding mode");
|
2024-04-16 04:45:10 +08:00
|
|
|
}
|
2022-08-06 23:38:06 +08:00
|
|
|
|
2024-04-02 17:28:53 +08:00
|
|
|
// if trunc and int, do nothing
|
2024-04-11 21:47:35 +08:00
|
|
|
if (roundingMode == "trunc" && isa<mlir::FloatType>(outElemTy)) {
|
2022-08-06 23:38:06 +08:00
|
|
|
// "trunc" - rounds the results of the division towards zero. Equivalent
|
|
|
|
// to C-style integer division.
|
2023-02-02 21:29:47 +08:00
|
|
|
auto sign = rewriter.create<stablehlo::SignOp>(loc, result);
|
|
|
|
auto abs = rewriter.create<stablehlo::AbsOp>(loc, result);
|
|
|
|
auto floor = rewriter.create<stablehlo::FloorOp>(loc, abs);
|
|
|
|
result = rewriter.create<stablehlo::MulOp>(loc, sign, floor).getResult();
|
2022-08-06 23:38:06 +08:00
|
|
|
}
|
|
|
|
if (roundingMode == "floor") {
|
|
|
|
// "floor" - rounds the results of the division down. Equivalent to
|
|
|
|
// floor division in Python (the // operator)
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<mlir::FloatType>(outElemTy))
|
2024-04-02 17:28:53 +08:00
|
|
|
result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult();
|
|
|
|
else if (!outElemTy.isUnsignedInteger()) {
|
|
|
|
TensorType defaultIntToFloatType =
|
|
|
|
outType.cloneWith(outType.getShape(), rewriter.getF64Type());
|
|
|
|
lhs =
|
|
|
|
hlo::promoteType(rewriter, op.getLoc(), lhs, defaultIntToFloatType);
|
|
|
|
rhs =
|
|
|
|
hlo::promoteType(rewriter, op.getLoc(), rhs, defaultIntToFloatType);
|
|
|
|
result = rewriter.create<ChloOpT>(loc, defaultIntToFloatType, lhs, rhs,
|
|
|
|
bcastDimensions);
|
|
|
|
result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult();
|
|
|
|
result = hlo::promoteType(rewriter, op.getLoc(), result, outType);
|
|
|
|
}
|
2022-08-06 23:38:06 +08:00
|
|
|
}
|
|
|
|
rewriter.replaceOp(op, result);
|
2022-07-27 13:07:51 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
// Binary op legalizations for comparator ops.
|
|
|
|
namespace {
|
|
|
|
template <typename AtenOpT>
|
|
|
|
class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
|
|
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value lhs = adaptor.getSelf();
|
|
|
|
Value rhs = adaptor.getOther();
|
2024-04-28 05:00:56 +08:00
|
|
|
RankedTensorType lhsTy = dyn_cast<RankedTensorType>(lhs.getType());
|
|
|
|
RankedTensorType rhsTy = dyn_cast<RankedTensorType>(rhs.getType());
|
2022-07-27 13:07:51 +08:00
|
|
|
|
2024-05-10 02:20:06 +08:00
|
|
|
if (!lhsTy) {
|
2023-02-02 21:29:47 +08:00
|
|
|
return op.emitError("only Tensor types supported in StableHLO");
|
2024-05-10 02:20:06 +08:00
|
|
|
}
|
|
|
|
if (!rhsTy) {
|
|
|
|
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
|
|
|
|
rhs.getType());
|
|
|
|
rhsTy = dyn_cast<RankedTensorType>(rhs.getType());
|
|
|
|
}
|
2022-07-27 13:07:51 +08:00
|
|
|
|
2024-05-10 02:20:06 +08:00
|
|
|
auto outType = cast<RankedTensorType>(
|
|
|
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
|
|
|
op.getType()));
|
2022-07-27 13:07:51 +08:00
|
|
|
|
|
|
|
Type lhsElemTy = lhsTy.getElementType();
|
2024-05-10 02:20:06 +08:00
|
|
|
Type rhsElemTy = rhsTy.getElementType();
|
|
|
|
if (!lhsElemTy.isIntOrFloat() || !rhsElemTy.isIntOrFloat()) {
|
2022-07-27 13:07:51 +08:00
|
|
|
return op.emitError(
|
2022-08-02 12:53:24 +08:00
|
|
|
"only floating-point or integer datatype legalization supported");
|
2022-07-27 13:07:51 +08:00
|
|
|
}
|
|
|
|
|
2024-05-10 02:20:06 +08:00
|
|
|
if (isa<mlir::IntegerType>(lhsElemTy) && isa<mlir::FloatType>(rhsElemTy)) {
|
|
|
|
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsTy);
|
|
|
|
} else if (isa<mlir::FloatType>(lhsElemTy) &&
|
|
|
|
isa<mlir::IntegerType>(rhsElemTy)) {
|
|
|
|
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy);
|
|
|
|
} else {
|
|
|
|
if (lhsElemTy.getIntOrFloatBitWidth() >
|
|
|
|
rhsElemTy.getIntOrFloatBitWidth()) {
|
|
|
|
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy);
|
|
|
|
} else {
|
|
|
|
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsTy);
|
|
|
|
}
|
2022-07-27 13:07:51 +08:00
|
|
|
}
|
2024-05-10 02:20:06 +08:00
|
|
|
lhsElemTy = dyn_cast<RankedTensorType>(lhs.getType()).getElementType();
|
2022-07-27 13:07:51 +08:00
|
|
|
|
2022-08-31 03:44:00 +08:00
|
|
|
chlo::ComparisonTypeAttr compareTypeAttr;
|
|
|
|
chlo::ComparisonDirectionAttr compareDirectionAttr;
|
2022-07-27 13:07:51 +08:00
|
|
|
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<mlir::FloatType>(lhsElemTy)) {
|
2022-08-31 03:44:00 +08:00
|
|
|
compareTypeAttr = chlo::ComparisonTypeAttr::get(
|
|
|
|
op->getContext(), chlo::ComparisonType::FLOAT);
|
2024-04-11 21:47:35 +08:00
|
|
|
} else if (isa<mlir::IntegerType>(lhsElemTy)) {
|
2022-08-31 03:44:00 +08:00
|
|
|
compareTypeAttr = chlo::ComparisonTypeAttr::get(
|
|
|
|
op->getContext(), chlo::ComparisonType::SIGNED);
|
2022-07-27 13:07:51 +08:00
|
|
|
}
|
|
|
|
|
2024-05-10 17:07:37 +08:00
|
|
|
if constexpr (std::is_same<AtenOpT, AtenLtTensorOp>() ||
|
|
|
|
std::is_same<AtenOpT, AtenLtScalarOp>()) {
|
2022-08-31 03:44:00 +08:00
|
|
|
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
|
|
|
|
op->getContext(), chlo::ComparisonDirection::LT);
|
2024-05-10 17:07:37 +08:00
|
|
|
} else if constexpr (std::is_same<AtenOpT, AtenGtTensorOp>() ||
|
|
|
|
std::is_same<AtenOpT, AtenGtScalarOp>()) {
|
2022-08-31 03:44:00 +08:00
|
|
|
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
|
|
|
|
op->getContext(), chlo::ComparisonDirection::GT);
|
2024-05-10 17:07:37 +08:00
|
|
|
} else if constexpr (std::is_same<AtenOpT, AtenGeTensorOp>() ||
|
|
|
|
std::is_same<AtenOpT, AtenGeScalarOp>()) {
|
2022-11-24 14:28:34 +08:00
|
|
|
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
|
|
|
|
op->getContext(), chlo::ComparisonDirection::GE);
|
2024-05-10 17:07:37 +08:00
|
|
|
} else if constexpr (std::is_same<AtenOpT, AtenEqTensorOp>() ||
|
|
|
|
std::is_same<AtenOpT, AtenEqScalarOp>()) {
|
2022-08-31 03:44:00 +08:00
|
|
|
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
|
|
|
|
op->getContext(), chlo::ComparisonDirection::EQ);
|
2024-05-10 17:07:37 +08:00
|
|
|
} else if constexpr (std::is_same<AtenOpT, AtenNeTensorOp>() ||
|
|
|
|
std::is_same<AtenOpT, AtenNeScalarOp>()) {
|
2022-08-31 03:44:00 +08:00
|
|
|
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
|
|
|
|
op->getContext(), chlo::ComparisonDirection::NE);
|
2024-05-10 17:07:37 +08:00
|
|
|
} else if constexpr (std::is_same<AtenOpT, AtenLtTensorOp>() ||
|
|
|
|
std::is_same<AtenOpT, AtenLtScalarOp>()) {
|
2022-11-24 14:28:34 +08:00
|
|
|
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
|
|
|
|
op->getContext(), chlo::ComparisonDirection::LT);
|
2024-05-10 17:07:37 +08:00
|
|
|
} else if constexpr (std::is_same<AtenOpT, AtenLeTensorOp>() ||
|
|
|
|
std::is_same<AtenOpT, AtenLeScalarOp>()) {
|
2022-11-24 14:28:34 +08:00
|
|
|
compareDirectionAttr = chlo::ComparisonDirectionAttr::get(
|
|
|
|
op->getContext(), chlo::ComparisonDirection::LE);
|
|
|
|
} else {
|
|
|
|
return op.emitError("operator haven't been supported");
|
2022-07-27 13:07:51 +08:00
|
|
|
}
|
Bump stablehlo to openxla/stablehlo@fd52182f76cadb82f2064fe5fc49a4fb4347a826 (#2821)
With the recent LLVM integrate and changes from
https://github.com/llvm/llvm-project/pull/78260, we hit this build error
in Stablehlo (which is quite old).
```
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1020:14: error: no member named 'startRootUpdate' in 'mlir::PatternRewriter'
rewriter.startRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1026:16: error: no member named 'finalizeRootUpdate' in 'mlir::PatternRewriter'
rewriter.finalizeRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1029:16: error: no member named 'cancelRootUpdate' in 'mlir::PatternRewriter'
rewriter.cancelRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1108:14: error: no member named 'updateRootInPlace' in 'mlir::PatternRewriter'
rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; });
~~~~~~~~ ^
4 errors generated.
Target @torch-mlir//:torch-mlir-opt failed to build
```
I'm still puzzled as to how this didn't fail with the CMake merge gating
CI (do we not test Stablehlo builds/tests?). In any case, bumping our
submodule to https://github.com/openxla/stablehlo/pull/1918 fixes it.
It exposes a new failing lit test in TorchToStablehlo though, that I
have looped stablehlo developers into
([here](https://discord.com/channels/999073994483433573/999074539138990131/1201235845391331419)).
```
bazel run @torch-mlir//test/Conversion:TorchToStablehlo/scatter.mlir.test
...external/torch-mlir/test/Conversion/TorchToStablehlo/scatter.mlir
within split at <stdin>:1 offset :33:8: error: unexpected error: Expects non-empty reduction block for type inference
%0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64>
^
LLVM ERROR: Failed to infer result type(s).
```
Bazel CI:
https://github.com/sjain-stanford/torch-mlir/actions/runs/7732673480/job/21083102228
2024-02-01 06:21:17 +08:00
|
|
|
DenseI64ArrayAttr bcastDimensions;
|
2022-08-02 12:53:24 +08:00
|
|
|
rewriter.replaceOpWithNewOp<chlo::BroadcastCompareOp>(
|
|
|
|
op, outType, lhs, rhs, bcastDimensions, compareDirectionAttr,
|
2022-07-27 13:07:51 +08:00
|
|
|
compareTypeAttr);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
2023-01-04 10:11:25 +08:00
|
|
|
} // namespace
|
|
|
|
|
|
|
|
// Binary op legalizations for Logical And/Or/Xor.
|
|
|
|
namespace {
|
|
|
|
template <typename AtenOpT, typename ChloOpT>
|
|
|
|
class ConvertAtenLogicalBinaryOp : public OpConversionPattern<AtenOpT> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
|
|
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
2022-07-27 13:07:51 +08:00
|
|
|
|
2023-01-04 10:11:25 +08:00
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
2024-04-08 20:24:17 +08:00
|
|
|
Value lhs = adaptor.getSelf();
|
|
|
|
Value rhs = adaptor.getOther();
|
|
|
|
|
2024-04-28 05:00:56 +08:00
|
|
|
RankedTensorType lhsTy = dyn_cast<RankedTensorType>(lhs.getType());
|
|
|
|
RankedTensorType rhsTy = dyn_cast<RankedTensorType>(rhs.getType());
|
2024-04-08 20:24:17 +08:00
|
|
|
|
|
|
|
if (!lhsTy)
|
|
|
|
return op.emitError("lhs must be a ranked tensor type");
|
|
|
|
|
2023-01-04 10:11:25 +08:00
|
|
|
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
|
|
|
->convertType(op.getType())
|
|
|
|
.template cast<TensorType>();
|
2024-04-08 20:24:17 +08:00
|
|
|
Type outElemTy = outType.getElementType();
|
|
|
|
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
|
|
|
|
if (!rhsTy) {
|
|
|
|
rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy);
|
|
|
|
}
|
|
|
|
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
|
2023-01-04 10:11:25 +08:00
|
|
|
|
Bump stablehlo to openxla/stablehlo@fd52182f76cadb82f2064fe5fc49a4fb4347a826 (#2821)
With the recent LLVM integrate and changes from
https://github.com/llvm/llvm-project/pull/78260, we hit this build error
in Stablehlo (which is quite old).
```
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1020:14: error: no member named 'startRootUpdate' in 'mlir::PatternRewriter'
rewriter.startRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1026:16: error: no member named 'finalizeRootUpdate' in 'mlir::PatternRewriter'
rewriter.finalizeRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1029:16: error: no member named 'cancelRootUpdate' in 'mlir::PatternRewriter'
rewriter.cancelRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1108:14: error: no member named 'updateRootInPlace' in 'mlir::PatternRewriter'
rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; });
~~~~~~~~ ^
4 errors generated.
Target @torch-mlir//:torch-mlir-opt failed to build
```
I'm still puzzled as to how this didn't fail with the CMake merge gating
CI (do we not test Stablehlo builds/tests?). In any case, bumping our
submodule to https://github.com/openxla/stablehlo/pull/1918 fixes it.
It exposes a new failing lit test in TorchToStablehlo though, that I
have looped stablehlo developers into
([here](https://discord.com/channels/999073994483433573/999074539138990131/1201235845391331419)).
```
bazel run @torch-mlir//test/Conversion:TorchToStablehlo/scatter.mlir.test
...external/torch-mlir/test/Conversion/TorchToStablehlo/scatter.mlir
within split at <stdin>:1 offset :33:8: error: unexpected error: Expects non-empty reduction block for type inference
%0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64>
^
LLVM ERROR: Failed to infer result type(s).
```
Bazel CI:
https://github.com/sjain-stanford/torch-mlir/actions/runs/7732673480/job/21083102228
2024-02-01 06:21:17 +08:00
|
|
|
DenseI64ArrayAttr bcastDimensions;
|
2023-01-04 10:11:25 +08:00
|
|
|
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
|
|
|
|
bcastDimensions);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
2022-07-27 13:07:51 +08:00
|
|
|
} // namespace
|
|
|
|
|
|
|
|
// AtenTransposeIntOp
|
|
|
|
namespace {
|
|
|
|
class ConvertAtenTransposeIntOp
|
|
|
|
: public OpConversionPattern<AtenTransposeIntOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenTransposeIntOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value self = adaptor.getSelf();
|
2022-07-27 13:07:51 +08:00
|
|
|
int64_t dim0;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getDim0(), m_TorchConstantInt(&dim0))) {
|
2022-07-27 13:07:51 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "dim0 must be constant");
|
|
|
|
}
|
|
|
|
int64_t dim1;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1))) {
|
2022-07-27 13:07:51 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "dim1 must be constant");
|
|
|
|
}
|
|
|
|
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inType = cast<RankedTensorType>(self.getType());
|
2022-07-27 13:07:51 +08:00
|
|
|
auto inputRank = inType.getRank();
|
2024-04-28 05:00:56 +08:00
|
|
|
auto outType = cast<RankedTensorType>(
|
|
|
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
2022-07-27 13:07:51 +08:00
|
|
|
|
|
|
|
dim0 = toPositiveDim(dim0, inputRank);
|
|
|
|
if (!isValidDim(dim0, inputRank)) {
|
|
|
|
return rewriter.notifyMatchFailure(op, "dim0 out of range");
|
|
|
|
}
|
|
|
|
dim1 = toPositiveDim(dim1, inputRank);
|
|
|
|
if (!isValidDim(dim1, inputRank)) {
|
|
|
|
return rewriter.notifyMatchFailure(op, "dim1 out of range");
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<int64_t> permValues(inputRank);
|
|
|
|
std::iota(std::begin(permValues), std::end(permValues), 0);
|
|
|
|
std::swap(permValues[dim0], permValues[dim1]);
|
2023-02-02 21:29:47 +08:00
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::TransposeOp>(op, outType, self,
|
2023-12-08 15:13:42 +08:00
|
|
|
permValues);
|
2022-07-27 13:07:51 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2022-09-23 10:24:36 +08:00
|
|
|
// AtenToDtypeOp
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
|
|
|
|
AtenToDtypeOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value self = adaptor.getSelf();
|
2022-09-23 10:24:36 +08:00
|
|
|
auto outType =
|
2024-04-28 05:00:56 +08:00
|
|
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
2023-02-02 21:29:47 +08:00
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(op, outType, self);
|
2022-09-23 10:24:36 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenSizeIntOp>::matchAndRewrite(
|
|
|
|
AtenSizeIntOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
// Not a tensor type.
|
2024-04-28 05:00:56 +08:00
|
|
|
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
2022-09-23 10:24:36 +08:00
|
|
|
if (!selfType)
|
|
|
|
return op.emitError("only tensor types are currently supported");
|
2022-11-21 21:50:35 +08:00
|
|
|
|
|
|
|
Value dim;
|
|
|
|
int64_t dimInt;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (matchPattern(op.getDim(), m_TorchConstantInt(&dimInt))) {
|
2022-11-21 21:50:35 +08:00
|
|
|
dimInt = toPositiveDim(dimInt, selfType.getRank());
|
2023-04-07 19:49:35 +08:00
|
|
|
if (!isValidDim(dimInt, selfType.getRank()))
|
|
|
|
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
2022-11-21 21:50:35 +08:00
|
|
|
dim = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), dimInt);
|
|
|
|
} else {
|
|
|
|
Value inputRank = rewriter.create<arith::ConstantOp>(
|
|
|
|
op.getLoc(), rewriter.getI64IntegerAttr(selfType.getRank()));
|
2023-02-02 21:29:47 +08:00
|
|
|
dim = toPositiveDimDynamic(rewriter, op.getLoc(), adaptor.getDim(),
|
|
|
|
inputRank);
|
2022-11-21 21:50:35 +08:00
|
|
|
dim = rewriter.create<arith::IndexCastOp>(op.getLoc(),
|
|
|
|
rewriter.getIndexType(), dim);
|
|
|
|
}
|
|
|
|
|
2022-09-23 10:24:36 +08:00
|
|
|
auto dimSize = rewriter.create<tensor::DimOp>(
|
2022-12-08 04:20:41 +08:00
|
|
|
op.getLoc(), rewriter.getIndexType(), adaptor.getSelf(), dim);
|
2022-09-23 10:24:36 +08:00
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
|
|
|
|
op, getTypeConverter()->convertType(op.getType()), dimSize);
|
2022-11-21 21:50:35 +08:00
|
|
|
|
2022-09-23 10:24:36 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-11-24 14:28:34 +08:00
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
|
2023-02-02 21:29:47 +08:00
|
|
|
AtenWhereSelfOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value self = adaptor.getSelf();
|
|
|
|
Value cond = adaptor.getCondition();
|
|
|
|
Value other = adaptor.getOther();
|
2022-11-24 14:28:34 +08:00
|
|
|
|
2023-05-06 06:21:55 +08:00
|
|
|
auto outType =
|
2024-04-28 05:00:56 +08:00
|
|
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
2023-05-06 06:21:55 +08:00
|
|
|
// promote self and other types
|
2023-06-26 00:04:17 +08:00
|
|
|
self = hlo::promoteType(rewriter, op.getLoc(), self, outType);
|
|
|
|
other = hlo::promoteType(rewriter, op.getLoc(), other, outType);
|
2023-05-06 06:21:55 +08:00
|
|
|
|
2022-11-24 14:28:34 +08:00
|
|
|
if (failed(
|
|
|
|
broadcastRanks(rewriter, op, self, cond, options.dimSizeIndexBits)))
|
|
|
|
return op.emitError("failed broadcast self and condition ranks");
|
|
|
|
|
|
|
|
if (failed(
|
|
|
|
broadcastRanks(rewriter, op, other, cond, options.dimSizeIndexBits)))
|
|
|
|
return op.emitError("failed broadcast other and condition ranks");
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<chlo::BroadcastSelectOp>(
|
2023-02-02 21:29:47 +08:00
|
|
|
op, getTypeConverter()->convertType(op.getType()),
|
2022-11-24 14:28:34 +08:00
|
|
|
ArrayRef<Value>{cond, self, other});
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-07-27 13:07:51 +08:00
|
|
|
// AtenBroadcastToOp
|
2022-09-01 10:36:02 +08:00
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
|
|
|
|
AtenBroadcastToOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value self = adaptor.getSelf();
|
2024-04-28 05:00:56 +08:00
|
|
|
auto selfTy = cast<RankedTensorType>(self.getType());
|
|
|
|
auto outType = cast<RankedTensorType>(
|
|
|
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
2022-09-01 10:36:02 +08:00
|
|
|
|
|
|
|
if (options.enableStaticShape && selfTy.hasStaticShape()) {
|
2023-02-02 21:29:47 +08:00
|
|
|
Value bcastOp = hlo::promoteAndBroadcast(rewriter, self, outType);
|
2022-09-01 10:36:02 +08:00
|
|
|
rewriter.replaceOp(op, bcastOp);
|
|
|
|
return success();
|
|
|
|
}
|
2022-08-02 12:53:24 +08:00
|
|
|
|
2022-09-01 10:36:02 +08:00
|
|
|
SmallVector<Value> shape;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!(getListConstructElements(adaptor.getSize(), shape))) {
|
2022-09-01 10:36:02 +08:00
|
|
|
return op->emitError("desired shape must be a list of scalar");
|
|
|
|
}
|
|
|
|
SmallVector<Value> bcastShapeVec;
|
|
|
|
int64_t totalRank = shape.size();
|
|
|
|
int64_t selfRank = selfTy.getRank();
|
|
|
|
int64_t leadingRank = totalRank - selfRank;
|
|
|
|
|
|
|
|
for (int64_t i = 0; i < totalRank; ++i) {
|
|
|
|
Value dValue = shape[i];
|
|
|
|
Value newD;
|
|
|
|
int64_t dInt;
|
2022-09-23 10:24:36 +08:00
|
|
|
if (i >= leadingRank && matchPattern(dValue, m_TorchConstantInt(&dInt)) &&
|
|
|
|
dInt == -1) {
|
2022-09-01 10:36:02 +08:00
|
|
|
newD = rewriter.create<mlir::tensor::DimOp>(op->getLoc(), self,
|
|
|
|
i - leadingRank);
|
|
|
|
} else {
|
|
|
|
dValue = rewriter.create<torch::TorchConversion::ToI64Op>(op->getLoc(),
|
|
|
|
dValue);
|
|
|
|
newD = rewriter.create<mlir::arith::IndexCastOp>(
|
|
|
|
op->getLoc(), rewriter.getIndexType(), dValue);
|
2022-08-02 12:53:24 +08:00
|
|
|
}
|
2022-09-01 10:36:02 +08:00
|
|
|
bcastShapeVec.push_back(newD);
|
|
|
|
}
|
2022-08-02 12:53:24 +08:00
|
|
|
|
2022-09-01 10:36:02 +08:00
|
|
|
if (options.dimSizeIndexBits == 32) {
|
2022-08-02 12:53:24 +08:00
|
|
|
for (auto &dsize : bcastShapeVec) {
|
|
|
|
auto dsizeI64 = rewriter.create<mlir::arith::IndexCastOp>(
|
|
|
|
op->getLoc(), rewriter.getI64Type(), dsize);
|
|
|
|
dsize = rewriter.create<arith::TruncIOp>(op->getLoc(),
|
|
|
|
rewriter.getI32Type(), dsizeI64);
|
|
|
|
}
|
2022-09-01 10:36:02 +08:00
|
|
|
}
|
2022-08-02 12:53:24 +08:00
|
|
|
|
2022-09-23 10:24:36 +08:00
|
|
|
if (bcastShapeVec.size() == 0) {
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outType, self);
|
|
|
|
} else {
|
2022-08-02 12:53:24 +08:00
|
|
|
Value bcastShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
|
|
|
op->getLoc(), ValueRange{bcastShapeVec});
|
|
|
|
auto dimensionNumbers =
|
|
|
|
llvm::to_vector<4>(llvm::seq<int64_t>(leadingRank, totalRank));
|
2023-02-02 21:29:47 +08:00
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::DynamicBroadcastInDimOp>(
|
2022-08-02 12:53:24 +08:00
|
|
|
op, outType, self, bcastShapeTensor,
|
Bump stablehlo to openxla/stablehlo@fd52182f76cadb82f2064fe5fc49a4fb4347a826 (#2821)
With the recent LLVM integrate and changes from
https://github.com/llvm/llvm-project/pull/78260, we hit this build error
in Stablehlo (which is quite old).
```
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1020:14: error: no member named 'startRootUpdate' in 'mlir::PatternRewriter'
rewriter.startRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1026:16: error: no member named 'finalizeRootUpdate' in 'mlir::PatternRewriter'
rewriter.finalizeRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1029:16: error: no member named 'cancelRootUpdate' in 'mlir::PatternRewriter'
rewriter.cancelRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1108:14: error: no member named 'updateRootInPlace' in 'mlir::PatternRewriter'
rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; });
~~~~~~~~ ^
4 errors generated.
Target @torch-mlir//:torch-mlir-opt failed to build
```
I'm still puzzled as to how this didn't fail with the CMake merge gating
CI (do we not test Stablehlo builds/tests?). In any case, bumping our
submodule to https://github.com/openxla/stablehlo/pull/1918 fixes it.
It exposes a new failing lit test in TorchToStablehlo though, that I
have looped stablehlo developers into
([here](https://discord.com/channels/999073994483433573/999074539138990131/1201235845391331419)).
```
bazel run @torch-mlir//test/Conversion:TorchToStablehlo/scatter.mlir.test
...external/torch-mlir/test/Conversion/TorchToStablehlo/scatter.mlir
within split at <stdin>:1 offset :33:8: error: unexpected error: Expects non-empty reduction block for type inference
%0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64>
^
LLVM ERROR: Failed to infer result type(s).
```
Bazel CI:
https://github.com/sjain-stanford/torch-mlir/actions/runs/7732673480/job/21083102228
2024-02-01 06:21:17 +08:00
|
|
|
rewriter.getDenseI64ArrayAttr(dimensionNumbers));
|
2022-09-23 10:24:36 +08:00
|
|
|
}
|
|
|
|
return success();
|
2022-09-01 10:36:02 +08:00
|
|
|
}
|
2022-07-27 13:07:51 +08:00
|
|
|
|
|
|
|
// AtenPermuteOp
|
2022-09-01 10:36:02 +08:00
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
|
|
|
|
AtenPermuteOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value self = adaptor.getSelf();
|
2022-09-01 10:36:02 +08:00
|
|
|
// Not a ranked tensor type
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inType = dyn_cast<RankedTensorType>(self.getType());
|
|
|
|
auto outType = cast<RankedTensorType>(
|
|
|
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
2022-09-01 10:36:02 +08:00
|
|
|
if (!inType)
|
|
|
|
return op.emitError("only ranked tensor types with static shapes are "
|
|
|
|
"currently supported");
|
|
|
|
|
|
|
|
SmallVector<int64_t> permValues;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(permValues)))
|
2022-09-01 10:36:02 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "only constant dimensions are currently supported");
|
2022-07-27 13:07:51 +08:00
|
|
|
|
2022-09-01 10:36:02 +08:00
|
|
|
int64_t inRank = inType.getRank();
|
|
|
|
for (auto &d : permValues) {
|
|
|
|
d = toPositiveDim(d, inRank);
|
|
|
|
if (!isValidDim(d, inRank))
|
|
|
|
return op.emitError("not all dims are valid");
|
2022-07-27 13:07:51 +08:00
|
|
|
}
|
2022-07-21 07:18:16 +08:00
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::TransposeOp>(op, outType, self,
|
2023-12-08 15:13:42 +08:00
|
|
|
permValues);
|
2022-09-01 10:36:02 +08:00
|
|
|
return success();
|
|
|
|
}
|
2022-07-21 07:18:16 +08:00
|
|
|
|
2022-07-27 13:07:51 +08:00
|
|
|
// ValueTensorLiteralOp
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
|
|
|
|
ValueTensorLiteralOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
2024-04-28 05:00:56 +08:00
|
|
|
RankedTensorType resultType = cast<RankedTensorType>(
|
|
|
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
2022-07-27 13:07:51 +08:00
|
|
|
|
|
|
|
// Tensors with integer types need to be converted to signless integer
|
|
|
|
// element type. All tensors with element types other than integer can reuse
|
|
|
|
// existing elements attribute.
|
2022-08-02 12:53:24 +08:00
|
|
|
// TODO: what about unsigned integer?
|
2024-04-28 05:00:56 +08:00
|
|
|
if (auto elements = dyn_cast<DenseIntElementsAttr>(op.getValueAttr())) {
|
2022-07-27 13:07:51 +08:00
|
|
|
Type builtinTensorElemTy = resultType.getElementType();
|
|
|
|
unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth();
|
|
|
|
|
|
|
|
DenseElementsAttr valueAttr =
|
|
|
|
elements.mapValues(builtinTensorElemTy, [&](const APInt &v) {
|
|
|
|
return APInt(bitWidth, v.getSExtValue());
|
|
|
|
});
|
2023-02-02 21:29:47 +08:00
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, resultType,
|
|
|
|
valueAttr);
|
2022-07-27 13:07:51 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(op, resultType,
|
|
|
|
adaptor.getValue());
|
2022-07-27 13:07:51 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2023-07-20 16:46:44 +08:00
|
|
|
// AtenTensorIntOp
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenTensorIntOp>::matchAndRewrite(
|
|
|
|
AtenTensorIntOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
2024-04-28 05:00:56 +08:00
|
|
|
RankedTensorType resultType = cast<RankedTensorType>(
|
|
|
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
2023-07-20 16:46:44 +08:00
|
|
|
Type outElementType = resultType.getElementType();
|
|
|
|
Value innerValue = adaptor.getT();
|
|
|
|
Value stablehloTensor =
|
|
|
|
hlo::scalarToStablehloTensor(rewriter, op, innerValue, outElementType);
|
|
|
|
rewriter.replaceOp(op, stablehloTensor);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-07-27 13:07:51 +08:00
|
|
|
// AtenReciprocalOp
|
|
|
|
// Reciprocal(x) = Div(1, x)
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
|
|
|
|
AtenReciprocalOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = adaptor.getSelf();
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inputTy = cast<RankedTensorType>(input.getType());
|
2022-07-27 13:07:51 +08:00
|
|
|
auto outTy =
|
2024-04-28 05:00:56 +08:00
|
|
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
|
|
|
if (!isa<mlir::FloatType>(inputTy.getElementType())) {
|
2022-08-02 12:53:24 +08:00
|
|
|
return op.emitError("only floating-point datatype legalization supported "
|
2022-07-27 13:07:51 +08:00
|
|
|
"for AtenReciprocalOp");
|
|
|
|
}
|
2022-08-02 12:53:24 +08:00
|
|
|
|
2024-05-26 12:34:56 +08:00
|
|
|
Value oneTensor =
|
|
|
|
hlo::getConstantLike<int64_t>(rewriter, op->getLoc(), 1, input);
|
2023-02-02 21:29:47 +08:00
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(op, outTy, oneTensor, input);
|
2022-07-27 13:07:51 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2023-02-17 12:26:46 +08:00
|
|
|
// AtenPowTensorScalarOp
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
|
|
|
AtenPowTensorScalarOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
Value lhs = adaptor.getSelf();
|
2024-04-28 05:00:56 +08:00
|
|
|
auto lhsType = dyn_cast<TensorType>(lhs.getType());
|
2023-02-17 12:26:46 +08:00
|
|
|
Value rhs = adaptor.getExponent();
|
2024-04-28 05:00:56 +08:00
|
|
|
TensorType rhsType = dyn_cast<TensorType>(rhs.getType());
|
2023-02-17 12:26:46 +08:00
|
|
|
|
|
|
|
if (!lhsType)
|
|
|
|
return op.emitError("only Tensor types supported in StableHLO");
|
|
|
|
|
|
|
|
auto outType = OpConversionPattern<AtenPowTensorScalarOp>::getTypeConverter()
|
|
|
|
->convertType(op.getType())
|
|
|
|
.template cast<TensorType>();
|
|
|
|
|
|
|
|
Type outElemTy = outType.getElementType();
|
|
|
|
if (!outElemTy.isIntOrFloat()) {
|
|
|
|
return op.emitError(
|
|
|
|
"only floating-point or integer datatype legalization supported");
|
|
|
|
}
|
|
|
|
|
|
|
|
if (!rhsType) {
|
2023-05-25 10:32:55 +08:00
|
|
|
rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy);
|
2023-02-17 12:26:46 +08:00
|
|
|
}
|
Bump stablehlo to openxla/stablehlo@fd52182f76cadb82f2064fe5fc49a4fb4347a826 (#2821)
With the recent LLVM integrate and changes from
https://github.com/llvm/llvm-project/pull/78260, we hit this build error
in Stablehlo (which is quite old).
```
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1020:14: error: no member named 'startRootUpdate' in 'mlir::PatternRewriter'
rewriter.startRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1026:16: error: no member named 'finalizeRootUpdate' in 'mlir::PatternRewriter'
rewriter.finalizeRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1029:16: error: no member named 'cancelRootUpdate' in 'mlir::PatternRewriter'
rewriter.cancelRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1108:14: error: no member named 'updateRootInPlace' in 'mlir::PatternRewriter'
rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; });
~~~~~~~~ ^
4 errors generated.
Target @torch-mlir//:torch-mlir-opt failed to build
```
I'm still puzzled as to how this didn't fail with the CMake merge gating
CI (do we not test Stablehlo builds/tests?). In any case, bumping our
submodule to https://github.com/openxla/stablehlo/pull/1918 fixes it.
It exposes a new failing lit test in TorchToStablehlo though, that I
have looped stablehlo developers into
([here](https://discord.com/channels/999073994483433573/999074539138990131/1201235845391331419)).
```
bazel run @torch-mlir//test/Conversion:TorchToStablehlo/scatter.mlir.test
...external/torch-mlir/test/Conversion/TorchToStablehlo/scatter.mlir
within split at <stdin>:1 offset :33:8: error: unexpected error: Expects non-empty reduction block for type inference
%0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64>
^
LLVM ERROR: Failed to infer result type(s).
```
Bazel CI:
https://github.com/sjain-stanford/torch-mlir/actions/runs/7732673480/job/21083102228
2024-02-01 06:21:17 +08:00
|
|
|
DenseI64ArrayAttr bcastDimensions;
|
2023-06-26 00:04:17 +08:00
|
|
|
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
|
|
|
|
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
|
2023-02-17 12:26:46 +08:00
|
|
|
auto loc = op.getLoc();
|
2023-05-25 10:32:55 +08:00
|
|
|
Value result = rewriter.create<chlo::BroadcastPowOp>(loc, outType, lhs, rhs,
|
|
|
|
bcastDimensions);
|
2023-02-17 12:26:46 +08:00
|
|
|
|
|
|
|
rewriter.replaceOp(op, result);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2024-04-26 15:47:44 +08:00
|
|
|
// AtenPowScalarOp
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenPowScalarOp>::matchAndRewrite(
|
|
|
|
AtenPowScalarOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
Value lhs = adaptor.getSelf();
|
|
|
|
auto lhsType = dyn_cast<TensorType>(lhs.getType());
|
|
|
|
Value rhs = adaptor.getExponent();
|
|
|
|
auto rhsType = dyn_cast<TensorType>(rhs.getType());
|
|
|
|
|
|
|
|
if (!rhsType)
|
|
|
|
return op.emitError("only Tensor types supported in StableHLO");
|
|
|
|
|
|
|
|
auto outType = cast<TensorType>(
|
|
|
|
OpConversionPattern<AtenPowScalarOp>::getTypeConverter()->convertType(
|
|
|
|
op.getType()));
|
|
|
|
|
|
|
|
Type outElemTy = outType.getElementType();
|
|
|
|
if (!outElemTy.isIntOrFloat()) {
|
|
|
|
return op.emitError(
|
|
|
|
"only floating-point or integer datatype legalization supported");
|
|
|
|
}
|
|
|
|
|
|
|
|
if (!lhsType) {
|
|
|
|
lhs = hlo::scalarToStablehloTensor(rewriter, op, lhs, outElemTy);
|
|
|
|
}
|
|
|
|
DenseI64ArrayAttr bcastDimensions;
|
|
|
|
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
|
|
|
|
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
|
|
|
|
auto loc = op.getLoc();
|
|
|
|
Value result = rewriter.create<chlo::BroadcastPowOp>(loc, outType, lhs, rhs,
|
|
|
|
bcastDimensions);
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, result);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-07-27 13:07:51 +08:00
|
|
|
// PrimNumToTensorScalarOp
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
|
|
|
|
PrimNumToTensorScalarOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
2024-04-28 05:00:56 +08:00
|
|
|
RankedTensorType outputType = cast<RankedTensorType>(
|
|
|
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
2022-07-27 13:07:51 +08:00
|
|
|
auto outputElemType = outputType.getElementType();
|
2023-02-02 21:29:47 +08:00
|
|
|
Value stablehloTensor = hlo::scalarToStablehloTensor(
|
|
|
|
rewriter, op, adaptor.getA(), outputElemType);
|
|
|
|
rewriter.replaceOp(op, stablehloTensor);
|
2022-07-27 13:07:51 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2023-09-04 14:04:09 +08:00
|
|
|
// AtenScalarImplicitOp
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenScalarImplicitOp>::matchAndRewrite(
|
|
|
|
AtenScalarImplicitOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
2023-09-06 01:57:15 +08:00
|
|
|
Location loc = op.getLoc();
|
2024-04-28 05:00:56 +08:00
|
|
|
Type inputDtype = cast<BaseTensorType>(op.getA().getType()).getDtype();
|
2023-09-06 01:57:15 +08:00
|
|
|
Type resultType =
|
|
|
|
this->getTypeConverter()->convertType(op->getResult(0).getType());
|
2024-01-30 01:59:33 +08:00
|
|
|
auto result = rewriter.create<tensor::ExtractOp>(loc, adaptor.getA());
|
2023-09-06 01:57:15 +08:00
|
|
|
|
|
|
|
rewriter.replaceOp(
|
|
|
|
op, convertScalarToDtype(rewriter, loc, result, resultType, inputDtype));
|
2023-09-04 14:04:09 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-07-27 13:07:51 +08:00
|
|
|
// AtenContiguousOp
|
|
|
|
// Ref: TosaToTosa.cpp for implementation details
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenContiguousOp>::matchAndRewrite(
|
|
|
|
AtenContiguousOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
|
|
|
// Not a tensor type.
|
2024-04-28 05:00:56 +08:00
|
|
|
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
2022-07-27 13:07:51 +08:00
|
|
|
if (!selfType)
|
2022-08-02 12:53:24 +08:00
|
|
|
return op.emitError("only tensor types are currently supported");
|
2022-07-27 13:07:51 +08:00
|
|
|
|
|
|
|
// FIXME: memory_format is not handled.
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
rewriter.replaceOp(op, adaptor.getSelf());
|
2022-07-27 13:07:51 +08:00
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// AtenReluOp
|
|
|
|
// Relu(x) = Max(0, x)
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
|
|
|
|
AtenReluOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value lhs = adaptor.getSelf();
|
2024-04-28 05:00:56 +08:00
|
|
|
auto lhsTy = cast<RankedTensorType>(lhs.getType());
|
2022-07-27 13:07:51 +08:00
|
|
|
auto lhsElemTy = lhsTy.getElementType();
|
|
|
|
|
2024-04-11 21:47:35 +08:00
|
|
|
if (!isa<mlir::FloatType>(lhsElemTy)) {
|
2022-08-02 12:53:24 +08:00
|
|
|
return op->emitError("only float tensor in relu op is supported");
|
2022-07-27 13:07:51 +08:00
|
|
|
}
|
|
|
|
|
2024-05-26 12:34:56 +08:00
|
|
|
Value zeroTensor =
|
|
|
|
hlo::getConstantLike<int64_t>(rewriter, op->getLoc(), 0, lhs);
|
2023-02-02 21:29:47 +08:00
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::MaxOp>(op, lhs, zeroTensor);
|
2022-07-27 13:07:51 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-08-02 15:01:30 +08:00
|
|
|
// Convert a Aten::GELU to HLO
|
2024-05-09 11:39:13 +08:00
|
|
|
// Gelu(x, "none") = x * 0.5 * (1 + erf(x/(sqrt(2))))
|
|
|
|
// Gelu(x, "tanh") = x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
2022-08-02 15:01:30 +08:00
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
|
2022-08-04 12:34:22 +08:00
|
|
|
AtenGeluOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
2022-08-02 15:01:30 +08:00
|
|
|
Location loc = op.getLoc();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = adaptor.getSelf();
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
2022-08-02 15:01:30 +08:00
|
|
|
if (!inputTy) {
|
|
|
|
return op.emitError("only ranked tensor type is supported.");
|
|
|
|
}
|
|
|
|
|
2024-05-09 11:39:13 +08:00
|
|
|
std::string approximate;
|
|
|
|
if (!matchPattern(op.getApproximate(), m_TorchConstantStr(approximate))) {
|
|
|
|
return op.emitError("approximate must be constant string");
|
|
|
|
}
|
|
|
|
if (approximate != "none" && approximate != "tanh") {
|
|
|
|
return op.emitError("unsupported approximate: ") << approximate;
|
|
|
|
}
|
|
|
|
|
2024-05-26 12:34:56 +08:00
|
|
|
Value one = hlo::getConstantLike(rewriter, loc, 1.0, input);
|
|
|
|
Value two = hlo::getConstantLike(rewriter, loc, 2.0, input);
|
|
|
|
Value three = hlo::getConstantLike(rewriter, loc, 3.0, input);
|
|
|
|
Value half = hlo::getConstantLike(rewriter, loc, 0.5, input);
|
2024-05-09 11:39:13 +08:00
|
|
|
// 2/pi
|
2024-05-26 12:34:56 +08:00
|
|
|
Value twoDivPi = hlo::getConstantLike(rewriter, loc, M_2_PI, input);
|
|
|
|
Value t = hlo::getConstantLike(rewriter, loc, 0.044715, input);
|
2024-05-09 11:39:13 +08:00
|
|
|
|
|
|
|
// x * 0.5
|
|
|
|
auto inputMulHalf = rewriter.create<stablehlo::MulOp>(loc, input, half);
|
|
|
|
if (approximate == "none") {
|
|
|
|
auto rsqrtTwo = rewriter.create<stablehlo::RsqrtOp>(loc, two);
|
|
|
|
auto erfElement = rewriter.create<stablehlo::MulOp>(loc, input, rsqrtTwo);
|
|
|
|
auto erf = rewriter.create<chlo::ErfOp>(loc, erfElement);
|
|
|
|
auto erfAdd = rewriter.create<stablehlo::AddOp>(loc, erf, one);
|
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::MulOp>(op, erfAdd, inputMulHalf);
|
|
|
|
return success();
|
|
|
|
} else {
|
|
|
|
auto sqrtTwoPi = rewriter.create<stablehlo::SqrtOp>(loc, twoDivPi);
|
|
|
|
// x^3
|
|
|
|
auto powThree = rewriter.create<stablehlo::PowOp>(loc, input, three);
|
|
|
|
// x + 0.044715 * x^3
|
|
|
|
auto add = rewriter.create<stablehlo::AddOp>(
|
|
|
|
loc, input, rewriter.create<stablehlo::MulOp>(loc, t, powThree));
|
|
|
|
auto tanh = rewriter.create<stablehlo::TanhOp>(
|
|
|
|
loc, rewriter.create<stablehlo::MulOp>(loc, sqrtTwoPi, add));
|
|
|
|
auto tanhAdd = rewriter.create<stablehlo::AddOp>(loc, tanh, one);
|
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::MulOp>(op, tanhAdd, inputMulHalf);
|
|
|
|
return success();
|
|
|
|
}
|
2022-08-02 15:01:30 +08:00
|
|
|
}
|
|
|
|
|
2024-04-23 19:06:55 +08:00
|
|
|
// AtenLog2Op
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
|
|
|
|
AtenLog2Op op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
Value input = adaptor.getSelf();
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
2024-04-23 19:06:55 +08:00
|
|
|
if (!inputTy) {
|
|
|
|
return op.emitError("only ranked tensor type is supported.");
|
|
|
|
}
|
2024-04-28 05:00:56 +08:00
|
|
|
auto outTy = cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
2024-04-23 19:06:55 +08:00
|
|
|
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
|
|
|
|
|
2024-05-26 12:34:56 +08:00
|
|
|
auto two = hlo::getConstantLike(rewriter, op.getLoc(), 2.0, input);
|
2024-04-23 19:06:55 +08:00
|
|
|
auto log2Op = rewriter.create<stablehlo::LogOp>(op.getLoc(), two);
|
|
|
|
auto logInputOp = rewriter.create<stablehlo::LogOp>(op.getLoc(), input);
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(op, outTy, logInputOp, log2Op);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// AtenLog10Op
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
|
|
|
|
AtenLog10Op op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
Value input = adaptor.getSelf();
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
2024-04-23 19:06:55 +08:00
|
|
|
if (!inputTy) {
|
|
|
|
return op.emitError("only ranked tensor type is supported.");
|
|
|
|
}
|
|
|
|
|
2024-04-28 05:00:56 +08:00
|
|
|
auto outTy = cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
2024-04-23 19:06:55 +08:00
|
|
|
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
|
|
|
|
|
2024-05-26 12:34:56 +08:00
|
|
|
auto ten = hlo::getConstantLike(rewriter, op.getLoc(), 10.0, input);
|
2024-04-23 19:06:55 +08:00
|
|
|
auto log10Op = rewriter.create<stablehlo::LogOp>(op.getLoc(), ten);
|
|
|
|
auto logInputOp = rewriter.create<stablehlo::LogOp>(op.getLoc(), input);
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(op, outTy, logInputOp, log10Op);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-07-27 13:07:51 +08:00
|
|
|
// AtenErfOp
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenErfOp>::matchAndRewrite(
|
|
|
|
AtenErfOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = adaptor.getSelf();
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inputType = cast<TensorType>(input.getType());
|
|
|
|
if (!isa<mlir::FloatType>(inputType.getElementType())) {
|
2022-08-02 12:53:24 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "only float tensor is supported");
|
2022-07-27 13:07:51 +08:00
|
|
|
}
|
2022-08-02 12:53:24 +08:00
|
|
|
rewriter.replaceOpWithNewOp<chlo::ErfOp>(
|
|
|
|
op, getTypeConverter()->convertType(op.getType()), input);
|
2022-07-27 13:07:51 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// AtenBatchNormOp
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|
|
|
AtenBatchNormOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = adaptor.getInput();
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inputTy = cast<RankedTensorType>(input.getType());
|
2022-12-08 04:20:41 +08:00
|
|
|
Value weight = adaptor.getWeight();
|
|
|
|
Value bias = adaptor.getBias();
|
|
|
|
Value runningMean = adaptor.getRunningMean();
|
|
|
|
Value runningVar = adaptor.getRunningVar();
|
2022-07-27 13:07:51 +08:00
|
|
|
// momentum is ignored
|
2022-12-08 04:20:41 +08:00
|
|
|
Value momentum = adaptor.getMomentum();
|
2022-07-27 13:07:51 +08:00
|
|
|
(void)momentum;
|
|
|
|
|
2023-05-18 00:04:40 +08:00
|
|
|
// handle feature index, see torch's BatchNorm1d, BatchNorm2d, BatchNorm3d,
|
|
|
|
// all of NC, NCL, NCHW, NCDHW's feature index is 1.
|
|
|
|
int64_t feature_index = 1;
|
|
|
|
|
2024-04-28 05:00:56 +08:00
|
|
|
if (!isa<mlir::FloatType>(inputTy.getElementType())) {
|
2022-08-02 12:53:24 +08:00
|
|
|
return op.emitError("only input tensor of float type is supported");
|
|
|
|
}
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inputElemTy = cast<mlir::FloatType>(inputTy.getElementType());
|
2022-08-02 12:53:24 +08:00
|
|
|
|
2023-08-21 17:36:56 +08:00
|
|
|
Value channelDim =
|
|
|
|
rewriter.create<tensor::DimOp>(op->getLoc(), input, feature_index);
|
2022-08-02 12:53:24 +08:00
|
|
|
|
2022-09-01 10:36:02 +08:00
|
|
|
if (options.dimSizeIndexBits == 32) {
|
|
|
|
auto channelDimI64 = rewriter.create<mlir::arith::IndexCastOp>(
|
|
|
|
op->getLoc(), rewriter.getI64Type(), channelDim);
|
|
|
|
channelDim = rewriter.create<arith::TruncIOp>(
|
|
|
|
op->getLoc(), rewriter.getI32Type(), channelDimI64);
|
|
|
|
}
|
2022-08-02 12:53:24 +08:00
|
|
|
|
|
|
|
Value channelShape = rewriter.create<tensor::FromElementsOp>(
|
|
|
|
op->getLoc(), ValueRange{channelDim});
|
2022-07-27 13:07:51 +08:00
|
|
|
if (failed(checkNotNone(rewriter, op, weight))) {
|
2023-02-02 21:29:47 +08:00
|
|
|
weight = hlo::getConstantOfShape(
|
2022-08-02 12:53:24 +08:00
|
|
|
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1),
|
|
|
|
channelShape,
|
|
|
|
RankedTensorType::get({inputTy.getShape()[1]},
|
|
|
|
inputTy.getElementType()));
|
2022-07-27 13:07:51 +08:00
|
|
|
}
|
|
|
|
if (failed(checkNotNone(rewriter, op, bias))) {
|
2023-02-02 21:29:47 +08:00
|
|
|
bias = hlo::getConstantOfShape(
|
2022-08-02 12:53:24 +08:00
|
|
|
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0),
|
|
|
|
channelShape,
|
|
|
|
RankedTensorType::get({inputTy.getShape()[1]},
|
|
|
|
inputTy.getElementType()));
|
2022-07-27 13:07:51 +08:00
|
|
|
}
|
|
|
|
if (failed(checkNotNone(rewriter, op, runningVar))) {
|
2023-02-02 21:29:47 +08:00
|
|
|
runningVar = hlo::getConstantOfShape(
|
2022-08-02 12:53:24 +08:00
|
|
|
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1),
|
|
|
|
channelShape,
|
|
|
|
RankedTensorType::get({inputTy.getShape()[1]},
|
|
|
|
inputTy.getElementType()));
|
2022-07-27 13:07:51 +08:00
|
|
|
}
|
|
|
|
if (failed(checkNotNone(rewriter, op, runningMean))) {
|
2023-02-02 21:29:47 +08:00
|
|
|
runningMean = hlo::getConstantOfShape(
|
2022-08-02 12:53:24 +08:00
|
|
|
rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0),
|
|
|
|
channelShape,
|
|
|
|
RankedTensorType::get({inputTy.getShape()[1]},
|
|
|
|
inputTy.getElementType()));
|
2022-07-27 13:07:51 +08:00
|
|
|
}
|
|
|
|
|
2024-04-28 05:00:56 +08:00
|
|
|
auto weightTy = cast<RankedTensorType>(weight.getType());
|
|
|
|
auto biasTy = cast<RankedTensorType>(bias.getType());
|
|
|
|
auto runningMeanTy = cast<RankedTensorType>(runningMean.getType());
|
|
|
|
auto runningVarTy = cast<RankedTensorType>(runningVar.getType());
|
2022-08-02 12:53:24 +08:00
|
|
|
|
2022-07-27 13:07:51 +08:00
|
|
|
if (weightTy.getRank() != 1 || biasTy.getRank() != 1 ||
|
|
|
|
runningMeanTy.getRank() != 1 || runningVarTy.getRank() != 1) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "expect weight, bias, running_mean and running_var to be rank 1");
|
|
|
|
}
|
2024-04-28 05:00:56 +08:00
|
|
|
if (!isa<mlir::FloatType>(weightTy.getElementType()) ||
|
|
|
|
!isa<mlir::FloatType>(biasTy.getElementType()) ||
|
|
|
|
!isa<mlir::FloatType>(runningMeanTy.getElementType()) ||
|
|
|
|
!isa<mlir::FloatType>(runningVarTy.getElementType())) {
|
2022-08-02 12:53:24 +08:00
|
|
|
return op.emitError("only float weight/bias/runningMean/runningVar tensor "
|
|
|
|
"of float type is supported");
|
2022-07-27 13:07:51 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
double eps = 0.0;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getEps(), m_TorchConstantFloat(&eps))) {
|
2022-07-27 13:07:51 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "non-float(double) eps unsupported");
|
|
|
|
}
|
|
|
|
bool training = false;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) {
|
2022-07-27 13:07:51 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "non-bool training unsupported");
|
|
|
|
}
|
|
|
|
// TODO: handle cudnnEnabled parameter. Here, we just ignore it!
|
|
|
|
bool cudnnEnabled = false;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getCudnnEnabled(), m_TorchConstantBool(&cudnnEnabled))) {
|
2022-07-27 13:07:51 +08:00
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"non-bool cudnn_enabled unsupported");
|
|
|
|
}
|
|
|
|
if (training) {
|
|
|
|
Type outputTy = getTypeConverter()->convertType(op.getType());
|
|
|
|
Type batchMeanOrVarTy =
|
|
|
|
RankedTensorType::get(weightTy.getShape(), inputTy.getElementType());
|
2023-08-21 17:36:56 +08:00
|
|
|
|
|
|
|
Value output;
|
|
|
|
// supported mixed types, like input type is fp16 and weight type is fp32.
|
|
|
|
if (inputTy.getElementType() != weightTy.getElementType()) {
|
|
|
|
RankedTensorType convertedType = inputTy;
|
2024-04-28 05:00:56 +08:00
|
|
|
if (cast<FloatType>(weightTy.getElementType()).getWidth() >
|
|
|
|
cast<FloatType>(inputTy.getElementType()).getWidth()) {
|
2023-08-21 17:36:56 +08:00
|
|
|
convertedType = RankedTensorType::get(inputTy.getShape(),
|
|
|
|
weightTy.getElementType());
|
|
|
|
}
|
|
|
|
input = hlo::promoteType(rewriter, op.getLoc(), input, convertedType);
|
|
|
|
weight = hlo::promoteType(rewriter, op.getLoc(), weight, convertedType);
|
|
|
|
bias = hlo::promoteType(rewriter, op.getLoc(), bias, convertedType);
|
|
|
|
auto batchNormTrainingResult =
|
|
|
|
rewriter.create<stablehlo::BatchNormTrainingOp>(
|
|
|
|
op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input,
|
|
|
|
weight, bias, rewriter.getF32FloatAttr(eps),
|
|
|
|
rewriter.getI64IntegerAttr(feature_index));
|
|
|
|
output = hlo::promoteType(rewriter, op.getLoc(),
|
|
|
|
batchNormTrainingResult.getResult(0),
|
2024-04-11 21:47:35 +08:00
|
|
|
cast<TensorType>(outputTy));
|
2023-08-21 17:36:56 +08:00
|
|
|
} else {
|
|
|
|
auto batchNormTrainingResult =
|
|
|
|
rewriter.create<stablehlo::BatchNormTrainingOp>(
|
|
|
|
op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input,
|
|
|
|
weight, bias, rewriter.getF32FloatAttr(eps),
|
|
|
|
rewriter.getI64IntegerAttr(feature_index));
|
|
|
|
output = batchNormTrainingResult.getResult(0);
|
|
|
|
}
|
|
|
|
rewriter.replaceOp(op, output);
|
2022-07-27 13:07:51 +08:00
|
|
|
return success();
|
|
|
|
} else {
|
|
|
|
Type outputTy = getTypeConverter()->convertType(op.getType());
|
2022-09-23 20:50:29 +08:00
|
|
|
SmallVector<int64_t, 4> castShape{inputTy.getShape().begin(),
|
|
|
|
inputTy.getShape().end()};
|
|
|
|
castShape[1] = weightTy.getShape()[0];
|
|
|
|
auto castTy = RankedTensorType::get(castShape, inputTy.getElementType());
|
2023-02-02 21:29:47 +08:00
|
|
|
// Feature counts must match among operands of
|
|
|
|
// stablehlo::BatchNormInferenceOp.
|
2022-09-23 20:50:29 +08:00
|
|
|
Value inputCasted =
|
|
|
|
rewriter.create<tensor::CastOp>(op.getLoc(), castTy, input);
|
2023-08-21 17:36:56 +08:00
|
|
|
|
|
|
|
Value output;
|
|
|
|
// supported mixed types, like input type is fp16 and weight type is fp32.
|
|
|
|
if (inputTy.getElementType() != weightTy.getElementType()) {
|
|
|
|
RankedTensorType convertedType = inputTy;
|
2024-04-28 05:00:56 +08:00
|
|
|
if (cast<FloatType>(weightTy.getElementType()).getWidth() >
|
|
|
|
cast<FloatType>(inputTy.getElementType()).getWidth()) {
|
2023-08-21 17:36:56 +08:00
|
|
|
convertedType = RankedTensorType::get(inputTy.getShape(),
|
|
|
|
weightTy.getElementType());
|
|
|
|
}
|
|
|
|
input =
|
|
|
|
hlo::promoteType(rewriter, op.getLoc(), inputCasted, convertedType);
|
|
|
|
weight = hlo::promoteType(rewriter, op.getLoc(), weight, convertedType);
|
|
|
|
bias = hlo::promoteType(rewriter, op.getLoc(), bias, convertedType);
|
|
|
|
runningMean =
|
|
|
|
hlo::promoteType(rewriter, op.getLoc(), runningMean, convertedType);
|
|
|
|
runningVar =
|
|
|
|
hlo::promoteType(rewriter, op.getLoc(), runningVar, convertedType);
|
|
|
|
Value bnResult = rewriter.create<stablehlo::BatchNormInferenceOp>(
|
|
|
|
op.getLoc(), convertedType, input, weight, bias, runningMean,
|
|
|
|
runningVar, rewriter.getF32FloatAttr(eps),
|
|
|
|
rewriter.getI64IntegerAttr(feature_index));
|
|
|
|
output = hlo::promoteType(rewriter, op.getLoc(), bnResult,
|
2024-04-11 21:47:35 +08:00
|
|
|
cast<TensorType>(outputTy));
|
2023-08-21 17:36:56 +08:00
|
|
|
} else {
|
|
|
|
output = rewriter.create<stablehlo::BatchNormInferenceOp>(
|
|
|
|
op.getLoc(), inputCasted.getType(), inputCasted, weight, bias,
|
|
|
|
runningMean, runningVar,
|
|
|
|
// 'epsilon' must satisfy constraint: 32-bit float attribute.
|
|
|
|
rewriter.getF32FloatAttr(eps),
|
|
|
|
rewriter.getI64IntegerAttr(feature_index));
|
|
|
|
}
|
2022-09-23 20:50:29 +08:00
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outputTy, output);
|
2022-07-27 13:07:51 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// AtenNativeLayerNormOp
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
|
|
|
AtenNativeLayerNormOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = adaptor.getInput();
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inputTy = cast<RankedTensorType>(input.getType());
|
2022-07-27 13:07:51 +08:00
|
|
|
auto inputShape = inputTy.getShape();
|
|
|
|
auto inputRank = inputTy.getRank();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value weight = adaptor.getWeight();
|
|
|
|
Value bias = adaptor.getBias();
|
2022-07-27 13:07:51 +08:00
|
|
|
|
2022-08-02 12:53:24 +08:00
|
|
|
if (!inputTy.hasStaticShape()) {
|
|
|
|
return op->emitError("dynamic shaped input is not supported");
|
|
|
|
}
|
|
|
|
|
2022-07-27 13:07:51 +08:00
|
|
|
SmallVector<int64_t> normalizedShape;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getNormalizedShape(),
|
2022-11-17 04:33:12 +08:00
|
|
|
m_TorchListOfConstantInts(normalizedShape))) {
|
2022-07-27 13:07:51 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "normalized_shape must be a list of const int");
|
|
|
|
}
|
|
|
|
double eps = 0;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getEps(), m_TorchConstantFloat(&eps))) {
|
2022-08-02 12:53:24 +08:00
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"non const float eps is unsupported");
|
2022-07-27 13:07:51 +08:00
|
|
|
}
|
|
|
|
if (failed(checkNotNone(rewriter, op, weight)) ||
|
|
|
|
failed(checkNotNone(rewriter, op, bias))) {
|
2022-08-02 12:53:24 +08:00
|
|
|
return op->emitError("none weight or bias is unsupported");
|
2022-07-27 13:07:51 +08:00
|
|
|
}
|
2024-04-28 05:00:56 +08:00
|
|
|
auto weightTy = cast<RankedTensorType>(weight.getType());
|
|
|
|
auto biasTy = cast<RankedTensorType>(bias.getType());
|
2022-07-27 13:07:51 +08:00
|
|
|
|
2024-04-28 05:00:56 +08:00
|
|
|
if (!isa<mlir::FloatType>(inputTy.getElementType()) ||
|
|
|
|
!isa<mlir::FloatType>(biasTy.getElementType()) ||
|
|
|
|
!isa<mlir::FloatType>(weightTy.getElementType())) {
|
2022-08-02 12:53:24 +08:00
|
|
|
return op->emitError("currently only float data type are supported");
|
2022-07-27 13:07:51 +08:00
|
|
|
}
|
|
|
|
int64_t normalizedShapeRank = normalizedShape.size();
|
|
|
|
if (weightTy.getRank() != normalizedShapeRank ||
|
|
|
|
biasTy.getRank() != normalizedShapeRank ||
|
|
|
|
inputRank < normalizedShapeRank || normalizedShapeRank < 1) {
|
2022-08-02 12:53:24 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "input or weight or bias shape or"
|
2022-07-27 13:07:51 +08:00
|
|
|
"normalized shape not compatible");
|
|
|
|
}
|
|
|
|
for (int64_t i = 1; i <= normalizedShapeRank; i++) {
|
|
|
|
if (inputShape[inputRank - i] != normalizedShape[normalizedShapeRank - i] ||
|
|
|
|
weightTy.getShape()[normalizedShapeRank - i] !=
|
|
|
|
normalizedShape[normalizedShapeRank - i] ||
|
|
|
|
biasTy.getShape()[normalizedShapeRank - i] !=
|
|
|
|
normalizedShape[normalizedShapeRank - i]) {
|
|
|
|
return op.emitError("mismatching contracting dimension");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-08-02 12:53:24 +08:00
|
|
|
// Flatten dims to fit batch_norm operation.
|
2022-07-27 13:07:51 +08:00
|
|
|
int64_t numFeatureDimSize = 1;
|
|
|
|
int64_t numEmbeddingDimSize = 1;
|
|
|
|
for (int64_t i = 0; i < inputRank - normalizedShapeRank; i++) {
|
|
|
|
numFeatureDimSize *= inputShape[i];
|
|
|
|
}
|
|
|
|
for (int64_t i = 0; i < normalizedShapeRank; i++) {
|
|
|
|
numEmbeddingDimSize *= normalizedShape[i];
|
|
|
|
}
|
|
|
|
SmallVector<int64_t> inputFlattenShape{1, numFeatureDimSize,
|
|
|
|
numEmbeddingDimSize};
|
2023-02-02 21:29:47 +08:00
|
|
|
SmallVector<int64_t> meanOrVarStablehloOutShape{numFeatureDimSize};
|
2022-07-27 13:07:51 +08:00
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
auto stablehloBatchNormOutTy =
|
2022-07-27 13:07:51 +08:00
|
|
|
RankedTensorType::get(inputFlattenShape, inputTy.getElementType());
|
2023-02-02 21:29:47 +08:00
|
|
|
auto stablehloBathNormOutMeanOrVarTy = RankedTensorType::get(
|
|
|
|
meanOrVarStablehloOutShape, inputTy.getElementType());
|
2022-07-27 13:07:51 +08:00
|
|
|
|
2022-08-02 12:53:24 +08:00
|
|
|
// Reshape input
|
2023-02-02 21:29:47 +08:00
|
|
|
auto stablehloInput = rewriter.create<stablehlo::DynamicReshapeOp>(
|
|
|
|
op->getLoc(), stablehloBatchNormOutTy, input,
|
|
|
|
hlo::getConstTensor(rewriter, op, llvm::ArrayRef(inputFlattenShape),
|
|
|
|
{static_cast<int64_t>(inputFlattenShape.size())})
|
2022-08-09 11:17:35 +08:00
|
|
|
.value());
|
2022-07-27 13:07:51 +08:00
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
// Generate "scale" and "offset" Value for stablehlo.BatchNormTrainingOp.
|
2022-07-27 13:07:51 +08:00
|
|
|
SmallVector<APFloat> zeroConstVec(
|
|
|
|
numFeatureDimSize, APFloat::getZero(inputTy.getElementType()
|
|
|
|
.cast<mlir::FloatType>()
|
|
|
|
.getFloatSemantics()));
|
|
|
|
SmallVector<APFloat> oneConstVec(
|
|
|
|
numFeatureDimSize,
|
|
|
|
APFloat(
|
2024-04-28 05:00:56 +08:00
|
|
|
cast<mlir::FloatType>(inputTy.getElementType()).getFloatSemantics(),
|
2022-07-27 13:07:51 +08:00
|
|
|
1));
|
|
|
|
auto oneOrZeroConstType =
|
|
|
|
RankedTensorType::get({numFeatureDimSize}, inputTy.getElementType());
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
Value scale = rewriter.create<stablehlo::ConstantOp>(
|
2022-07-27 13:07:51 +08:00
|
|
|
op->getLoc(), oneOrZeroConstType,
|
|
|
|
DenseElementsAttr::get(oneOrZeroConstType, oneConstVec));
|
2023-02-02 21:29:47 +08:00
|
|
|
Value offset = rewriter.create<stablehlo::ConstantOp>(
|
2022-07-27 13:07:51 +08:00
|
|
|
op->getLoc(), oneOrZeroConstType,
|
|
|
|
DenseElementsAttr::get(oneOrZeroConstType, zeroConstVec));
|
2023-02-02 21:29:47 +08:00
|
|
|
auto batchNormTrainingResult =
|
|
|
|
rewriter.create<stablehlo::BatchNormTrainingOp>(
|
|
|
|
op->getLoc(), stablehloBatchNormOutTy,
|
|
|
|
stablehloBathNormOutMeanOrVarTy, stablehloBathNormOutMeanOrVarTy,
|
|
|
|
stablehloInput, scale, offset, rewriter.getF32FloatAttr(eps),
|
|
|
|
rewriter.getI64IntegerAttr(1));
|
2022-07-27 13:07:51 +08:00
|
|
|
|
2022-08-02 12:53:24 +08:00
|
|
|
// Reshape back
|
2022-07-27 13:07:51 +08:00
|
|
|
auto outputTy =
|
2024-04-28 05:00:56 +08:00
|
|
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType(0)));
|
2022-07-27 13:07:51 +08:00
|
|
|
auto outputMeanOrVarTy =
|
2024-04-28 05:00:56 +08:00
|
|
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType(1)));
|
2022-07-27 13:07:51 +08:00
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
auto output = rewriter.create<stablehlo::DynamicReshapeOp>(
|
2022-07-27 13:07:51 +08:00
|
|
|
op->getLoc(), outputTy, batchNormTrainingResult.getResult(0),
|
2023-02-02 21:29:47 +08:00
|
|
|
hlo::getConstTensor(rewriter, op, outputTy.getShape(),
|
|
|
|
{static_cast<int64_t>(outputTy.getShape().size())})
|
2022-08-09 11:17:35 +08:00
|
|
|
.value());
|
2023-02-02 21:29:47 +08:00
|
|
|
auto mean = rewriter.create<stablehlo::DynamicReshapeOp>(
|
2022-07-27 13:07:51 +08:00
|
|
|
op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(1),
|
2023-02-02 21:29:47 +08:00
|
|
|
hlo::getConstTensor(
|
2022-07-27 13:07:51 +08:00
|
|
|
rewriter, op, outputMeanOrVarTy.getShape(),
|
|
|
|
{static_cast<int64_t>(outputMeanOrVarTy.getShape().size())})
|
2022-08-09 11:17:35 +08:00
|
|
|
.value());
|
2023-02-02 21:29:47 +08:00
|
|
|
auto var = rewriter.create<stablehlo::DynamicReshapeOp>(
|
2022-07-27 13:07:51 +08:00
|
|
|
op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(2),
|
2023-02-02 21:29:47 +08:00
|
|
|
hlo::getConstTensor(
|
2022-07-27 13:07:51 +08:00
|
|
|
rewriter, op, outputMeanOrVarTy.getShape(),
|
|
|
|
{static_cast<int64_t>(outputMeanOrVarTy.getShape().size())})
|
2022-08-09 11:17:35 +08:00
|
|
|
.value());
|
2022-07-27 13:07:51 +08:00
|
|
|
|
|
|
|
// Apply affine transform: output x weight + bias [element-wise]
|
2023-02-02 21:29:47 +08:00
|
|
|
auto bcastedWeight = hlo::promoteAndBroadcast(rewriter, weight, outputTy);
|
|
|
|
auto bcastedBias = hlo::promoteAndBroadcast(rewriter, bias, outputTy);
|
2022-07-27 13:07:51 +08:00
|
|
|
auto outputMulWeight =
|
2023-02-02 21:29:47 +08:00
|
|
|
rewriter.create<stablehlo::MulOp>(op->getLoc(), output, bcastedWeight);
|
|
|
|
auto finalOuput = rewriter.create<stablehlo::AddOp>(
|
|
|
|
op->getLoc(), outputMulWeight, bcastedBias);
|
2022-07-27 13:07:51 +08:00
|
|
|
rewriter.replaceOp(op, {finalOuput, mean, var});
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-08-10 13:12:34 +08:00
|
|
|
// AtenCatOp
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
|
|
|
AtenCatOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
auto outType =
|
2024-04-28 05:00:56 +08:00
|
|
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
2022-08-10 13:12:34 +08:00
|
|
|
int64_t dim;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
|
2022-08-10 13:12:34 +08:00
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"only constant dim param is supported");
|
|
|
|
}
|
2023-04-07 19:49:35 +08:00
|
|
|
dim = toPositiveDim(dim, outType.getRank());
|
|
|
|
if (!isValidDim(dim, outType.getRank()))
|
|
|
|
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
2022-08-10 13:12:34 +08:00
|
|
|
|
|
|
|
SmallVector<Value> torchTensors;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!getListConstructElements(op.getTensors(), torchTensors)) {
|
2022-08-10 13:12:34 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "input should comes from a PrimListConstructOp");
|
|
|
|
}
|
|
|
|
SmallVector<Value> builtinTensors = getTypeConvertedValues(
|
|
|
|
rewriter, op->getLoc(), getTypeConverter(), torchTensors);
|
|
|
|
|
|
|
|
// Promote type
|
|
|
|
for (auto &v : builtinTensors) {
|
2023-06-26 00:04:17 +08:00
|
|
|
v = hlo::promoteType(rewriter, op->getLoc(), v, outType);
|
2022-08-10 13:12:34 +08:00
|
|
|
}
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(
|
2023-04-07 19:49:35 +08:00
|
|
|
op, outType, ValueRange(builtinTensors), dim);
|
2022-08-23 16:47:21 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// AtenNumelOp
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenNumelOp>::matchAndRewrite(
|
2022-09-16 15:09:21 +08:00
|
|
|
AtenNumelOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
2022-12-08 04:20:41 +08:00
|
|
|
auto self = adaptor.getSelf();
|
2024-04-28 05:00:56 +08:00
|
|
|
auto selfTy = dyn_cast<RankedTensorType>(self.getType());
|
2022-08-23 16:47:21 +08:00
|
|
|
size_t rank = selfTy.getRank();
|
|
|
|
|
2022-09-01 10:36:02 +08:00
|
|
|
Type intType = rewriter.getIntegerType(options.dimSizeIndexBits);
|
2022-08-23 16:47:21 +08:00
|
|
|
auto loc = op->getLoc();
|
2022-09-16 15:09:21 +08:00
|
|
|
Value numel = rewriter.create<arith::ConstantOp>(
|
|
|
|
loc, rewriter.getIntegerAttr(intType, 1));
|
|
|
|
for (size_t d = 0; d < rank; ++d) {
|
|
|
|
Value dimSize = rewriter.create<arith::IndexCastOp>(
|
2022-08-23 16:47:21 +08:00
|
|
|
loc, intType, rewriter.create<tensor::DimOp>(loc, self, d));
|
2022-09-16 15:09:21 +08:00
|
|
|
numel = rewriter.create<arith::MulIOp>(loc, numel, dimSize);
|
2022-08-23 16:47:21 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
auto outTy = getTypeConverter()->convertType(op.getType());
|
|
|
|
if (outTy != numel.getType()) {
|
2022-09-16 15:09:21 +08:00
|
|
|
rewriter.replaceOpWithNewOp<arith::ExtSIOp>(op, outTy, numel);
|
2022-08-23 16:47:21 +08:00
|
|
|
} else {
|
|
|
|
rewriter.replaceOp(op, numel);
|
|
|
|
}
|
2022-08-10 13:12:34 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-09-16 15:09:21 +08:00
|
|
|
// AtenClampOp
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
|
|
|
|
AtenClampOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = adaptor.getSelf();
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inputType = cast<RankedTensorType>(input.getType());
|
2022-09-16 15:09:21 +08:00
|
|
|
auto inputElemType = inputType.getElementType();
|
2022-12-08 04:20:41 +08:00
|
|
|
Value minValue = adaptor.getMin();
|
|
|
|
Value maxValue = adaptor.getMax();
|
2022-09-16 15:09:21 +08:00
|
|
|
if (failed(checkNotNone(rewriter, op, minValue)) &&
|
|
|
|
failed(checkNotNone(rewriter, op, maxValue))) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "this op should be folded as its `min` and `max` both are none");
|
|
|
|
} else if (failed(checkNotNone(rewriter, op, minValue))) {
|
2023-02-02 21:29:47 +08:00
|
|
|
maxValue =
|
|
|
|
hlo::scalarToStablehloTensor(rewriter, op, maxValue, inputElemType);
|
2022-09-16 15:09:21 +08:00
|
|
|
auto minInfo = getMinValueOfDtype(op, inputElemType, rewriter);
|
|
|
|
if (failed(minInfo)) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "failed to generate min value of dtype");
|
|
|
|
}
|
|
|
|
minValue = *minInfo;
|
|
|
|
} else if (failed(checkNotNone(rewriter, op, maxValue))) {
|
2023-02-02 21:29:47 +08:00
|
|
|
minValue =
|
|
|
|
hlo::scalarToStablehloTensor(rewriter, op, minValue, inputElemType);
|
2022-09-16 15:09:21 +08:00
|
|
|
auto maxInfo = getMaxValueOfDtype(op, inputElemType, rewriter);
|
|
|
|
if (failed(maxInfo)) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "failed to generate max value of dtype");
|
|
|
|
}
|
|
|
|
maxValue = *maxInfo;
|
|
|
|
} else {
|
2023-02-02 21:29:47 +08:00
|
|
|
minValue =
|
|
|
|
hlo::scalarToStablehloTensor(rewriter, op, minValue, inputElemType);
|
|
|
|
maxValue =
|
|
|
|
hlo::scalarToStablehloTensor(rewriter, op, maxValue, inputElemType);
|
2022-09-16 15:09:21 +08:00
|
|
|
}
|
2023-02-02 21:29:47 +08:00
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::ClampOp>(op, minValue, input,
|
|
|
|
maxValue);
|
2022-09-16 15:09:21 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2024-04-19 10:55:27 +08:00
|
|
|
// AtenClampTensorOp
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenClampTensorOp>::matchAndRewrite(
|
|
|
|
AtenClampTensorOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
Value input = adaptor.getSelf();
|
|
|
|
auto inputType = cast<RankedTensorType>(input.getType());
|
|
|
|
auto inputElemType = inputType.getElementType();
|
|
|
|
Value minValue = adaptor.getMin();
|
|
|
|
Value maxValue = adaptor.getMax();
|
|
|
|
auto minIsNotNone = checkNotNone(rewriter, op, minValue);
|
|
|
|
auto maxIsNotNone = checkNotNone(rewriter, op, maxValue);
|
|
|
|
if (failed(minIsNotNone) && failed(maxIsNotNone)) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "this op should be folded as its `min` and `max` both are none");
|
|
|
|
} else if (failed(minIsNotNone)) {
|
|
|
|
auto minInfo = getMinValueOfDtype(op, inputElemType, rewriter);
|
|
|
|
if (failed(minInfo)) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "failed to generate min value of dtype");
|
|
|
|
}
|
|
|
|
minValue = *minInfo;
|
|
|
|
} else if (failed(maxIsNotNone)) {
|
|
|
|
auto maxInfo = getMaxValueOfDtype(op, inputElemType, rewriter);
|
|
|
|
if (failed(maxInfo)) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "failed to generate max value of dtype");
|
|
|
|
}
|
|
|
|
maxValue = *maxInfo;
|
|
|
|
}
|
2024-04-19 17:08:29 +08:00
|
|
|
if (inputType.hasStaticShape()) {
|
|
|
|
minValue = hlo::promoteAndBroadcast(rewriter, minValue, inputType);
|
|
|
|
maxValue = hlo::promoteAndBroadcast(rewriter, maxValue, inputType);
|
|
|
|
}
|
2024-04-19 10:55:27 +08:00
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::ClampOp>(op, minValue, input,
|
|
|
|
maxValue);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-09-20 22:31:24 +08:00
|
|
|
// AtenArangeStartStepOp
|
|
|
|
// aten.arange.start_step = range(ceil((end-start)/step)) * step + start.
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
|
|
|
|
AtenArangeStartStepOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
Location loc = op->getLoc();
|
|
|
|
|
|
|
|
// Get element type of resultType as dtype
|
|
|
|
auto outType = this->getTypeConverter()
|
|
|
|
->convertType(op.getType())
|
|
|
|
.cast<RankedTensorType>();
|
|
|
|
auto dtype = outType.getElementType();
|
2024-04-11 21:47:35 +08:00
|
|
|
if (!isa<mlir::IntegerType>(dtype) && !isa<mlir::FloatType>(dtype)) {
|
2022-09-20 22:31:24 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: only int or float dtype supported");
|
|
|
|
}
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
Value start =
|
|
|
|
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStart(), dtype);
|
|
|
|
Value end =
|
|
|
|
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getEnd(), dtype);
|
|
|
|
Value step =
|
|
|
|
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStep(), dtype);
|
2022-09-20 22:31:24 +08:00
|
|
|
|
|
|
|
// Get length of the 1-d output tensor
|
2023-02-02 21:29:47 +08:00
|
|
|
Value subOut = rewriter.create<stablehlo::SubtractOp>(loc, end, start);
|
2024-04-11 15:55:56 +08:00
|
|
|
// promote div to f64
|
|
|
|
Type divType = RankedTensorType::get({}, rewriter.getF64Type());
|
|
|
|
Value divOut = rewriter.create<stablehlo::DivOp>(
|
|
|
|
loc, rewriter.create<stablehlo::ConvertOp>(loc, divType, subOut),
|
|
|
|
rewriter.create<stablehlo::ConvertOp>(loc, divType, step));
|
|
|
|
// ceil to i64
|
|
|
|
Value resultLength = rewriter.create<stablehlo::ConvertOp>(
|
|
|
|
loc, RankedTensorType::get({}, rewriter.getI64Type()),
|
|
|
|
rewriter.create<stablehlo::CeilOp>(loc, divOut));
|
|
|
|
resultLength = rewriter.create<stablehlo::ReshapeOp>(
|
|
|
|
loc, RankedTensorType::get({1}, rewriter.getI64Type()), resultLength);
|
2022-09-20 22:31:24 +08:00
|
|
|
|
|
|
|
Value window =
|
2023-02-02 21:29:47 +08:00
|
|
|
rewriter.create<stablehlo::DynamicIotaOp>(loc, outType, resultLength, 0);
|
Bump stablehlo to openxla/stablehlo@fd52182f76cadb82f2064fe5fc49a4fb4347a826 (#2821)
With the recent LLVM integrate and changes from
https://github.com/llvm/llvm-project/pull/78260, we hit this build error
in Stablehlo (which is quite old).
```
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1020:14: error: no member named 'startRootUpdate' in 'mlir::PatternRewriter'
rewriter.startRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1026:16: error: no member named 'finalizeRootUpdate' in 'mlir::PatternRewriter'
rewriter.finalizeRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1029:16: error: no member named 'cancelRootUpdate' in 'mlir::PatternRewriter'
rewriter.cancelRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1108:14: error: no member named 'updateRootInPlace' in 'mlir::PatternRewriter'
rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; });
~~~~~~~~ ^
4 errors generated.
Target @torch-mlir//:torch-mlir-opt failed to build
```
I'm still puzzled as to how this didn't fail with the CMake merge gating
CI (do we not test Stablehlo builds/tests?). In any case, bumping our
submodule to https://github.com/openxla/stablehlo/pull/1918 fixes it.
It exposes a new failing lit test in TorchToStablehlo though, that I
have looped stablehlo developers into
([here](https://discord.com/channels/999073994483433573/999074539138990131/1201235845391331419)).
```
bazel run @torch-mlir//test/Conversion:TorchToStablehlo/scatter.mlir.test
...external/torch-mlir/test/Conversion/TorchToStablehlo/scatter.mlir
within split at <stdin>:1 offset :33:8: error: unexpected error: Expects non-empty reduction block for type inference
%0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64>
^
LLVM ERROR: Failed to infer result type(s).
```
Bazel CI:
https://github.com/sjain-stanford/torch-mlir/actions/runs/7732673480/job/21083102228
2024-02-01 06:21:17 +08:00
|
|
|
DenseI64ArrayAttr broadcastDimensions;
|
2022-09-20 22:31:24 +08:00
|
|
|
Value mulOut = rewriter.create<chlo::BroadcastMulOp>(loc, window, step,
|
|
|
|
broadcastDimensions);
|
|
|
|
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(op, mulOut, start,
|
|
|
|
broadcastDimensions);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2024-04-24 14:15:11 +08:00
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenConstantPadNdOp>::matchAndRewrite(
|
|
|
|
AtenConstantPadNdOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
Value self = adaptor.getSelf();
|
|
|
|
auto selfTy = self.getType().cast<RankedTensorType>();
|
|
|
|
auto selfElemTy = selfTy.getElementType();
|
|
|
|
int64_t rank = selfTy.getRank();
|
|
|
|
|
|
|
|
SmallVector<int64_t> padInts;
|
|
|
|
if (!matchPattern(op.getPad(), m_TorchListOfConstantInts(padInts)))
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"only support constant int pad ranges");
|
|
|
|
uint64_t padRank = padInts.size() / 2;
|
|
|
|
if (padRank * 2 != padInts.size())
|
|
|
|
return rewriter.notifyMatchFailure(op, "pad range size is not even");
|
|
|
|
if (rank < 0 || padRank > (uint64_t)rank)
|
|
|
|
return rewriter.notifyMatchFailure(op, "padding exceeds tensor rank");
|
|
|
|
|
|
|
|
// Initialize low/high paddings with 0 for all the dims.
|
|
|
|
SmallVector<int64_t> lowPadding(/*Size=*/rank, /*Value=*/0);
|
|
|
|
SmallVector<int64_t> highPadding(/*Size=*/rank, /*Value=*/0);
|
|
|
|
// Add the requested padding - note op.pad() is highest dim first ordered
|
|
|
|
// pairs of low,high.
|
|
|
|
// Add the requested padding - note op.pad() is highest dim first ordered
|
|
|
|
// pairs of low,high.
|
|
|
|
for (uint64_t i = 0; i < padRank; ++i) {
|
|
|
|
lowPadding[rank - i - 1] = padInts[i * 2];
|
|
|
|
highPadding[rank - i - 1] = padInts[i * 2 + 1];
|
|
|
|
}
|
|
|
|
|
|
|
|
Value constantValue = hlo::scalarToStablehloTensor(
|
|
|
|
rewriter, op, adaptor.getValue(), selfElemTy);
|
|
|
|
|
|
|
|
SmallVector<int64_t> interiorPadding(rank, 0);
|
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::PadOp>(
|
|
|
|
op, self, constantValue, lowPadding, highPadding, interiorPadding);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-12-21 20:09:43 +08:00
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
|
|
|
|
AtenGeluBackwardOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
Value input = adaptor.getSelf();
|
2023-02-02 21:29:47 +08:00
|
|
|
auto outType =
|
2024-04-28 05:00:56 +08:00
|
|
|
cast<TensorType>(this->getTypeConverter()->convertType(op.getType()));
|
2022-12-21 20:09:43 +08:00
|
|
|
if (!outType) {
|
|
|
|
return op.emitError("only tensor type is supported");
|
|
|
|
}
|
|
|
|
// TODO: Handle approximate.
|
|
|
|
std::string approximate;
|
|
|
|
if (!matchPattern(op.getApproximate(), m_TorchConstantStr(approximate)) ||
|
|
|
|
approximate != "none") {
|
|
|
|
return rewriter.notifyMatchFailure(op, "Unsupported value of approximate");
|
|
|
|
}
|
|
|
|
// Create constant value
|
2024-05-26 12:34:56 +08:00
|
|
|
Value kAlpha =
|
|
|
|
hlo::getConstantLike(rewriter, loc, 0.70710678118654752440, input);
|
2022-12-21 20:09:43 +08:00
|
|
|
Value cstAlpha0 =
|
2024-05-26 12:34:56 +08:00
|
|
|
hlo::getConstantLike(rewriter, loc, 1.12837916709551257390, input);
|
|
|
|
Value half = hlo::getConstantLike(rewriter, loc, .5, input);
|
|
|
|
Value one = hlo::getConstantLike(rewriter, loc, 1.0, input);
|
|
|
|
Value negHalf = hlo::getConstantLike(rewriter, loc, -0.5, input);
|
2022-12-21 20:09:43 +08:00
|
|
|
|
|
|
|
// Compute
|
2023-02-02 21:29:47 +08:00
|
|
|
Value kBeta0 =
|
|
|
|
rewriter.create<stablehlo::MulOp>(loc, outType, kAlpha, cstAlpha0);
|
|
|
|
Value kBeta = rewriter.create<stablehlo::MulOp>(loc, outType, kBeta0, half);
|
|
|
|
Value erfArg = rewriter.create<stablehlo::MulOp>(loc, outType, kAlpha,
|
|
|
|
adaptor.getSelf());
|
2022-12-21 20:09:43 +08:00
|
|
|
Value erf = rewriter.create<mlir::chlo::ErfOp>(loc, outType, erfArg);
|
2023-02-02 21:29:47 +08:00
|
|
|
Value erfAdd = rewriter.create<stablehlo::AddOp>(loc, outType, erf, one);
|
|
|
|
Value cdf = rewriter.create<stablehlo::MulOp>(loc, outType, erfAdd, half);
|
|
|
|
Value inputSquared = rewriter.create<stablehlo::MulOp>(
|
2022-12-21 20:09:43 +08:00
|
|
|
loc, outType, adaptor.getSelf(), adaptor.getSelf());
|
|
|
|
Value negHalfInputSquared =
|
2023-02-02 21:29:47 +08:00
|
|
|
rewriter.create<stablehlo::MulOp>(loc, outType, inputSquared, negHalf);
|
2022-12-21 20:09:43 +08:00
|
|
|
Value expRes =
|
2023-02-02 21:29:47 +08:00
|
|
|
rewriter.create<stablehlo::ExpOp>(loc, outType, negHalfInputSquared);
|
|
|
|
Value pdf = rewriter.create<stablehlo::MulOp>(loc, outType, kBeta, expRes);
|
2022-12-21 20:09:43 +08:00
|
|
|
Value pdfTimesInput =
|
2023-02-02 21:29:47 +08:00
|
|
|
rewriter.create<stablehlo::MulOp>(loc, outType, pdf, adaptor.getSelf());
|
2022-12-21 20:09:43 +08:00
|
|
|
Value pdfTimesInputAddCdf =
|
2023-02-02 21:29:47 +08:00
|
|
|
rewriter.create<stablehlo::AddOp>(loc, outType, pdfTimesInput, cdf);
|
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::MulOp>(
|
|
|
|
op, outType, adaptor.getGradOutput(), pdfTimesInputAddCdf);
|
2022-12-21 20:09:43 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2023-05-05 00:55:03 +08:00
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenPowTensorTensorOp>::matchAndRewrite(
|
|
|
|
AtenPowTensorTensorOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
Value lhs = adaptor.getSelf();
|
2024-04-28 05:00:56 +08:00
|
|
|
auto lhsTy = cast<TensorType>(lhs.getType());
|
2023-05-05 00:55:03 +08:00
|
|
|
Value rhs = adaptor.getExponent();
|
2024-04-28 05:00:56 +08:00
|
|
|
auto rhsTy = cast<TensorType>(rhs.getType());
|
2023-05-05 00:55:03 +08:00
|
|
|
|
|
|
|
if (!lhsTy || !rhsTy)
|
|
|
|
return op.emitError("only Tensor types supported");
|
|
|
|
|
|
|
|
auto outTy =
|
2024-04-28 05:00:56 +08:00
|
|
|
cast<TensorType>(this->getTypeConverter()->convertType(op.getType()));
|
2023-05-05 00:55:03 +08:00
|
|
|
|
2023-06-26 00:04:17 +08:00
|
|
|
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy);
|
|
|
|
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy);
|
2023-05-05 00:55:03 +08:00
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<chlo::BroadcastPowOp>(op, outTy, lhs, rhs,
|
|
|
|
/*broadcast_attr*/ nullptr);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2023-05-19 10:07:35 +08:00
|
|
|
// Converts `aten.empty.memory_format` to `tensor.empty` op.
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenEmptyMemoryFormatOp>::matchAndRewrite(
|
|
|
|
AtenEmptyMemoryFormatOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
|
|
|
// TODO: Add support pin_memory and memory_format features.
|
|
|
|
// At this point all tensors should have value semantics, and hence the
|
|
|
|
// `layout` check can be ignored.
|
|
|
|
|
|
|
|
// The pin_memory should be either `False` or `none`.
|
|
|
|
bool pinMemory;
|
2024-04-28 05:00:56 +08:00
|
|
|
if (!isa<Torch::NoneType>(op.getPinMemory().getType()) &&
|
2023-05-19 10:07:35 +08:00
|
|
|
(!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) ||
|
|
|
|
pinMemory))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: pin_memory must be either None or false");
|
|
|
|
|
|
|
|
// Only `none`, `contiguous` and `preserve` memory_format is supported.
|
2024-04-28 05:00:56 +08:00
|
|
|
if (!isa<Torch::NoneType>(op.getMemoryFormat().getType())) {
|
2023-05-19 10:07:35 +08:00
|
|
|
int64_t memoryFormat;
|
|
|
|
if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat)))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: the memory format should be specified in "
|
|
|
|
"an integer constant");
|
|
|
|
if (memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
|
|
|
|
memoryFormat != torch_upstream::MemoryFormat::Preserve)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: only none, contiguous and preserve "
|
|
|
|
"memory_format is supported");
|
|
|
|
}
|
|
|
|
|
2024-04-28 05:00:56 +08:00
|
|
|
if (!isa<Torch::NoneType>(op.getDevice().getType())) {
|
2023-05-19 10:07:35 +08:00
|
|
|
std::string device;
|
|
|
|
if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device)))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: device must be a constant str");
|
|
|
|
}
|
|
|
|
|
|
|
|
// TODO: Add support for non-strided layout.
|
|
|
|
// torch.layout is by default strided i.e. 0.
|
2024-04-28 05:00:56 +08:00
|
|
|
if (!isa<Torch::NoneType>(op.getLayout().getType())) {
|
2023-05-19 10:07:35 +08:00
|
|
|
int64_t tensorLayout;
|
|
|
|
if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout)))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: layout must be a constant");
|
|
|
|
else if (tensorLayout != torch_upstream::Layout::Strided)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: layout is expected to be strided");
|
|
|
|
}
|
|
|
|
|
|
|
|
Location loc = op.getLoc();
|
2023-08-16 00:53:28 +08:00
|
|
|
const TypeConverter *typeConverter = this->getTypeConverter();
|
2023-05-19 10:07:35 +08:00
|
|
|
SmallVector<Value> resultSizeTorchInt, resultSize, resultSizeIndex;
|
|
|
|
if (!getListConstructElements(op.getSize(), resultSizeTorchInt)) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: size must be constructed using ListConstruct");
|
|
|
|
}
|
|
|
|
resultSize =
|
|
|
|
getTypeConvertedValues(rewriter, loc, typeConverter, resultSizeTorchInt);
|
|
|
|
for (auto size : resultSize)
|
|
|
|
resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size));
|
|
|
|
|
|
|
|
auto resultType =
|
2024-04-28 05:00:56 +08:00
|
|
|
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
2023-05-19 10:07:35 +08:00
|
|
|
Type resultElementType;
|
2024-04-28 05:00:56 +08:00
|
|
|
if (isa<Torch::NoneType>(op.getDtype().getType())) {
|
2023-07-10 15:36:21 +08:00
|
|
|
resultElementType = resultType.getElementType();
|
2023-05-19 10:07:35 +08:00
|
|
|
} else {
|
|
|
|
int64_t dtypeInt;
|
|
|
|
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt)))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: dtype must be a constant integer or none");
|
2024-03-04 23:31:54 +08:00
|
|
|
FailureOr<Type> maybeResultElementType =
|
|
|
|
torch_to_stablehlo::getBackendTypeForScalarType(
|
|
|
|
op->getContext(), (torch_upstream::ScalarType)dtypeInt);
|
2023-05-19 10:07:35 +08:00
|
|
|
if (failed(maybeResultElementType)) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unable to convert `dtypeInt` to builtin type");
|
|
|
|
}
|
|
|
|
resultElementType = *maybeResultElementType;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Create an uninitialized tensor of `resultSize` shape.
|
|
|
|
Value initTensor = rewriter.create<tensor::EmptyOp>(
|
|
|
|
loc, getAsOpFoldResult(resultSizeIndex), resultElementType);
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, initTensor);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-12-22 17:12:52 +08:00
|
|
|
// RuntimeAssertOp
|
|
|
|
namespace {
|
|
|
|
class ConvertRuntimeAssertOp : public OpConversionPattern<RuntimeAssertOp> {
|
|
|
|
public:
|
|
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(RuntimeAssertOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
bool condition;
|
|
|
|
if (!matchPattern(op.getCondition(), m_TorchConstantBool(&condition))) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: condition must be a constant");
|
|
|
|
}
|
|
|
|
if (!condition) {
|
|
|
|
return op->emitError("condition must be true");
|
|
|
|
}
|
|
|
|
rewriter.eraseOp(op);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2023-06-15 10:27:34 +08:00
|
|
|
// AtenFillScalarOp
|
2023-05-12 07:41:46 +08:00
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenFillScalarOp>::matchAndRewrite(
|
|
|
|
AtenFillScalarOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
auto outType =
|
2024-04-28 05:00:56 +08:00
|
|
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
2023-05-12 07:41:46 +08:00
|
|
|
auto dtype = outType.getElementType();
|
|
|
|
Value scalarTensor =
|
|
|
|
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getValue(), dtype);
|
2023-07-27 18:35:25 +08:00
|
|
|
Value shapeTensor =
|
|
|
|
rewriter.create<shape::ShapeOfOp>(op->getLoc(), adaptor.getSelf());
|
|
|
|
Value bcastScalar = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
|
|
|
|
op->getLoc(), outType, scalarTensor, shapeTensor,
|
Bump stablehlo to openxla/stablehlo@fd52182f76cadb82f2064fe5fc49a4fb4347a826 (#2821)
With the recent LLVM integrate and changes from
https://github.com/llvm/llvm-project/pull/78260, we hit this build error
in Stablehlo (which is quite old).
```
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1020:14: error: no member named 'startRootUpdate' in 'mlir::PatternRewriter'
rewriter.startRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1026:16: error: no member named 'finalizeRootUpdate' in 'mlir::PatternRewriter'
rewriter.finalizeRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1029:16: error: no member named 'cancelRootUpdate' in 'mlir::PatternRewriter'
rewriter.cancelRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1108:14: error: no member named 'updateRootInPlace' in 'mlir::PatternRewriter'
rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; });
~~~~~~~~ ^
4 errors generated.
Target @torch-mlir//:torch-mlir-opt failed to build
```
I'm still puzzled as to how this didn't fail with the CMake merge gating
CI (do we not test Stablehlo builds/tests?). In any case, bumping our
submodule to https://github.com/openxla/stablehlo/pull/1918 fixes it.
It exposes a new failing lit test in TorchToStablehlo though, that I
have looped stablehlo developers into
([here](https://discord.com/channels/999073994483433573/999074539138990131/1201235845391331419)).
```
bazel run @torch-mlir//test/Conversion:TorchToStablehlo/scatter.mlir.test
...external/torch-mlir/test/Conversion/TorchToStablehlo/scatter.mlir
within split at <stdin>:1 offset :33:8: error: unexpected error: Expects non-empty reduction block for type inference
%0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64>
^
LLVM ERROR: Failed to infer result type(s).
```
Bazel CI:
https://github.com/sjain-stanford/torch-mlir/actions/runs/7732673480/job/21083102228
2024-02-01 06:21:17 +08:00
|
|
|
rewriter.getDenseI64ArrayAttr({}));
|
2023-05-12 07:41:46 +08:00
|
|
|
rewriter.replaceOp(op, bcastScalar);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2023-06-15 10:27:34 +08:00
|
|
|
// AtenFlipOp
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenFlipOp>::matchAndRewrite(
|
|
|
|
AtenFlipOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
Value self = adaptor.getSelf();
|
|
|
|
auto outType =
|
2024-04-28 05:00:56 +08:00
|
|
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
2023-06-15 10:27:34 +08:00
|
|
|
|
|
|
|
SmallVector<int64_t> dims;
|
|
|
|
if (!matchPattern(op.getDims(), m_TorchListOfConstantInts(dims))) {
|
|
|
|
return rewriter.notifyMatchFailure(op, "dims must be a list of const int");
|
|
|
|
}
|
|
|
|
for (unsigned i = 0, e = dims.size(); i < e; i++) {
|
|
|
|
dims[i] = toPositiveDim(dims[i], outType.getRank());
|
|
|
|
if (!isValidDim(dims[i], outType.getRank())) {
|
|
|
|
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-12-08 15:13:42 +08:00
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::ReverseOp>(op, outType, self, dims);
|
2023-06-15 10:27:34 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2024-04-21 00:03:37 +08:00
|
|
|
// AtenRemainderTensorOp
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenRemainderTensorOp>::matchAndRewrite(
|
|
|
|
AtenRemainderTensorOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
Value lhs = adaptor.getSelf();
|
|
|
|
Value rhs = adaptor.getOther();
|
|
|
|
|
|
|
|
auto resultType =
|
|
|
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
|
|
|
lhs = hlo::promoteType(rewriter, op->getLoc(), lhs, resultType);
|
|
|
|
rhs = hlo::promoteType(rewriter, op->getLoc(), rhs, resultType);
|
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::RemOp>(op, lhs, rhs);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2024-04-21 08:39:36 +08:00
|
|
|
// AtenFmodTensorOp
|
|
|
|
// torch.fmod(a, b) == a - a.div(b, rounding_mode="trunc") * b
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenFmodTensorOp>::matchAndRewrite(
|
|
|
|
AtenFmodTensorOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
auto loc = op->getLoc();
|
|
|
|
Value lhs = adaptor.getSelf();
|
|
|
|
Value rhs = adaptor.getOther();
|
|
|
|
|
|
|
|
auto resultType =
|
|
|
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
|
|
|
lhs = hlo::promoteType(rewriter, op->getLoc(), lhs, resultType);
|
|
|
|
rhs = hlo::promoteType(rewriter, op->getLoc(), rhs, resultType);
|
|
|
|
|
|
|
|
stablehlo::MulOp mul;
|
|
|
|
auto div = rewriter.create<stablehlo::DivOp>(loc, lhs, rhs);
|
|
|
|
if (isa<mlir::FloatType>(resultType.getElementType())) {
|
|
|
|
// rounding mode is trunc
|
|
|
|
auto sign = rewriter.create<stablehlo::SignOp>(loc, div);
|
|
|
|
auto abs = rewriter.create<stablehlo::AbsOp>(loc, div);
|
|
|
|
auto floor = rewriter.create<stablehlo::FloorOp>(loc, abs);
|
|
|
|
auto trunc = rewriter.create<stablehlo::MulOp>(loc, sign, floor);
|
|
|
|
mul = rewriter.create<stablehlo::MulOp>(loc, trunc, rhs);
|
|
|
|
} else {
|
|
|
|
mul = rewriter.create<stablehlo::MulOp>(loc, div, rhs);
|
|
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::SubtractOp>(op, lhs, mul);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2024-04-26 09:20:49 +08:00
|
|
|
// AtenBitwiseLeftShiftTensorOp
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenBitwiseLeftShiftTensorOp>::matchAndRewrite(
|
|
|
|
AtenBitwiseLeftShiftTensorOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
Value lhs = adaptor.getSelf();
|
|
|
|
Value rhs = adaptor.getOther();
|
|
|
|
|
|
|
|
auto resultType =
|
|
|
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
|
|
|
rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType);
|
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::ShiftLeftOp>(op, lhs, rhs);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// AtenBitwiseRightShiftTensorOp
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenBitwiseRightShiftTensorOp>::matchAndRewrite(
|
|
|
|
AtenBitwiseRightShiftTensorOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
Value lhs = adaptor.getSelf();
|
|
|
|
Value rhs = adaptor.getOther();
|
|
|
|
|
|
|
|
auto resultType =
|
|
|
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
|
|
|
rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType);
|
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::ShiftRightArithmeticOp>(op, lhs, rhs);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2024-05-20 15:49:24 +08:00
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
|
|
|
|
AtenTrilOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
|
|
|
Value self = adaptor.getSelf();
|
|
|
|
|
|
|
|
auto selfTy = self.getType().cast<RankedTensorType>();
|
|
|
|
if (!selfTy.hasStaticShape()) {
|
|
|
|
return op->emitError("dynamic shaped input is not supported");
|
|
|
|
}
|
|
|
|
|
|
|
|
ArrayRef<int64_t> selfShape = selfTy.getShape();
|
|
|
|
int64_t selfRank = selfTy.getRank();
|
|
|
|
auto iotaElementTy = mlir::IntegerType::get(op.getContext(), 64);
|
|
|
|
auto iotaTy = RankedTensorType::get(
|
|
|
|
{selfShape[selfRank - 2], selfShape[selfRank - 1]}, iotaElementTy);
|
|
|
|
Value colIdxTensor =
|
|
|
|
rewriter.create<stablehlo::IotaOp>(loc, iotaTy, 1).getResult();
|
|
|
|
Value rowIdxTensor =
|
|
|
|
rewriter.create<stablehlo::IotaOp>(loc, iotaTy, 0).getResult();
|
|
|
|
|
|
|
|
Value diagonal = adaptor.getDiagonal();
|
|
|
|
Value diagonalTensor =
|
|
|
|
rewriter.create<tensor::FromElementsOp>(loc, diagonal).getResult();
|
|
|
|
|
|
|
|
auto bcastDimensions = rewriter.getDenseI64ArrayAttr({1});
|
|
|
|
Value shiftedRowIdxTensor = rewriter.create<chlo::BroadcastAddOp>(
|
|
|
|
loc, rowIdxTensor, diagonalTensor, bcastDimensions);
|
|
|
|
|
|
|
|
auto cmpDirectionAttr = stablehlo::ComparisonDirectionAttr::get(
|
|
|
|
rewriter.getContext(), stablehlo::ComparisonDirection::LE);
|
|
|
|
auto cmpTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
|
|
|
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
|
|
|
|
auto cmpTy = iotaTy.clone(rewriter.getI1Type());
|
|
|
|
Value cmpRes = rewriter.create<stablehlo::CompareOp>(
|
|
|
|
loc, cmpTy, colIdxTensor, shiftedRowIdxTensor, cmpDirectionAttr,
|
|
|
|
cmpTypeAttr);
|
|
|
|
|
|
|
|
auto resTy =
|
|
|
|
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
|
|
|
|
|
|
|
auto bcastTy = resTy.clone(rewriter.getI1Type());
|
|
|
|
auto bcastAttr = rewriter.getDenseI64ArrayAttr({selfRank - 2, selfRank - 1});
|
|
|
|
Value bcastedCmpRes = rewriter.create<stablehlo::BroadcastInDimOp>(
|
|
|
|
loc, bcastTy, cmpRes, bcastAttr);
|
|
|
|
|
|
|
|
auto resElemTy = resTy.getElementType();
|
|
|
|
Value zeroTensor;
|
|
|
|
if (resElemTy.isa<mlir::FloatType>()) {
|
|
|
|
auto constAttr = SplatElementsAttr::get(
|
|
|
|
resTy, llvm::APFloat::getZero(
|
|
|
|
resElemTy.cast<FloatType>().getFloatSemantics(), false));
|
|
|
|
zeroTensor = rewriter.create<stablehlo::ConstantOp>(loc, resTy, constAttr);
|
|
|
|
} else if (resElemTy.isa<mlir::IntegerType>()) {
|
|
|
|
auto constAttr = SplatElementsAttr::get(
|
|
|
|
resTy,
|
|
|
|
llvm::APInt::getZero(resElemTy.cast<mlir::IntegerType>().getWidth()));
|
|
|
|
zeroTensor = rewriter.create<stablehlo::ConstantOp>(loc, resTy, constAttr);
|
|
|
|
} else {
|
|
|
|
return op.emitError("element type is not float or integer");
|
|
|
|
}
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::SelectOp>(
|
|
|
|
op.getOperation(), resTy, bcastedCmpRes, self, zeroTensor);
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
2022-07-21 07:18:16 +08:00
|
|
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
2023-02-02 21:29:47 +08:00
|
|
|
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
2022-07-21 07:18:16 +08:00
|
|
|
MLIRContext *context = patterns.getContext();
|
|
|
|
|
2022-07-27 13:07:51 +08:00
|
|
|
target.addIllegalOp<AtenTransposeIntOp>();
|
|
|
|
patterns.add<ConvertAtenTransposeIntOp>(typeConverter, context);
|
2022-12-22 17:12:52 +08:00
|
|
|
target.addIllegalOp<RuntimeAssertOp>();
|
|
|
|
patterns.add<ConvertRuntimeAssertOp>(typeConverter, context);
|
2022-07-27 13:07:51 +08:00
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
#define INSERT_UNARY_PATTERN(AtenOp, StablehloOp) \
|
2023-01-12 06:40:03 +08:00
|
|
|
target.addIllegalOp<AtenOp>(); \
|
2023-02-02 21:29:47 +08:00
|
|
|
patterns.add<ConvertAtenUnaryOp<AtenOp, StablehloOp>>(typeConverter, context)
|
|
|
|
INSERT_UNARY_PATTERN(AtenNegOp, stablehlo::NegOp);
|
|
|
|
INSERT_UNARY_PATTERN(AtenLogicalNotOp, stablehlo::NotOp);
|
|
|
|
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, stablehlo::NotOp);
|
2023-05-09 13:13:00 +08:00
|
|
|
INSERT_UNARY_PATTERN(AtenAbsOp, stablehlo::AbsOp);
|
2024-04-22 10:20:49 +08:00
|
|
|
INSERT_UNARY_PATTERN(AtenExpm1Op, stablehlo::Expm1Op);
|
2023-01-12 06:40:03 +08:00
|
|
|
#undef INSERT_UNARY_PATTERN
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, StablehloOp) \
|
2022-07-27 13:07:51 +08:00
|
|
|
target.addIllegalOp<AtenOp>(); \
|
2023-02-02 21:29:47 +08:00
|
|
|
patterns.add<ConvertAtenUnaryFPOnlyOp<AtenOp, StablehloOp>>(typeConverter, \
|
|
|
|
context)
|
2023-02-07 03:14:26 +08:00
|
|
|
INSERT_UNARY_FPONLY_PATTERN(AtenCeilOp, stablehlo::CeilOp);
|
|
|
|
INSERT_UNARY_FPONLY_PATTERN(AtenFloorOp, stablehlo::FloorOp);
|
2024-03-12 08:58:20 +08:00
|
|
|
INSERT_UNARY_FPONLY_PATTERN(AtenRoundOp, stablehlo::RoundNearestEvenOp);
|
2022-07-27 13:07:51 +08:00
|
|
|
#undef INSERT_UNARY_FPONLY_PATTERN
|
2024-04-23 17:57:12 +08:00
|
|
|
|
|
|
|
#define INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenOp, StablehloOp) \
|
|
|
|
target.addIllegalOp<AtenOp>(); \
|
|
|
|
patterns.add<ConvertAtenUnaryPromoteToFPOp<AtenOp, StablehloOp>>( \
|
|
|
|
typeConverter, context)
|
|
|
|
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenLogOp, stablehlo::LogOp);
|
|
|
|
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenLog1pOp, stablehlo::Log1pOp);
|
|
|
|
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenExpOp, stablehlo::ExpOp);
|
|
|
|
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSqrtOp, stablehlo::SqrtOp);
|
|
|
|
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenRsqrtOp, stablehlo::RsqrtOp);
|
|
|
|
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSigmoidOp, stablehlo::LogisticOp);
|
|
|
|
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenTanhOp, stablehlo::TanhOp);
|
|
|
|
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSinOp, stablehlo::SineOp);
|
|
|
|
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenCosOp, stablehlo::CosineOp);
|
2024-04-26 15:47:44 +08:00
|
|
|
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenTanOp, chlo::TanOp);
|
2024-04-23 17:57:12 +08:00
|
|
|
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAsinOp, chlo::AsinOp);
|
2024-04-23 19:54:58 +08:00
|
|
|
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenSinhOp, chlo::SinhOp);
|
2024-04-23 17:57:12 +08:00
|
|
|
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAcosOp, chlo::AcosOp);
|
2024-04-23 19:54:58 +08:00
|
|
|
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenCoshOp, chlo::CoshOp);
|
2024-04-23 17:57:12 +08:00
|
|
|
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAtanOp, chlo::AtanOp);
|
2024-04-26 15:47:44 +08:00
|
|
|
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAsinhOp, chlo::AsinhOp);
|
|
|
|
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAcoshOp, chlo::AcoshOp);
|
|
|
|
INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenAtanhOp, chlo::AtanhOp);
|
2024-04-23 17:57:12 +08:00
|
|
|
#undef INSERT_UNARY_PROMOTE_TO_FP_PATTERN
|
2022-07-27 13:07:51 +08:00
|
|
|
|
|
|
|
#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \
|
|
|
|
target.addIllegalOp<AtenOp>(); \
|
|
|
|
patterns.add<ConvertAtenConstPatternOp<AtenOp, fillVal>>(typeConverter, \
|
2022-09-01 10:36:02 +08:00
|
|
|
context)
|
2022-07-27 13:07:51 +08:00
|
|
|
INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1);
|
|
|
|
INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0);
|
|
|
|
#undef INSERT_CONSTANT_FILL_PATTERN
|
|
|
|
|
2023-07-29 21:55:49 +08:00
|
|
|
#define INSERT_TENSOR_TO_SCALAR_PATTERN(AtenOp) \
|
|
|
|
target.addIllegalOp<AtenOp>(); \
|
2024-01-30 01:59:33 +08:00
|
|
|
patterns.add<ConvertAtenTensorToScalarLikeOp<AtenOp>>(typeConverter, context)
|
2023-07-29 21:55:49 +08:00
|
|
|
|
|
|
|
INSERT_TENSOR_TO_SCALAR_PATTERN(AtenIntTensorOp);
|
|
|
|
INSERT_TENSOR_TO_SCALAR_PATTERN(AtenFloatTensorOp);
|
|
|
|
INSERT_TENSOR_TO_SCALAR_PATTERN(AtenBoolTensorOp);
|
|
|
|
#undef INSERT_TENSOR_TO_SCALAR_PATTERN
|
|
|
|
|
2022-08-02 12:53:24 +08:00
|
|
|
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, ChloOp) \
|
2022-07-27 13:07:51 +08:00
|
|
|
target.addIllegalOp<AtenOp>(); \
|
2022-09-01 10:36:02 +08:00
|
|
|
patterns.add<ConvertAtenAddSubOp<AtenOp, ChloOp>>(typeConverter, context)
|
2022-08-02 12:53:24 +08:00
|
|
|
INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, chlo::BroadcastAddOp);
|
|
|
|
INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, chlo::BroadcastAddOp);
|
|
|
|
INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, chlo::BroadcastSubOp);
|
|
|
|
INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, chlo::BroadcastSubOp);
|
2022-08-17 09:07:36 +08:00
|
|
|
INSERT_BINARY_ADDSUB_PATTERN(AtenRsubScalarOp, chlo::BroadcastSubOp);
|
2022-07-27 13:07:51 +08:00
|
|
|
#undef INSERT_BINARY_ADDSUB_PATTERN
|
|
|
|
|
2022-08-02 12:53:24 +08:00
|
|
|
#define INSERT_BINARY_MULDIV_PATTERN(AtenOp, ChloOp) \
|
2022-07-27 13:07:51 +08:00
|
|
|
target.addIllegalOp<AtenOp>(); \
|
2022-09-01 10:36:02 +08:00
|
|
|
patterns.add<ConvertAtenMulDivOp<AtenOp, ChloOp>>(typeConverter, context)
|
2022-08-02 12:53:24 +08:00
|
|
|
INSERT_BINARY_MULDIV_PATTERN(AtenMulTensorOp, chlo::BroadcastMulOp);
|
|
|
|
INSERT_BINARY_MULDIV_PATTERN(AtenMulScalarOp, chlo::BroadcastMulOp);
|
|
|
|
INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorOp, chlo::BroadcastDivOp);
|
2022-08-06 23:38:06 +08:00
|
|
|
INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorModeOp, chlo::BroadcastDivOp);
|
2022-08-02 12:53:24 +08:00
|
|
|
INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarOp, chlo::BroadcastDivOp);
|
2024-04-16 04:45:10 +08:00
|
|
|
INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarModeOp, chlo::BroadcastDivOp);
|
2022-11-24 14:28:34 +08:00
|
|
|
INSERT_BINARY_MULDIV_PATTERN(AtenRemainderScalarOp, chlo::BroadcastRemOp);
|
2022-08-02 12:53:24 +08:00
|
|
|
#undef INSERT_BINARY_MULDIV_PATTERN
|
2022-07-27 13:07:51 +08:00
|
|
|
|
|
|
|
#define INSERT_BINARY_COMPARE_PATTERN(AtenOp) \
|
|
|
|
target.addIllegalOp<AtenOp>(); \
|
2022-09-01 10:36:02 +08:00
|
|
|
patterns.add<ConvertAtenCompareOp<AtenOp>>(typeConverter, context)
|
2022-08-02 12:53:24 +08:00
|
|
|
|
2022-07-27 13:07:51 +08:00
|
|
|
INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp);
|
|
|
|
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp);
|
2022-11-24 14:28:34 +08:00
|
|
|
INSERT_BINARY_COMPARE_PATTERN(AtenGeTensorOp);
|
|
|
|
INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp);
|
2022-07-27 13:07:51 +08:00
|
|
|
INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp);
|
|
|
|
INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp);
|
2022-11-24 14:28:34 +08:00
|
|
|
INSERT_BINARY_COMPARE_PATTERN(AtenLeTensorOp);
|
|
|
|
INSERT_BINARY_COMPARE_PATTERN(AtenLeScalarOp);
|
2022-07-27 13:07:51 +08:00
|
|
|
INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp);
|
|
|
|
INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp);
|
|
|
|
INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp);
|
|
|
|
INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp);
|
|
|
|
#undef INSERT_BINARY_COMPARE_PATTERN
|
|
|
|
|
2023-01-04 10:11:25 +08:00
|
|
|
#define INSERT_BINARY_LOGICAL_PATTERN(AtenOp, ChloOp) \
|
|
|
|
target.addIllegalOp<AtenOp>(); \
|
|
|
|
patterns.add<ConvertAtenLogicalBinaryOp<AtenOp, ChloOp>>(typeConverter, \
|
|
|
|
context)
|
|
|
|
|
|
|
|
INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalOrOp, chlo::BroadcastOrOp);
|
|
|
|
INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalAndOp, chlo::BroadcastAndOp);
|
|
|
|
INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalXorOp, chlo::BroadcastXorOp);
|
2024-04-08 20:24:17 +08:00
|
|
|
INSERT_BINARY_LOGICAL_PATTERN(AtenBitwiseAndScalarOp, chlo::BroadcastAndOp);
|
|
|
|
|
2023-01-04 10:11:25 +08:00
|
|
|
#undef INSERT_BINARY_LOGICAL_PATTERN
|
|
|
|
|
2022-07-21 07:18:16 +08:00
|
|
|
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
|
|
|
target.addIllegalOp<AtenOp>(); \
|
2022-09-01 10:36:02 +08:00
|
|
|
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
|
|
|
|
|
|
|
|
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
|
|
|
|
INSERT_ATENOP_PATTERN(AtenPermuteOp);
|
|
|
|
|
2022-07-27 13:07:51 +08:00
|
|
|
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
|
2023-07-20 16:46:44 +08:00
|
|
|
INSERT_ATENOP_PATTERN(AtenTensorIntOp);
|
2022-07-27 13:07:51 +08:00
|
|
|
INSERT_ATENOP_PATTERN(AtenReciprocalOp);
|
2023-02-17 12:26:46 +08:00
|
|
|
INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp);
|
2024-04-26 15:47:44 +08:00
|
|
|
INSERT_ATENOP_PATTERN(AtenPowScalarOp);
|
2022-07-27 13:07:51 +08:00
|
|
|
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
|
2023-09-04 14:04:09 +08:00
|
|
|
INSERT_ATENOP_PATTERN(AtenScalarImplicitOp);
|
2022-07-27 13:07:51 +08:00
|
|
|
INSERT_ATENOP_PATTERN(AtenContiguousOp);
|
2024-04-24 14:15:11 +08:00
|
|
|
INSERT_ATENOP_PATTERN(AtenConstantPadNdOp);
|
2022-07-21 07:18:16 +08:00
|
|
|
|
2022-07-27 13:07:51 +08:00
|
|
|
INSERT_ATENOP_PATTERN(AtenReluOp);
|
2022-08-02 15:01:30 +08:00
|
|
|
INSERT_ATENOP_PATTERN(AtenGeluOp);
|
2024-04-23 19:06:55 +08:00
|
|
|
INSERT_ATENOP_PATTERN(AtenLog2Op);
|
|
|
|
INSERT_ATENOP_PATTERN(AtenLog10Op);
|
2022-07-27 13:07:51 +08:00
|
|
|
INSERT_ATENOP_PATTERN(AtenErfOp);
|
2022-12-21 20:09:43 +08:00
|
|
|
INSERT_ATENOP_PATTERN(AtenGeluBackwardOp);
|
2022-07-27 13:07:51 +08:00
|
|
|
|
2022-08-10 13:12:34 +08:00
|
|
|
INSERT_ATENOP_PATTERN(AtenCatOp);
|
2022-09-16 15:09:21 +08:00
|
|
|
INSERT_ATENOP_PATTERN(AtenClampOp);
|
2024-04-19 10:55:27 +08:00
|
|
|
INSERT_ATENOP_PATTERN(AtenClampTensorOp);
|
2022-09-20 22:31:24 +08:00
|
|
|
INSERT_ATENOP_PATTERN(AtenArangeStartStepOp);
|
2022-08-10 13:12:34 +08:00
|
|
|
|
2022-07-27 13:07:51 +08:00
|
|
|
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
|
|
|
|
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
|
2022-08-23 16:47:21 +08:00
|
|
|
INSERT_ATENOP_PATTERN(AtenNumelOp);
|
2022-09-23 10:24:36 +08:00
|
|
|
INSERT_ATENOP_PATTERN(AtenSizeIntOp);
|
|
|
|
INSERT_ATENOP_PATTERN(AtenToDtypeOp);
|
2022-11-24 14:28:34 +08:00
|
|
|
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
|
2023-05-05 00:55:03 +08:00
|
|
|
INSERT_ATENOP_PATTERN(AtenPowTensorTensorOp);
|
2024-05-11 15:33:37 +08:00
|
|
|
|
2023-05-19 10:07:35 +08:00
|
|
|
INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp);
|
2023-05-12 07:41:46 +08:00
|
|
|
INSERT_ATENOP_PATTERN(AtenFillScalarOp);
|
2023-06-15 10:27:34 +08:00
|
|
|
INSERT_ATENOP_PATTERN(AtenFlipOp);
|
2024-04-21 00:03:37 +08:00
|
|
|
INSERT_ATENOP_PATTERN(AtenRemainderTensorOp);
|
2024-04-21 08:39:36 +08:00
|
|
|
INSERT_ATENOP_PATTERN(AtenFmodTensorOp);
|
2024-04-26 09:20:49 +08:00
|
|
|
INSERT_ATENOP_PATTERN(AtenBitwiseLeftShiftTensorOp);
|
|
|
|
INSERT_ATENOP_PATTERN(AtenBitwiseRightShiftTensorOp);
|
2024-05-20 15:49:24 +08:00
|
|
|
|
|
|
|
INSERT_ATENOP_PATTERN(AtenTrilOp);
|
2022-07-27 13:07:51 +08:00
|
|
|
#undef INSERT_ATENOP_PATTERN
|
2022-09-23 20:39:15 +08:00
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \
|
2022-09-23 20:39:15 +08:00
|
|
|
target.addIllegalOp<AtenOp>(); \
|
2023-02-02 21:29:47 +08:00
|
|
|
patterns.add<ConvertAtenBinaryBroadcastOp<AtenOp, StablehloOp>>( \
|
|
|
|
typeConverter, context)
|
2022-09-23 20:39:15 +08:00
|
|
|
INSERT_BINARY_BROADCAST_PATTERN(AtenMaximumOp, chlo::BroadcastMaxOp);
|
|
|
|
INSERT_BINARY_BROADCAST_PATTERN(AtenMinimumOp, chlo::BroadcastMinOp);
|
|
|
|
INSERT_BINARY_BROADCAST_PATTERN(Aten__And__TensorOp, chlo::BroadcastAndOp);
|
2022-11-24 14:28:34 +08:00
|
|
|
INSERT_BINARY_BROADCAST_PATTERN(AtenBitwiseAndTensorOp, chlo::BroadcastAndOp);
|
2023-01-12 06:40:03 +08:00
|
|
|
INSERT_BINARY_BROADCAST_PATTERN(AtenBitwiseOrTensorOp, chlo::BroadcastOrOp);
|
|
|
|
INSERT_BINARY_BROADCAST_PATTERN(AtenBitwiseXorTensorOp, chlo::BroadcastXorOp);
|
2024-04-26 15:47:44 +08:00
|
|
|
INSERT_BINARY_BROADCAST_PATTERN(AtenAtan2Op, chlo::BroadcastAtan2Op);
|
2022-09-23 20:39:15 +08:00
|
|
|
#undef INSERT_BINARY_BROADCAST_PATTERN
|
2022-08-03 08:16:31 +08:00
|
|
|
}
|