2022-08-04 12:34:22 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
2022-08-04 12:34:22 +08:00
|
|
|
|
|
|
|
#include "../PassDetail.h"
|
2023-02-02 21:29:47 +08:00
|
|
|
#include "PopulatePatterns.h"
|
|
|
|
|
2022-10-05 21:28:06 +08:00
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
2022-08-04 12:34:22 +08:00
|
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
2022-08-31 03:44:00 +08:00
|
|
|
#include "stablehlo/dialect/ChloOps.h"
|
2023-02-02 21:29:47 +08:00
|
|
|
#include "stablehlo/dialect/StablehloOps.h"
|
2023-07-25 14:09:53 +08:00
|
|
|
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
|
2022-08-04 12:34:22 +08:00
|
|
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
|
|
|
#include <numeric>
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::torch;
|
|
|
|
using namespace mlir::torch::Torch;
|
2023-02-02 21:29:47 +08:00
|
|
|
using namespace mlir::torch::torch_to_stablehlo;
|
2022-08-04 12:34:22 +08:00
|
|
|
|
|
|
|
static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
|
|
|
PatternRewriter &rewriter) {
|
|
|
|
auto constType = RankedTensorType::get({}, elementTy);
|
|
|
|
// Avg pooling
|
2024-01-30 01:59:33 +08:00
|
|
|
if (isa<AtenAvgPool1dOp, AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp,
|
2024-05-01 00:06:13 +08:00
|
|
|
AtenAvgPool3dOp, AtenCumsumOp>(op)) {
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<mlir::FloatType>(elementTy)) {
|
2022-08-04 12:34:22 +08:00
|
|
|
auto constAttr = DenseElementsAttr::get(
|
|
|
|
constType, {APFloat::getZero(
|
2024-04-11 21:47:35 +08:00
|
|
|
cast<mlir::FloatType>(elementTy).getFloatSemantics(),
|
2022-08-04 12:34:22 +08:00
|
|
|
/*negative=*/false)});
|
2023-02-02 21:29:47 +08:00
|
|
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
|
|
|
constAttr);
|
2024-04-11 21:47:35 +08:00
|
|
|
} else if (isa<mlir::IntegerType>(elementTy) &&
|
2022-08-04 12:34:22 +08:00
|
|
|
elementTy.getIntOrFloatBitWidth() != 8) {
|
|
|
|
auto constAttr = DenseElementsAttr::get(
|
|
|
|
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
|
2023-02-02 21:29:47 +08:00
|
|
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
|
|
|
constAttr);
|
2022-08-04 12:34:22 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Max pooling
|
2024-05-01 00:06:13 +08:00
|
|
|
if (isa<AtenMaxPool1dOp, AtenMaxPool2dOp, AtenMaxPool3dOp,
|
|
|
|
AtenMaxPool2dWithIndicesOp>(op)) {
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<mlir::FloatType>(elementTy)) {
|
2022-08-04 12:34:22 +08:00
|
|
|
auto constAttr = DenseElementsAttr::get(
|
2024-04-11 21:47:35 +08:00
|
|
|
constType,
|
|
|
|
{APFloat::getInf(cast<mlir::FloatType>(elementTy).getFloatSemantics(),
|
|
|
|
/*negative=*/true)});
|
2023-02-02 21:29:47 +08:00
|
|
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
|
|
|
constAttr);
|
2024-04-11 21:47:35 +08:00
|
|
|
} else if (isa<mlir::IntegerType>(elementTy) &&
|
2022-08-04 12:34:22 +08:00
|
|
|
elementTy.getIntOrFloatBitWidth() != 8) {
|
|
|
|
auto constAttr = DenseElementsAttr::get(
|
|
|
|
constType,
|
|
|
|
{APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())});
|
2023-02-02 21:29:47 +08:00
|
|
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
|
|
|
constAttr);
|
2022-08-04 12:34:22 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
op->emitError("unimplemented lowering in AtenPoolingOp");
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
// AtenMaxPool2dWithIndicesOp
|
|
|
|
template <>
|
2022-09-01 10:36:02 +08:00
|
|
|
LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
2022-08-04 12:34:22 +08:00
|
|
|
AtenMaxPool2dWithIndicesOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = adaptor.getSelf();
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inputTy = cast<RankedTensorType>(input.getType());
|
2022-08-04 12:34:22 +08:00
|
|
|
auto inputElemTy = inputTy.getElementType();
|
|
|
|
auto inputShape = inputTy.getShape();
|
|
|
|
auto inputRank = inputTy.getRank();
|
|
|
|
auto outValTy =
|
2024-04-28 05:00:56 +08:00
|
|
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType(0)));
|
2022-08-04 12:34:22 +08:00
|
|
|
auto outIdxTy =
|
2024-04-28 05:00:56 +08:00
|
|
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType(1)));
|
2022-08-04 12:34:22 +08:00
|
|
|
|
|
|
|
if (inputRank <= 2) {
|
|
|
|
return op.emitError(
|
|
|
|
"max_pooling2d only supports inputs with rank higher than 2");
|
|
|
|
}
|
|
|
|
SmallVector<int64_t, 2> padding, kernelSize, stride, dilation;
|
|
|
|
bool ceilMode = false;
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!(matchPattern(op.getKernelSize(),
|
2022-11-17 04:33:12 +08:00
|
|
|
m_TorchListOfConstantInts(kernelSize)))) {
|
2022-08-04 12:34:22 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "non-const int kernel size unsupported!");
|
|
|
|
}
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) {
|
2022-08-04 12:34:22 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!");
|
|
|
|
}
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) {
|
2022-08-04 12:34:22 +08:00
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"non-const int padding unsupported!");
|
|
|
|
}
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!(matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilation)))) {
|
2022-08-04 12:34:22 +08:00
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"non-const int dilation unsupported!");
|
|
|
|
}
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) {
|
2022-08-04 12:34:22 +08:00
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"non-const bool ceil_mode unsupported!");
|
|
|
|
}
|
|
|
|
|
|
|
|
// prepend 1 to kernelSize, stride, dilation until they are of same rank as
|
|
|
|
// input
|
2023-02-02 21:29:47 +08:00
|
|
|
SmallVector<int64_t> stablehloStride(inputRank, 1);
|
|
|
|
SmallVector<int64_t> stablehloDilation(inputRank, 1);
|
|
|
|
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
|
|
|
|
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
2022-08-04 12:34:22 +08:00
|
|
|
std::copy(dilation.begin(), dilation.end(),
|
2023-02-02 21:29:47 +08:00
|
|
|
stablehloDilation.begin() + inputRank - 2);
|
|
|
|
std::copy(stride.begin(), stride.end(),
|
|
|
|
stablehloStride.begin() + inputRank - 2);
|
2022-08-04 12:34:22 +08:00
|
|
|
std::copy(kernelSize.begin(), kernelSize.end(),
|
2023-02-02 21:29:47 +08:00
|
|
|
stablehloKernelSize.begin() + inputRank - 2);
|
2022-08-04 12:34:22 +08:00
|
|
|
|
|
|
|
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
|
|
|
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
|
|
|
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
|
|
|
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
2022-08-04 12:34:22 +08:00
|
|
|
|
Bump stablehlo to openxla/stablehlo@fd52182f76cadb82f2064fe5fc49a4fb4347a826 (#2821)
With the recent LLVM integrate and changes from
https://github.com/llvm/llvm-project/pull/78260, we hit this build error
in Stablehlo (which is quite old).
```
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1020:14: error: no member named 'startRootUpdate' in 'mlir::PatternRewriter'
rewriter.startRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1026:16: error: no member named 'finalizeRootUpdate' in 'mlir::PatternRewriter'
rewriter.finalizeRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1029:16: error: no member named 'cancelRootUpdate' in 'mlir::PatternRewriter'
rewriter.cancelRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1108:14: error: no member named 'updateRootInPlace' in 'mlir::PatternRewriter'
rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; });
~~~~~~~~ ^
4 errors generated.
Target @torch-mlir//:torch-mlir-opt failed to build
```
I'm still puzzled as to how this didn't fail with the CMake merge gating
CI (do we not test Stablehlo builds/tests?). In any case, bumping our
submodule to https://github.com/openxla/stablehlo/pull/1918 fixes it.
It exposes a new failing lit test in TorchToStablehlo though, that I
have looped stablehlo developers into
([here](https://discord.com/channels/999073994483433573/999074539138990131/1201235845391331419)).
```
bazel run @torch-mlir//test/Conversion:TorchToStablehlo/scatter.mlir.test
...external/torch-mlir/test/Conversion/TorchToStablehlo/scatter.mlir
within split at <stdin>:1 offset :33:8: error: unexpected error: Expects non-empty reduction block for type inference
%0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64>
^
LLVM ERROR: Failed to infer result type(s).
```
Bazel CI:
https://github.com/sjain-stanford/torch-mlir/actions/runs/7732673480/job/21083102228
2024-02-01 06:21:17 +08:00
|
|
|
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
|
|
|
|
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
|
|
|
|
DenseI64ArrayAttr baseDilations;
|
|
|
|
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
|
2022-08-04 12:34:22 +08:00
|
|
|
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
|
|
|
RankedTensorType::get(
|
|
|
|
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
|
|
|
rewriter.getI64Type()),
|
2023-02-02 21:29:47 +08:00
|
|
|
stablehloPadding);
|
2022-08-04 12:34:22 +08:00
|
|
|
|
2022-09-01 10:36:02 +08:00
|
|
|
const auto &options = getOptions();
|
|
|
|
auto inputShapeInfo =
|
2023-02-02 21:29:47 +08:00
|
|
|
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
2022-08-04 12:34:22 +08:00
|
|
|
if (failed(inputShapeInfo)) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "failed to get dimension sizes of the input");
|
|
|
|
}
|
|
|
|
auto inputShapeVec = *inputShapeInfo;
|
|
|
|
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
|
|
|
op->getLoc(), inputShapeVec);
|
|
|
|
|
|
|
|
SmallVector<Value> initIndexShapeVec;
|
|
|
|
for (int64_t i = 0; i < inputRank - 2; i++)
|
|
|
|
initIndexShapeVec.push_back(inputShapeVec[i]);
|
|
|
|
initIndexShapeVec.push_back(rewriter.create<mlir::arith::MulIOp>(
|
|
|
|
op->getLoc(), inputShapeVec[inputRank - 1],
|
|
|
|
inputShapeVec[inputRank - 2]));
|
|
|
|
auto initIndexShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
|
|
|
op->getLoc(), initIndexShapeVec);
|
|
|
|
|
|
|
|
SmallVector<int64_t> initIndexShapeForType(inputShape.begin(),
|
|
|
|
inputShape.end() - 2);
|
2022-12-02 12:38:28 +08:00
|
|
|
if (inputShape[inputRank - 1] == ShapedType::kDynamic ||
|
|
|
|
inputShape[inputRank - 2] == ShapedType::kDynamic) {
|
|
|
|
initIndexShapeForType.push_back(ShapedType::kDynamic);
|
2022-08-04 12:34:22 +08:00
|
|
|
} else {
|
|
|
|
initIndexShapeForType.push_back(inputShape[inputRank - 1] *
|
|
|
|
inputShape[inputRank - 2]);
|
|
|
|
}
|
|
|
|
|
|
|
|
auto initIndexTensor =
|
|
|
|
rewriter
|
2023-02-02 21:29:47 +08:00
|
|
|
.create<stablehlo::DynamicIotaOp>(
|
2022-08-04 12:34:22 +08:00
|
|
|
op->getLoc(),
|
|
|
|
RankedTensorType::get(initIndexShapeForType,
|
|
|
|
rewriter.getI64Type()),
|
|
|
|
initIndexShapeTensor, static_cast<uint64_t>(inputRank - 2))
|
|
|
|
.getResult();
|
|
|
|
|
|
|
|
auto indexTensor =
|
|
|
|
rewriter
|
2023-02-02 21:29:47 +08:00
|
|
|
.create<stablehlo::DynamicReshapeOp>(
|
2022-08-04 12:34:22 +08:00
|
|
|
op->getLoc(),
|
|
|
|
RankedTensorType::get(inputShape, rewriter.getI64Type()),
|
|
|
|
initIndexTensor, inputShapeTensor)
|
|
|
|
.getResult();
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
Value initIdx = hlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
|
2022-08-04 12:34:22 +08:00
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
|
2022-08-04 12:34:22 +08:00
|
|
|
op->getLoc(), mlir::TypeRange{outValTy, outIdxTy},
|
|
|
|
mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx},
|
|
|
|
windowDimensions, windowStrides, baseDilations, windowDilations, pad);
|
|
|
|
|
2022-10-18 12:22:53 +08:00
|
|
|
Block &block = reduceWindowOp.getBody().emplaceBlock();
|
2022-08-04 12:34:22 +08:00
|
|
|
|
|
|
|
// Add bb argument
|
|
|
|
auto blockValArgumentType = RankedTensorType::get({}, inputElemTy);
|
|
|
|
auto blockIdxArgumentType = RankedTensorType::get({}, rewriter.getI64Type());
|
|
|
|
auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type());
|
|
|
|
block.addArgument(blockValArgumentType, op->getLoc());
|
|
|
|
block.addArgument(blockIdxArgumentType, op->getLoc());
|
|
|
|
block.addArgument(blockValArgumentType, op->getLoc());
|
|
|
|
block.addArgument(blockIdxArgumentType, op->getLoc());
|
|
|
|
auto *firstValArg = block.args_begin();
|
|
|
|
auto *firstIdxArg = std::next(firstValArg);
|
|
|
|
auto *secondValArg = std::next(firstIdxArg);
|
|
|
|
auto *secondIdxArg = std::next(secondValArg);
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
stablehlo::ComparisonTypeAttr compareTypeAttr;
|
2022-08-04 12:34:22 +08:00
|
|
|
if (inputTy.getElementType().isa<mlir::FloatType>()) {
|
2023-02-02 21:29:47 +08:00
|
|
|
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
|
|
|
rewriter.getContext(), stablehlo::ComparisonType::FLOAT);
|
2022-08-04 12:34:22 +08:00
|
|
|
} else if (inputTy.getElementType().isa<mlir::IntegerType>()) {
|
2023-02-02 21:29:47 +08:00
|
|
|
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
|
|
|
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
|
2022-08-04 12:34:22 +08:00
|
|
|
}
|
2023-02-02 21:29:47 +08:00
|
|
|
stablehlo::ComparisonDirectionAttr compareGeDirectionAttr =
|
|
|
|
stablehlo::ComparisonDirectionAttr::get(
|
|
|
|
rewriter.getContext(), stablehlo::ComparisonDirection::GE);
|
|
|
|
stablehlo::ComparisonDirectionAttr compareEqDirectionAttr =
|
|
|
|
stablehlo::ComparisonDirectionAttr::get(
|
|
|
|
rewriter.getContext(), stablehlo::ComparisonDirection::EQ);
|
2022-08-04 12:34:22 +08:00
|
|
|
|
|
|
|
{
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
|
|
rewriter.setInsertionPointToStart(&block);
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
Value compareGeResult = rewriter.create<stablehlo::CompareOp>(
|
2022-08-04 12:34:22 +08:00
|
|
|
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
|
|
|
compareGeDirectionAttr, compareTypeAttr);
|
2023-02-02 21:29:47 +08:00
|
|
|
Value retValResult = rewriter.create<stablehlo::SelectOp>(
|
2022-08-04 12:34:22 +08:00
|
|
|
op->getLoc(), compareGeResult, *firstValArg, *secondValArg);
|
|
|
|
|
|
|
|
// Get smaller index if compared values are equal.
|
2023-02-02 21:29:47 +08:00
|
|
|
Value compareEqResult = rewriter.create<stablehlo::CompareOp>(
|
2022-08-04 12:34:22 +08:00
|
|
|
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
|
|
|
compareEqDirectionAttr, compareTypeAttr);
|
2023-02-02 21:29:47 +08:00
|
|
|
Value minIdx = rewriter.create<stablehlo::MinOp>(op->getLoc(), *firstIdxArg,
|
|
|
|
*secondIdxArg);
|
|
|
|
Value idxWithGeVal = rewriter.create<stablehlo::SelectOp>(
|
2022-08-04 12:34:22 +08:00
|
|
|
op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg);
|
2023-02-02 21:29:47 +08:00
|
|
|
Value retIdxResult = rewriter.create<stablehlo::SelectOp>(
|
2022-08-04 12:34:22 +08:00
|
|
|
op->getLoc(), compareEqResult, minIdx, idxWithGeVal);
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
rewriter.create<stablehlo::ReturnOp>(
|
2022-08-04 12:34:22 +08:00
|
|
|
op->getLoc(), mlir::ValueRange{retValResult, retIdxResult});
|
|
|
|
}
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, reduceWindowOp.getResults());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2024-05-01 00:06:13 +08:00
|
|
|
namespace {
|
|
|
|
template <typename AtenOpT, int Dim>
|
|
|
|
class ConvertAtenMaxPoolOp : public ConvertAtenOp<AtenOpT> {
|
|
|
|
public:
|
|
|
|
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
|
|
|
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
Value input = adaptor.getSelf();
|
|
|
|
auto inputTy = cast<RankedTensorType>(input.getType());
|
|
|
|
auto inputElemTy = inputTy.getElementType();
|
|
|
|
auto inputRank = inputTy.getRank();
|
|
|
|
auto outTy = cast<RankedTensorType>(
|
|
|
|
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType()));
|
|
|
|
|
|
|
|
if (inputRank <= Dim) {
|
|
|
|
return op.emitError(
|
|
|
|
"max_pooling1d/2d only supports inputs with rank higher than 1/2");
|
|
|
|
}
|
|
|
|
SmallVector<int64_t, Dim> padding, kernelSize, stride, dilation;
|
|
|
|
bool ceilMode = false;
|
|
|
|
|
|
|
|
if (!(matchPattern(op.getKernelSize(),
|
|
|
|
m_TorchListOfConstantInts(kernelSize)))) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "non-const int kernel size unsupported!");
|
|
|
|
}
|
|
|
|
if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) {
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"non-const int stride unsupported!");
|
|
|
|
}
|
|
|
|
if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) {
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"non-const int padding unsupported!");
|
|
|
|
}
|
|
|
|
if (!(matchPattern(op.getDilation(),
|
|
|
|
m_TorchListOfConstantInts(dilation)))) {
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"non-const int dilation unsupported!");
|
|
|
|
}
|
|
|
|
if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "non-const bool ceil_mode unsupported!");
|
|
|
|
}
|
|
|
|
|
|
|
|
if (stride.empty()) {
|
|
|
|
stride = kernelSize;
|
|
|
|
}
|
|
|
|
|
|
|
|
// prepend 1 to kernelSize, stride, dilation until they are of same rank
|
|
|
|
// as input
|
|
|
|
SmallVector<int64_t> stablehloStride(inputRank, 1);
|
|
|
|
SmallVector<int64_t> stablehloDilation(inputRank, 1);
|
|
|
|
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
|
|
|
|
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
|
|
|
std::copy(dilation.begin(), dilation.end(),
|
|
|
|
stablehloDilation.begin() + inputRank - Dim);
|
|
|
|
std::copy(stride.begin(), stride.end(),
|
|
|
|
stablehloStride.begin() + inputRank - Dim);
|
|
|
|
std::copy(kernelSize.begin(), kernelSize.end(),
|
|
|
|
stablehloKernelSize.begin() + inputRank - Dim);
|
|
|
|
|
|
|
|
Value initVal =
|
|
|
|
createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
|
|
|
|
|
|
|
if (Dim == 1) {
|
|
|
|
stablehloPadding[stablehloPadding.size() - 2] = padding[0];
|
|
|
|
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
|
|
|
|
} else if (Dim == 2) {
|
|
|
|
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
|
|
|
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
|
|
|
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
|
|
|
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
|
|
|
} else if (Dim == 3) {
|
|
|
|
stablehloPadding[stablehloPadding.size() - 6] = padding[0];
|
|
|
|
stablehloPadding[stablehloPadding.size() - 5] = padding[0];
|
|
|
|
stablehloPadding[stablehloPadding.size() - 4] = padding[1];
|
|
|
|
stablehloPadding[stablehloPadding.size() - 3] = padding[1];
|
|
|
|
stablehloPadding[stablehloPadding.size() - 2] = padding[2];
|
|
|
|
stablehloPadding[stablehloPadding.size() - 1] = padding[2];
|
|
|
|
} else {
|
|
|
|
assert(false && "Unsupported pooling dimension");
|
|
|
|
}
|
|
|
|
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
|
|
|
|
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
|
|
|
|
DenseI64ArrayAttr baseDilations;
|
|
|
|
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
|
|
|
|
|
|
|
|
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
|
|
|
RankedTensorType::get(
|
|
|
|
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
|
|
|
rewriter.getI64Type()),
|
|
|
|
stablehloPadding);
|
|
|
|
|
|
|
|
auto reduceWindowOp = rewriter.create<stablehlo::ReduceWindowOp>(
|
|
|
|
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
|
|
|
|
baseDilations, windowDilations, pad);
|
|
|
|
|
|
|
|
Block &block = reduceWindowOp.getBody().emplaceBlock();
|
|
|
|
|
|
|
|
// Add bb argument
|
|
|
|
auto blockArgumentType = RankedTensorType::get({}, inputElemTy);
|
|
|
|
block.addArgument(blockArgumentType, op->getLoc());
|
|
|
|
block.addArgument(blockArgumentType, op->getLoc());
|
|
|
|
auto *firstArg = block.args_begin();
|
|
|
|
auto secondArg = block.args_rbegin();
|
|
|
|
|
|
|
|
{
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
|
|
rewriter.setInsertionPointToStart(&block);
|
|
|
|
|
|
|
|
Value result = rewriter.create<stablehlo::MaxOp>(op->getLoc(), *firstArg,
|
|
|
|
*secondArg);
|
|
|
|
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), result);
|
|
|
|
}
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, reduceWindowOp.getResults());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
2023-07-25 14:09:53 +08:00
|
|
|
namespace {
|
|
|
|
template <typename AtenOpT, int Dim>
|
|
|
|
class ConvertAtenAvgPoolOp : public ConvertAtenOp<AtenOpT> {
|
|
|
|
public:
|
|
|
|
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
|
|
|
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
Value input = adaptor.getSelf();
|
2024-04-28 05:00:56 +08:00
|
|
|
RankedTensorType inputTy = cast<RankedTensorType>(input.getType());
|
2023-07-25 14:09:53 +08:00
|
|
|
Type inputElemTy = inputTy.getElementType();
|
|
|
|
int64_t inputRank = inputTy.getRank();
|
|
|
|
RankedTensorType outTy = ConvertAtenOp<AtenOpT>::getTypeConverter()
|
2024-01-30 01:59:33 +08:00
|
|
|
->convertType(op.getType())
|
|
|
|
.template cast<RankedTensorType>();
|
2023-07-25 14:09:53 +08:00
|
|
|
auto outShape = outTy.getShape();
|
|
|
|
|
|
|
|
if (inputRank <= Dim) {
|
2024-05-01 00:06:13 +08:00
|
|
|
return op.emitError("avg_pooling1d/2d/3d only supports inputs with rank "
|
|
|
|
"higher than 1/2/3");
|
2023-07-25 14:09:53 +08:00
|
|
|
}
|
|
|
|
SmallVector<int64_t, Dim> padding, kernelSize, stride;
|
|
|
|
bool ceilMode = false;
|
|
|
|
bool countIncludePad = true;
|
|
|
|
|
|
|
|
if (!(matchPattern(op.getKernelSize(),
|
2024-01-30 01:59:33 +08:00
|
|
|
m_TorchListOfConstantInts(kernelSize)))) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "non-const int kernel size unsupported!");
|
2023-07-25 14:09:53 +08:00
|
|
|
}
|
|
|
|
if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) {
|
2024-01-30 01:59:33 +08:00
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"non-const int stride unsupported!");
|
2023-07-25 14:09:53 +08:00
|
|
|
}
|
|
|
|
if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) {
|
2024-01-30 01:59:33 +08:00
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"non-const int padding unsupported!");
|
2023-07-25 14:09:53 +08:00
|
|
|
}
|
|
|
|
if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) {
|
2024-01-30 01:59:33 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "non-const bool ceil_mode unsupported!");
|
2023-07-25 14:09:53 +08:00
|
|
|
}
|
|
|
|
if (!(matchPattern(op.getCountIncludePad(),
|
2024-01-30 01:59:33 +08:00
|
|
|
m_TorchConstantBool(&countIncludePad)))) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "non-const bool count_include_pad unsupported!");
|
2023-07-25 14:09:53 +08:00
|
|
|
}
|
2022-08-04 12:34:22 +08:00
|
|
|
|
2024-05-01 00:06:13 +08:00
|
|
|
if (stride.empty()) {
|
|
|
|
stride = kernelSize;
|
|
|
|
}
|
|
|
|
|
2023-07-25 14:09:53 +08:00
|
|
|
if constexpr (std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
|
2024-01-30 01:59:33 +08:00
|
|
|
if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride())))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "only None divisor_override supported for now!");
|
2023-07-25 14:09:53 +08:00
|
|
|
}
|
2022-08-04 12:34:22 +08:00
|
|
|
|
2023-07-25 14:09:53 +08:00
|
|
|
// Prepend 1 to kernelSize, stride, dilation until they are of same rank
|
|
|
|
// as input
|
|
|
|
SmallVector<int64_t> stablehloStride(inputRank, 1);
|
|
|
|
SmallVector<int64_t> stablehloDilation(inputRank, 1);
|
|
|
|
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
|
|
|
|
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
|
|
|
|
|
|
|
std::copy(stride.begin(), stride.end(),
|
2024-01-30 01:59:33 +08:00
|
|
|
stablehloStride.begin() + inputRank - Dim);
|
2023-07-25 14:09:53 +08:00
|
|
|
std::copy(kernelSize.begin(), kernelSize.end(),
|
2024-01-30 01:59:33 +08:00
|
|
|
stablehloKernelSize.begin() + inputRank - Dim);
|
2023-07-25 14:09:53 +08:00
|
|
|
if (Dim == 1) {
|
2024-01-30 01:59:33 +08:00
|
|
|
stablehloPadding[stablehloPadding.size() - 2] = padding[0];
|
|
|
|
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
|
2024-05-01 00:06:13 +08:00
|
|
|
} else if (Dim == 2) {
|
2024-01-30 01:59:33 +08:00
|
|
|
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
|
|
|
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
|
|
|
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
|
|
|
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
2024-05-01 00:06:13 +08:00
|
|
|
} else if (Dim == 3) {
|
|
|
|
stablehloPadding[stablehloPadding.size() - 6] = padding[0];
|
|
|
|
stablehloPadding[stablehloPadding.size() - 5] = padding[0];
|
|
|
|
stablehloPadding[stablehloPadding.size() - 4] = padding[1];
|
|
|
|
stablehloPadding[stablehloPadding.size() - 3] = padding[1];
|
|
|
|
stablehloPadding[stablehloPadding.size() - 2] = padding[2];
|
|
|
|
stablehloPadding[stablehloPadding.size() - 1] = padding[2];
|
|
|
|
} else {
|
|
|
|
assert(false && "Unsupported pooling dimension");
|
2023-07-25 14:09:53 +08:00
|
|
|
}
|
2022-08-04 12:34:22 +08:00
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
Value initVal =
|
|
|
|
createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
2023-07-25 14:09:53 +08:00
|
|
|
|
Bump stablehlo to openxla/stablehlo@fd52182f76cadb82f2064fe5fc49a4fb4347a826 (#2821)
With the recent LLVM integrate and changes from
https://github.com/llvm/llvm-project/pull/78260, we hit this build error
in Stablehlo (which is quite old).
```
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1020:14: error: no member named 'startRootUpdate' in 'mlir::PatternRewriter'
rewriter.startRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1026:16: error: no member named 'finalizeRootUpdate' in 'mlir::PatternRewriter'
rewriter.finalizeRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1029:16: error: no member named 'cancelRootUpdate' in 'mlir::PatternRewriter'
rewriter.cancelRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1108:14: error: no member named 'updateRootInPlace' in 'mlir::PatternRewriter'
rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; });
~~~~~~~~ ^
4 errors generated.
Target @torch-mlir//:torch-mlir-opt failed to build
```
I'm still puzzled as to how this didn't fail with the CMake merge gating
CI (do we not test Stablehlo builds/tests?). In any case, bumping our
submodule to https://github.com/openxla/stablehlo/pull/1918 fixes it.
It exposes a new failing lit test in TorchToStablehlo though, that I
have looped stablehlo developers into
([here](https://discord.com/channels/999073994483433573/999074539138990131/1201235845391331419)).
```
bazel run @torch-mlir//test/Conversion:TorchToStablehlo/scatter.mlir.test
...external/torch-mlir/test/Conversion/TorchToStablehlo/scatter.mlir
within split at <stdin>:1 offset :33:8: error: unexpected error: Expects non-empty reduction block for type inference
%0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64>
^
LLVM ERROR: Failed to infer result type(s).
```
Bazel CI:
https://github.com/sjain-stanford/torch-mlir/actions/runs/7732673480/job/21083102228
2024-02-01 06:21:17 +08:00
|
|
|
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
|
|
|
|
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
|
|
|
|
DenseI64ArrayAttr baseDilations;
|
|
|
|
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
|
2023-07-25 14:09:53 +08:00
|
|
|
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
|
|
|
RankedTensorType::get(
|
|
|
|
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
|
|
|
rewriter.getI64Type()),
|
|
|
|
stablehloPadding);
|
|
|
|
|
|
|
|
auto reduceWindowSum = rewriter.create<stablehlo::ReduceWindowOp>(
|
|
|
|
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
|
|
|
|
baseDilations, windowDilations, pad);
|
|
|
|
|
|
|
|
Block &sumBlock = reduceWindowSum.getBody().emplaceBlock();
|
|
|
|
|
|
|
|
// Add bb argument
|
|
|
|
auto blockArgumentType = RankedTensorType::get({}, inputElemTy);
|
|
|
|
sumBlock.addArgument(blockArgumentType, op->getLoc());
|
|
|
|
sumBlock.addArgument(blockArgumentType, op->getLoc());
|
|
|
|
auto firstArg = *sumBlock.args_begin();
|
|
|
|
auto secondArg = *sumBlock.args_rbegin();
|
|
|
|
|
|
|
|
{
|
2024-01-30 01:59:33 +08:00
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
|
|
rewriter.setInsertionPointToStart(&sumBlock);
|
2023-07-25 14:09:53 +08:00
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
Value sumResult =
|
|
|
|
rewriter.create<stablehlo::AddOp>(op->getLoc(), firstArg, secondArg);
|
|
|
|
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
|
2023-07-25 14:09:53 +08:00
|
|
|
}
|
2022-08-04 12:34:22 +08:00
|
|
|
|
2023-07-25 14:09:53 +08:00
|
|
|
// Use kernel size as the divisor
|
|
|
|
if (countIncludePad) {
|
2024-01-30 01:59:33 +08:00
|
|
|
Value divisor;
|
|
|
|
if (Dim == 1) {
|
|
|
|
divisor =
|
|
|
|
hlo::getConstTensor<int64_t>(rewriter, op, {kernelSize[0]}, {})
|
|
|
|
.value();
|
2024-05-01 00:06:13 +08:00
|
|
|
} else if (Dim == 2) {
|
2024-01-30 01:59:33 +08:00
|
|
|
divisor = hlo::getConstTensor<int64_t>(
|
|
|
|
rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
|
|
|
|
.value();
|
2024-05-01 00:06:13 +08:00
|
|
|
} else if (Dim == 3) {
|
|
|
|
divisor = hlo::getConstTensor<int64_t>(
|
|
|
|
rewriter, op,
|
|
|
|
{kernelSize[0] * kernelSize[1] * kernelSize[2]}, {})
|
|
|
|
.value();
|
|
|
|
} else {
|
|
|
|
assert(false && "Unsupported pooling dimension");
|
2024-01-30 01:59:33 +08:00
|
|
|
}
|
|
|
|
divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy);
|
Bump stablehlo to openxla/stablehlo@fd52182f76cadb82f2064fe5fc49a4fb4347a826 (#2821)
With the recent LLVM integrate and changes from
https://github.com/llvm/llvm-project/pull/78260, we hit this build error
in Stablehlo (which is quite old).
```
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1020:14: error: no member named 'startRootUpdate' in 'mlir::PatternRewriter'
rewriter.startRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1026:16: error: no member named 'finalizeRootUpdate' in 'mlir::PatternRewriter'
rewriter.finalizeRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1029:16: error: no member named 'cancelRootUpdate' in 'mlir::PatternRewriter'
rewriter.cancelRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1108:14: error: no member named 'updateRootInPlace' in 'mlir::PatternRewriter'
rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; });
~~~~~~~~ ^
4 errors generated.
Target @torch-mlir//:torch-mlir-opt failed to build
```
I'm still puzzled as to how this didn't fail with the CMake merge gating
CI (do we not test Stablehlo builds/tests?). In any case, bumping our
submodule to https://github.com/openxla/stablehlo/pull/1918 fixes it.
It exposes a new failing lit test in TorchToStablehlo though, that I
have looped stablehlo developers into
([here](https://discord.com/channels/999073994483433573/999074539138990131/1201235845391331419)).
```
bazel run @torch-mlir//test/Conversion:TorchToStablehlo/scatter.mlir.test
...external/torch-mlir/test/Conversion/TorchToStablehlo/scatter.mlir
within split at <stdin>:1 offset :33:8: error: unexpected error: Expects non-empty reduction block for type inference
%0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64>
^
LLVM ERROR: Failed to infer result type(s).
```
Bazel CI:
https://github.com/sjain-stanford/torch-mlir/actions/runs/7732673480/job/21083102228
2024-02-01 06:21:17 +08:00
|
|
|
DenseI64ArrayAttr bcastDimensions;
|
2024-01-30 01:59:33 +08:00
|
|
|
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
|
|
|
|
op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions);
|
|
|
|
return success();
|
2023-07-25 14:09:53 +08:00
|
|
|
}
|
2022-08-04 12:34:22 +08:00
|
|
|
|
2023-07-25 14:09:53 +08:00
|
|
|
// Use another mhlo.ReduceWindowOp to get the divisor
|
|
|
|
Value windowSizeConst =
|
|
|
|
hlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
|
|
|
|
windowSizeConst =
|
|
|
|
hlo::promoteType(rewriter, op.getLoc(), windowSizeConst, outTy);
|
|
|
|
const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
|
2024-01-30 01:59:33 +08:00
|
|
|
auto inputShapeVec = *hlo::getDimSizesOfTensor(rewriter, op, input,
|
|
|
|
options.dimSizeIndexBits);
|
2023-07-25 14:09:53 +08:00
|
|
|
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
|
|
|
op->getLoc(), inputShapeVec);
|
|
|
|
|
|
|
|
windowSizeConst = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(
|
|
|
|
op->getLoc(),
|
|
|
|
RankedTensorType::get(inputTy.getShape(), outTy.getElementType()),
|
Bump stablehlo to openxla/stablehlo@fd52182f76cadb82f2064fe5fc49a4fb4347a826 (#2821)
With the recent LLVM integrate and changes from
https://github.com/llvm/llvm-project/pull/78260, we hit this build error
in Stablehlo (which is quite old).
```
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1020:14: error: no member named 'startRootUpdate' in 'mlir::PatternRewriter'
rewriter.startRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1026:16: error: no member named 'finalizeRootUpdate' in 'mlir::PatternRewriter'
rewriter.finalizeRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1029:16: error: no member named 'cancelRootUpdate' in 'mlir::PatternRewriter'
rewriter.cancelRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1108:14: error: no member named 'updateRootInPlace' in 'mlir::PatternRewriter'
rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; });
~~~~~~~~ ^
4 errors generated.
Target @torch-mlir//:torch-mlir-opt failed to build
```
I'm still puzzled as to how this didn't fail with the CMake merge gating
CI (do we not test Stablehlo builds/tests?). In any case, bumping our
submodule to https://github.com/openxla/stablehlo/pull/1918 fixes it.
It exposes a new failing lit test in TorchToStablehlo though, that I
have looped stablehlo developers into
([here](https://discord.com/channels/999073994483433573/999074539138990131/1201235845391331419)).
```
bazel run @torch-mlir//test/Conversion:TorchToStablehlo/scatter.mlir.test
...external/torch-mlir/test/Conversion/TorchToStablehlo/scatter.mlir
within split at <stdin>:1 offset :33:8: error: unexpected error: Expects non-empty reduction block for type inference
%0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64>
^
LLVM ERROR: Failed to infer result type(s).
```
Bazel CI:
https://github.com/sjain-stanford/torch-mlir/actions/runs/7732673480/job/21083102228
2024-02-01 06:21:17 +08:00
|
|
|
windowSizeConst, inputShapeTensor, rewriter.getDenseI64ArrayAttr({}));
|
2023-07-25 14:09:53 +08:00
|
|
|
|
|
|
|
Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
|
|
|
auto reduceWindowSize = rewriter.create<stablehlo::ReduceWindowOp>(
|
|
|
|
op->getLoc(), RankedTensorType::get(outShape, inputElemTy),
|
|
|
|
windowSizeConst, zero, windowDimensions, windowStrides, baseDilations,
|
|
|
|
windowDilations, pad);
|
|
|
|
|
|
|
|
Block &sizeBlock = reduceWindowSize.getBody().emplaceBlock();
|
|
|
|
|
|
|
|
// Add bb argument
|
|
|
|
blockArgumentType = RankedTensorType::get({}, inputElemTy);
|
|
|
|
sizeBlock.addArgument(blockArgumentType, op->getLoc());
|
|
|
|
sizeBlock.addArgument(blockArgumentType, op->getLoc());
|
|
|
|
firstArg = *sizeBlock.args_begin();
|
|
|
|
secondArg = *sizeBlock.args_rbegin();
|
|
|
|
|
|
|
|
{
|
2024-01-30 01:59:33 +08:00
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
|
|
rewriter.setInsertionPointToStart(&sizeBlock);
|
2023-07-25 14:09:53 +08:00
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
Value sumResult =
|
|
|
|
rewriter.create<stablehlo::AddOp>(op->getLoc(), firstArg, secondArg);
|
|
|
|
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
|
2023-07-25 14:09:53 +08:00
|
|
|
}
|
2022-08-04 12:34:22 +08:00
|
|
|
|
2023-07-25 14:09:53 +08:00
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(
|
|
|
|
op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0));
|
|
|
|
return success();
|
2022-08-04 12:34:22 +08:00
|
|
|
}
|
2023-07-25 14:09:53 +08:00
|
|
|
};
|
2024-01-30 01:59:33 +08:00
|
|
|
} // namespace
|
2023-07-25 14:09:53 +08:00
|
|
|
|
2023-01-30 13:38:27 +08:00
|
|
|
// AtenCumsumOp
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
|
|
|
|
AtenCumsumOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
Value input = adaptor.getSelf();
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inputTy = cast<RankedTensorType>(input.getType());
|
[Stablehlo]fix CumsumInputDtypeInt32Module_basic on stablehlo backend. (#2797)
Code used for testing.For the location of CumsumInputDtypeInt32Module in
the repo you can see
[here](https://github.com/llvm/torch-mlir/blob/311b6b0286bfa016346bc7fd8b441bbd50216060/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py#L4148).
```python
import torch
import torch_mlir
class CumsumInputDtypeInt32Module(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, val):
return torch.ops.aten.cumsum(val, 1)
module = torch_mlir.compile(CumsumInputDtypeInt32Module(), [torch.randn(2, 7, 4).to(torch.int32)], output_type="stablehlo")
print(module.operation.get_asm())
```
After fixing the bugs.
```
module attributes {torch.debug_module_name = "CumsumInputDtypeInt32Module"} {
func.func @forward(%arg0: tensor<2x7x4xi32>) -> tensor<2x7x4xi64> {
%0 = stablehlo.constant dense<0> : tensor<i64>
%1 = stablehlo.convert %arg0 : (tensor<2x7x4xi32>) -> tensor<2x7x4xi64>
%2 = "stablehlo.reduce_window"(%1, %0) ({
^bb0(%arg1: tensor<i64>, %arg2: tensor<i64>):
%3 = stablehlo.add %arg1, %arg2 : tensor<i64>
stablehlo.return %3 : tensor<i64>
}) {padding = dense<[[0, 0], [6, 0], [0, 0]]> : tensor<3x2xi64>, window_dilations = dense<1> : tensor<3xi64>, window_dimensions = dense<[1, 7, 1]> : tensor<3xi64>, window_strides = dense<1> : tensor<3xi64>} : (tensor<2x7x4xi64>, tensor<i64>) -> tensor<2x7x4xi64>
return %2 : tensor<2x7x4xi64>
}
}
```
2024-01-25 10:44:08 +08:00
|
|
|
auto outTy =
|
2024-04-28 05:00:56 +08:00
|
|
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
[Stablehlo]fix CumsumInputDtypeInt32Module_basic on stablehlo backend. (#2797)
Code used for testing.For the location of CumsumInputDtypeInt32Module in
the repo you can see
[here](https://github.com/llvm/torch-mlir/blob/311b6b0286bfa016346bc7fd8b441bbd50216060/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py#L4148).
```python
import torch
import torch_mlir
class CumsumInputDtypeInt32Module(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, val):
return torch.ops.aten.cumsum(val, 1)
module = torch_mlir.compile(CumsumInputDtypeInt32Module(), [torch.randn(2, 7, 4).to(torch.int32)], output_type="stablehlo")
print(module.operation.get_asm())
```
After fixing the bugs.
```
module attributes {torch.debug_module_name = "CumsumInputDtypeInt32Module"} {
func.func @forward(%arg0: tensor<2x7x4xi32>) -> tensor<2x7x4xi64> {
%0 = stablehlo.constant dense<0> : tensor<i64>
%1 = stablehlo.convert %arg0 : (tensor<2x7x4xi32>) -> tensor<2x7x4xi64>
%2 = "stablehlo.reduce_window"(%1, %0) ({
^bb0(%arg1: tensor<i64>, %arg2: tensor<i64>):
%3 = stablehlo.add %arg1, %arg2 : tensor<i64>
stablehlo.return %3 : tensor<i64>
}) {padding = dense<[[0, 0], [6, 0], [0, 0]]> : tensor<3x2xi64>, window_dilations = dense<1> : tensor<3xi64>, window_dimensions = dense<[1, 7, 1]> : tensor<3xi64>, window_strides = dense<1> : tensor<3xi64>} : (tensor<2x7x4xi64>, tensor<i64>) -> tensor<2x7x4xi64>
return %2 : tensor<2x7x4xi64>
}
}
```
2024-01-25 10:44:08 +08:00
|
|
|
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
|
2024-04-28 05:00:56 +08:00
|
|
|
inputTy = cast<RankedTensorType>(input.getType());
|
2023-01-30 13:38:27 +08:00
|
|
|
auto inputElemTy = inputTy.getElementType();
|
|
|
|
auto inputRank = inputTy.getRank();
|
|
|
|
auto inputShape = inputTy.getShape();
|
|
|
|
|
|
|
|
int64_t dim;
|
|
|
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: dim must be a constant int");
|
|
|
|
}
|
|
|
|
dim = toPositiveDim(dim, inputRank);
|
|
|
|
if (!isValidDim(dim, inputRank)) {
|
|
|
|
return rewriter.notifyMatchFailure(op, "dim is out of range");
|
|
|
|
}
|
|
|
|
if (inputTy.isDynamicDim(dim)) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unimplemented: cumsum dim must be static");
|
|
|
|
}
|
|
|
|
|
|
|
|
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
SmallVector<int64_t> stablehloKernelSize(inputRank, 1);
|
|
|
|
stablehloKernelSize[dim] = inputShape[dim];
|
|
|
|
SmallVector<int64_t> stablehloStride(inputRank, 1);
|
|
|
|
SmallVector<int64_t> stablehloDilation(inputRank, 1);
|
|
|
|
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
|
|
|
stablehloPadding[dim * 2] = inputShape[dim] - 1;
|
2023-01-30 13:38:27 +08:00
|
|
|
|
Bump stablehlo to openxla/stablehlo@fd52182f76cadb82f2064fe5fc49a4fb4347a826 (#2821)
With the recent LLVM integrate and changes from
https://github.com/llvm/llvm-project/pull/78260, we hit this build error
in Stablehlo (which is quite old).
```
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1020:14: error: no member named 'startRootUpdate' in 'mlir::PatternRewriter'
rewriter.startRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1026:16: error: no member named 'finalizeRootUpdate' in 'mlir::PatternRewriter'
rewriter.finalizeRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1029:16: error: no member named 'cancelRootUpdate' in 'mlir::PatternRewriter'
rewriter.cancelRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1108:14: error: no member named 'updateRootInPlace' in 'mlir::PatternRewriter'
rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; });
~~~~~~~~ ^
4 errors generated.
Target @torch-mlir//:torch-mlir-opt failed to build
```
I'm still puzzled as to how this didn't fail with the CMake merge gating
CI (do we not test Stablehlo builds/tests?). In any case, bumping our
submodule to https://github.com/openxla/stablehlo/pull/1918 fixes it.
It exposes a new failing lit test in TorchToStablehlo though, that I
have looped stablehlo developers into
([here](https://discord.com/channels/999073994483433573/999074539138990131/1201235845391331419)).
```
bazel run @torch-mlir//test/Conversion:TorchToStablehlo/scatter.mlir.test
...external/torch-mlir/test/Conversion/TorchToStablehlo/scatter.mlir
within split at <stdin>:1 offset :33:8: error: unexpected error: Expects non-empty reduction block for type inference
%0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64>
^
LLVM ERROR: Failed to infer result type(s).
```
Bazel CI:
https://github.com/sjain-stanford/torch-mlir/actions/runs/7732673480/job/21083102228
2024-02-01 06:21:17 +08:00
|
|
|
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);
|
|
|
|
auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride);
|
|
|
|
DenseI64ArrayAttr baseDilations;
|
|
|
|
auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation);
|
2023-01-30 13:38:27 +08:00
|
|
|
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
|
|
|
RankedTensorType::get(
|
|
|
|
{static_cast<int64_t>(inputRank), static_cast<int64_t>(2)},
|
|
|
|
rewriter.getI64Type()),
|
2023-02-02 21:29:47 +08:00
|
|
|
stablehloPadding);
|
2023-01-30 13:38:27 +08:00
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
auto reduceWindowSum = rewriter.create<stablehlo::ReduceWindowOp>(
|
2023-01-30 13:38:27 +08:00
|
|
|
op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides,
|
|
|
|
baseDilations, windowDilations, pad);
|
|
|
|
|
|
|
|
Block &sumBlock = reduceWindowSum.getBody().emplaceBlock();
|
|
|
|
|
|
|
|
// Add bb argument
|
|
|
|
auto blockArgumentType = RankedTensorType::get({}, inputElemTy);
|
|
|
|
sumBlock.addArgument(blockArgumentType, op->getLoc());
|
|
|
|
sumBlock.addArgument(blockArgumentType, op->getLoc());
|
|
|
|
auto *firstArg = sumBlock.args_begin();
|
|
|
|
auto *secondArg = std::next(firstArg);
|
|
|
|
|
|
|
|
{
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
|
|
rewriter.setInsertionPointToStart(&sumBlock);
|
|
|
|
|
|
|
|
Value sumResult =
|
2023-02-02 21:29:47 +08:00
|
|
|
rewriter.create<stablehlo::AddOp>(op->getLoc(), *firstArg, *secondArg);
|
|
|
|
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
|
2023-01-30 13:38:27 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, reduceWindowSum.getResults());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
|
2022-08-04 12:34:22 +08:00
|
|
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
2023-02-02 21:29:47 +08:00
|
|
|
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
2022-08-04 12:34:22 +08:00
|
|
|
MLIRContext *context = patterns.getContext();
|
2024-05-01 00:06:13 +08:00
|
|
|
#define INSERT_ATEN_POOLING_PATTERN(AtenOp) \
|
|
|
|
target.addIllegalOp<AtenOp>(); \
|
|
|
|
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
|
|
|
|
INSERT_ATEN_POOLING_PATTERN(AtenMaxPool2dWithIndicesOp);
|
|
|
|
INSERT_ATEN_POOLING_PATTERN(AtenCumsumOp);
|
|
|
|
#undef INSERT_ATEN_POOLING_PATTERN
|
|
|
|
|
|
|
|
#define INSERT_ATEN_MAXPOOL_PATTERN(AtenOp, Dim) \
|
|
|
|
target.addIllegalOp<AtenOp>(); \
|
|
|
|
patterns.add<ConvertAtenMaxPoolOp<AtenOp, Dim>>(typeConverter, context, \
|
|
|
|
options)
|
|
|
|
INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool1dOp, 1);
|
|
|
|
INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool2dOp, 2);
|
|
|
|
INSERT_ATEN_MAXPOOL_PATTERN(AtenMaxPool3dOp, 3);
|
|
|
|
#undef INSERT_ATEN_MAXPOOL_PATTERN
|
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
#define INSERT_ATEN_AVGPOOL_PATTERN(AtenOp, Dim) \
|
2023-07-25 14:09:53 +08:00
|
|
|
target.addIllegalOp<AtenOp>(); \
|
2024-01-30 01:59:33 +08:00
|
|
|
patterns.add<ConvertAtenAvgPoolOp<AtenOp, Dim>>(typeConverter, context, \
|
|
|
|
options)
|
2023-07-25 14:09:53 +08:00
|
|
|
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool1dOp, 1);
|
|
|
|
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool2dOp, 2);
|
2024-05-01 00:06:13 +08:00
|
|
|
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool3dOp, 3);
|
2023-07-25 14:09:53 +08:00
|
|
|
#undef INSERT_ATEN_AVGPOOL_PATTERN
|
2022-08-04 12:34:22 +08:00
|
|
|
}
|