2022-08-03 10:47:52 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
2022-08-03 10:47:52 +08:00
|
|
|
|
|
|
|
#include "../PassDetail.h"
|
2023-02-02 21:29:47 +08:00
|
|
|
#include "PopulatePatterns.h"
|
|
|
|
|
2022-10-05 21:28:06 +08:00
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
2022-08-03 10:47:52 +08:00
|
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
2023-03-21 05:14:27 +08:00
|
|
|
#include "stablehlo/dialect/ChloOps.h"
|
2023-02-02 21:29:47 +08:00
|
|
|
#include "stablehlo/dialect/StablehloOps.h"
|
2024-01-30 01:59:33 +08:00
|
|
|
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
|
2022-08-03 10:47:52 +08:00
|
|
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
|
|
|
|
2024-02-16 01:08:48 +08:00
|
|
|
#include <unordered_set>
|
|
|
|
#include <vector>
|
|
|
|
|
2022-08-03 10:47:52 +08:00
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::torch;
|
|
|
|
using namespace mlir::torch::Torch;
|
2023-02-02 21:29:47 +08:00
|
|
|
using namespace mlir::torch::torch_to_stablehlo;
|
2022-08-03 10:47:52 +08:00
|
|
|
|
2024-06-29 16:53:33 +08:00
|
|
|
static SmallVector<int64_t> getReduceOutputShape(ArrayRef<int64_t> inputShape,
|
|
|
|
ArrayRef<int64_t> dims) {
|
|
|
|
std::unordered_set<int64_t> dimsSet(dims.begin(), dims.end());
|
|
|
|
SmallVector<int64_t> reduceResultShape;
|
|
|
|
for (size_t i = 0; i < inputShape.size(); i++) {
|
|
|
|
if (dimsSet.find(i) == dimsSet.end()) {
|
|
|
|
reduceResultShape.push_back(inputShape[i]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return reduceResultShape;
|
|
|
|
}
|
|
|
|
|
2022-08-03 10:47:52 +08:00
|
|
|
static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
|
|
|
|
PatternRewriter &rewriter) {
|
|
|
|
auto constType = RankedTensorType::get({}, elementTy);
|
2023-03-21 05:14:27 +08:00
|
|
|
if (isa<AtenSumOp, AtenSumDimIntListOp, AtenFrobeniusNormDimOp,
|
|
|
|
AtenLinalgVectorNormOp>(op)) {
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<mlir::FloatType>(elementTy)) {
|
2022-08-03 10:47:52 +08:00
|
|
|
auto constAttr = DenseElementsAttr::get(
|
|
|
|
constType, {APFloat::getZero(
|
2024-04-11 21:47:35 +08:00
|
|
|
cast<mlir::FloatType>(elementTy).getFloatSemantics(),
|
2022-08-03 10:47:52 +08:00
|
|
|
/*negative=*/false)});
|
2023-02-02 21:29:47 +08:00
|
|
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
|
|
|
constAttr);
|
2024-06-29 16:53:33 +08:00
|
|
|
} else if (isa<mlir::IntegerType>(elementTy)) {
|
2022-08-03 10:47:52 +08:00
|
|
|
auto constAttr = DenseElementsAttr::get(
|
|
|
|
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
|
2023-02-02 21:29:47 +08:00
|
|
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
|
|
|
constAttr);
|
2022-08-03 10:47:52 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-05-16 00:05:19 +08:00
|
|
|
if (isa<AtenAmaxOp, AtenMaxOp, AtenMaxDimOp, AtenArgmaxOp>(op)) {
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<mlir::FloatType>(elementTy)) {
|
2022-08-03 10:47:52 +08:00
|
|
|
auto constAttr = DenseElementsAttr::get(
|
2024-04-11 21:47:35 +08:00
|
|
|
constType,
|
|
|
|
{APFloat::getInf(cast<mlir::FloatType>(elementTy).getFloatSemantics(),
|
|
|
|
/*negative=*/true)});
|
2023-02-02 21:29:47 +08:00
|
|
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
|
|
|
constAttr);
|
2024-06-29 16:53:33 +08:00
|
|
|
} else if (isa<mlir::IntegerType>(elementTy)) {
|
2022-08-03 10:47:52 +08:00
|
|
|
auto constAttr = DenseElementsAttr::get(
|
|
|
|
constType,
|
|
|
|
{APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())});
|
2023-02-02 21:29:47 +08:00
|
|
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
|
|
|
constAttr);
|
2022-08-03 10:47:52 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-06-29 16:53:33 +08:00
|
|
|
if (isa<AtenAminOp, AtenMinOp, AtenMinDimOp, AtenArgminOp>(op)) {
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<mlir::FloatType>(elementTy)) {
|
2023-08-30 01:12:41 +08:00
|
|
|
auto constAttr = DenseElementsAttr::get(
|
2024-04-11 21:47:35 +08:00
|
|
|
constType,
|
|
|
|
{APFloat::getInf(cast<mlir::FloatType>(elementTy).getFloatSemantics(),
|
|
|
|
/*negative=*/false)});
|
2023-08-30 01:12:41 +08:00
|
|
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
|
|
|
constAttr);
|
2024-06-29 16:53:33 +08:00
|
|
|
} else if (isa<mlir::IntegerType>(elementTy)) {
|
2023-08-30 01:12:41 +08:00
|
|
|
auto constAttr = DenseElementsAttr::get(
|
|
|
|
constType,
|
|
|
|
{APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth())});
|
|
|
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
|
|
|
constAttr);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-04-24 11:14:04 +08:00
|
|
|
if (isa<AtenProdOp>(op)) {
|
|
|
|
if (isa<mlir::FloatType>(elementTy)) {
|
|
|
|
APFloat one(cast<mlir::FloatType>(elementTy).getFloatSemantics(), 1);
|
|
|
|
auto constAttr = DenseElementsAttr::get(constType, one);
|
|
|
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
|
|
|
constAttr);
|
2024-06-29 16:53:33 +08:00
|
|
|
} else if (isa<mlir::IntegerType>(elementTy)) {
|
2024-04-24 11:14:04 +08:00
|
|
|
APInt one(elementTy.getIntOrFloatBitWidth(), 1);
|
|
|
|
auto constAttr = DenseElementsAttr::get(constType, one);
|
|
|
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
|
|
|
constAttr);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-04-25 11:15:52 +08:00
|
|
|
if (isa<AtenAllOp>(op)) {
|
2024-06-29 16:53:33 +08:00
|
|
|
auto constAttr =
|
|
|
|
DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 1)});
|
2024-04-25 11:15:52 +08:00
|
|
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
|
|
|
constAttr);
|
|
|
|
}
|
|
|
|
|
2024-06-29 16:53:33 +08:00
|
|
|
if (isa<AtenAnyOp, AtenAnyDimOp>(op)) {
|
|
|
|
auto constAttr =
|
|
|
|
DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 0)});
|
2024-04-25 11:15:52 +08:00
|
|
|
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
|
|
|
|
constAttr);
|
|
|
|
}
|
|
|
|
|
2022-08-03 10:47:52 +08:00
|
|
|
op->emitError("unimplemented lowering in "
|
|
|
|
"createInitialValueForReduceOp");
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
2024-05-16 00:05:19 +08:00
|
|
|
static Value createReduceOpWithSingleRegionOp(Operation *op, Value input,
|
|
|
|
Type outTy,
|
|
|
|
ArrayRef<int64_t> dims,
|
|
|
|
PatternRewriter &rewriter) {
|
|
|
|
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
|
|
|
if (!inputTy)
|
|
|
|
return nullptr;
|
|
|
|
Value initValue =
|
|
|
|
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
|
|
|
|
if (!initValue)
|
|
|
|
return nullptr;
|
|
|
|
|
|
|
|
stablehlo::ReduceOp reduce = rewriter.create<stablehlo::ReduceOp>(
|
|
|
|
op->getLoc(), outTy, input, initValue,
|
|
|
|
rewriter.getDenseI64ArrayAttr(dims));
|
|
|
|
|
|
|
|
Block &block = reduce.getBody().emplaceBlock();
|
|
|
|
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
|
|
|
|
block.addArgument(blockArgumentTy, op->getLoc());
|
|
|
|
block.addArgument(blockArgumentTy, op->getLoc());
|
|
|
|
auto *firstArgument = block.args_begin();
|
|
|
|
auto secondArgument = block.args_rbegin();
|
|
|
|
|
|
|
|
{
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
|
|
rewriter.setInsertionPointToStart(&block);
|
|
|
|
Value result;
|
|
|
|
if (isa<AtenAmaxOp, AtenMaxOp, AtenMaxDimOp>(op)) {
|
|
|
|
result = rewriter.create<stablehlo::MaxOp>(
|
|
|
|
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
2024-06-29 16:53:33 +08:00
|
|
|
} else if (isa<AtenAminOp, AtenMinOp, AtenMinDimOp>(op)) {
|
2024-05-23 20:40:20 +08:00
|
|
|
result = rewriter.create<stablehlo::MinOp>(
|
|
|
|
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
2024-06-29 16:53:33 +08:00
|
|
|
} else if (isa<AtenSumOp, AtenSumDimIntListOp, AtenFrobeniusNormDimOp,
|
|
|
|
AtenLinalgVectorNormOp>(op)) {
|
2024-05-23 20:40:20 +08:00
|
|
|
result = rewriter.create<stablehlo::AddOp>(
|
|
|
|
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
|
|
|
} else if (isa<AtenAllOp>(op)) {
|
|
|
|
result = rewriter.create<stablehlo::AndOp>(
|
|
|
|
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
2024-06-29 16:53:33 +08:00
|
|
|
} else if (isa<AtenAnyOp, AtenAnyDimOp>(op)) {
|
2024-05-23 20:40:20 +08:00
|
|
|
result = rewriter.create<stablehlo::OrOp>(
|
|
|
|
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
|
|
|
} else if (isa<AtenProdOp>(op)) {
|
|
|
|
result = rewriter.create<stablehlo::MulOp>(
|
|
|
|
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
|
2024-05-16 00:05:19 +08:00
|
|
|
} else {
|
|
|
|
op->emitError("unimplemented lowering in "
|
|
|
|
"createReduceOpWithSingleRegionOp");
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), result);
|
|
|
|
}
|
|
|
|
return reduce.getResults()[0];
|
|
|
|
}
|
|
|
|
|
2024-06-29 16:53:33 +08:00
|
|
|
// Util for converting AtenMaxDimOp/AtenMinDimOp
|
2022-12-20 18:17:27 +08:00
|
|
|
static std::optional<ValueRange>
|
2024-06-29 16:53:33 +08:00
|
|
|
createReduceOpReturnIndices(ConversionPatternRewriter &rewriter, Operation *op,
|
|
|
|
Value &input, ArrayRef<Value> inputShapeVec,
|
|
|
|
int64_t dim, size_t dimSizeIndexBits) {
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inputTy = cast<RankedTensorType>(input.getType());
|
2022-08-03 10:47:52 +08:00
|
|
|
if (!inputTy) {
|
2022-12-14 18:44:05 +08:00
|
|
|
return std::nullopt;
|
2022-08-03 10:47:52 +08:00
|
|
|
}
|
|
|
|
if (!inputTy.getElementType().isIntOrFloat()) {
|
2022-12-14 18:44:05 +08:00
|
|
|
return std::nullopt;
|
2022-08-03 10:47:52 +08:00
|
|
|
}
|
|
|
|
auto inputShape = inputTy.getShape();
|
|
|
|
auto inputElemTy = inputTy.getElementType();
|
|
|
|
|
|
|
|
Value initValue = createInitialValueForReduceOp(op, inputElemTy, rewriter);
|
2022-12-14 18:44:05 +08:00
|
|
|
if (!initValue)
|
|
|
|
return std::nullopt;
|
2022-08-09 17:02:50 +08:00
|
|
|
Value initIndex;
|
2022-09-01 10:36:02 +08:00
|
|
|
if (dimSizeIndexBits == 32) {
|
2023-02-02 21:29:47 +08:00
|
|
|
initIndex = hlo::getConstTensor<int32_t>(rewriter, op, {0}, {}).value();
|
2022-08-09 17:02:50 +08:00
|
|
|
} else {
|
2023-02-02 21:29:47 +08:00
|
|
|
initIndex = hlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
|
2022-08-09 17:02:50 +08:00
|
|
|
}
|
2022-08-03 10:47:52 +08:00
|
|
|
|
2024-06-29 16:53:33 +08:00
|
|
|
auto outputShape = getReduceOutputShape(inputShape, {dim});
|
2024-02-16 01:08:48 +08:00
|
|
|
auto outputTy = RankedTensorType::get(outputShape, inputElemTy);
|
|
|
|
auto outputIndexTy =
|
|
|
|
RankedTensorType::get(outputShape, rewriter.getIntegerType(64));
|
|
|
|
|
2022-08-03 10:47:52 +08:00
|
|
|
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
|
|
|
op->getLoc(), inputShapeVec);
|
2023-02-02 21:29:47 +08:00
|
|
|
auto indexTensor = rewriter.create<stablehlo::DynamicIotaOp>(
|
2022-09-01 10:36:02 +08:00
|
|
|
op->getLoc(),
|
|
|
|
RankedTensorType::get(inputShape,
|
|
|
|
rewriter.getIntegerType(dimSizeIndexBits)),
|
2022-08-03 10:47:52 +08:00
|
|
|
inputShapeTensor, static_cast<uint64_t>(dim));
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
|
2024-02-16 01:08:48 +08:00
|
|
|
op->getLoc(), TypeRange{outputTy, outputIndexTy},
|
|
|
|
ValueRange{input, indexTensor},
|
2022-08-03 10:47:52 +08:00
|
|
|
ValueRange{
|
|
|
|
initValue,
|
|
|
|
initIndex,
|
|
|
|
},
|
Bump stablehlo to openxla/stablehlo@fd52182f76cadb82f2064fe5fc49a4fb4347a826 (#2821)
With the recent LLVM integrate and changes from
https://github.com/llvm/llvm-project/pull/78260, we hit this build error
in Stablehlo (which is quite old).
```
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1020:14: error: no member named 'startRootUpdate' in 'mlir::PatternRewriter'
rewriter.startRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1026:16: error: no member named 'finalizeRootUpdate' in 'mlir::PatternRewriter'
rewriter.finalizeRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1029:16: error: no member named 'cancelRootUpdate' in 'mlir::PatternRewriter'
rewriter.cancelRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1108:14: error: no member named 'updateRootInPlace' in 'mlir::PatternRewriter'
rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; });
~~~~~~~~ ^
4 errors generated.
Target @torch-mlir//:torch-mlir-opt failed to build
```
I'm still puzzled as to how this didn't fail with the CMake merge gating
CI (do we not test Stablehlo builds/tests?). In any case, bumping our
submodule to https://github.com/openxla/stablehlo/pull/1918 fixes it.
It exposes a new failing lit test in TorchToStablehlo though, that I
have looped stablehlo developers into
([here](https://discord.com/channels/999073994483433573/999074539138990131/1201235845391331419)).
```
bazel run @torch-mlir//test/Conversion:TorchToStablehlo/scatter.mlir.test
...external/torch-mlir/test/Conversion/TorchToStablehlo/scatter.mlir
within split at <stdin>:1 offset :33:8: error: unexpected error: Expects non-empty reduction block for type inference
%0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64>
^
LLVM ERROR: Failed to infer result type(s).
```
Bazel CI:
https://github.com/sjain-stanford/torch-mlir/actions/runs/7732673480/job/21083102228
2024-02-01 06:21:17 +08:00
|
|
|
rewriter.getDenseI64ArrayAttr(dim));
|
2022-08-03 10:47:52 +08:00
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
Block &block = stablehloReduceOp.getBody().emplaceBlock();
|
2022-08-03 10:47:52 +08:00
|
|
|
|
|
|
|
// Add block arguments
|
|
|
|
auto blockValArgumentType =
|
|
|
|
RankedTensorType::get({}, inputTy.getElementType());
|
2022-09-01 10:36:02 +08:00
|
|
|
auto blockIdxArgumentType =
|
|
|
|
RankedTensorType::get({}, rewriter.getIntegerType(dimSizeIndexBits));
|
2022-08-03 10:47:52 +08:00
|
|
|
auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type());
|
|
|
|
block.addArgument(blockValArgumentType, op->getLoc());
|
|
|
|
block.addArgument(blockIdxArgumentType, op->getLoc());
|
|
|
|
|
|
|
|
block.addArgument(blockValArgumentType, op->getLoc());
|
|
|
|
block.addArgument(blockIdxArgumentType, op->getLoc());
|
|
|
|
|
|
|
|
auto *firstValArg = block.args_begin();
|
|
|
|
auto *firstIdxArg = std::next(firstValArg);
|
|
|
|
auto *secondValArg = std::next(firstIdxArg);
|
|
|
|
auto *secondIdxArg = std::next(secondValArg);
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
stablehlo::ComparisonTypeAttr compareTypeAttr;
|
2024-05-31 14:45:13 +08:00
|
|
|
if (isa<mlir::FloatType>(inputTy.getElementType())) {
|
2023-02-02 21:29:47 +08:00
|
|
|
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
|
|
|
rewriter.getContext(), stablehlo::ComparisonType::FLOAT);
|
2024-05-31 14:45:13 +08:00
|
|
|
} else if (isa<mlir::IntegerType>(inputTy.getElementType())) {
|
2023-02-02 21:29:47 +08:00
|
|
|
compareTypeAttr = stablehlo::ComparisonTypeAttr::get(
|
|
|
|
rewriter.getContext(), stablehlo::ComparisonType::SIGNED);
|
2022-08-03 10:47:52 +08:00
|
|
|
}
|
2023-02-02 21:29:47 +08:00
|
|
|
stablehlo::ComparisonDirectionAttr compareGeDirectionAttr =
|
|
|
|
stablehlo::ComparisonDirectionAttr::get(
|
|
|
|
rewriter.getContext(), stablehlo::ComparisonDirection::GE);
|
2024-06-29 16:53:33 +08:00
|
|
|
stablehlo::ComparisonDirectionAttr compareLeDirectionAttr =
|
|
|
|
stablehlo::ComparisonDirectionAttr::get(
|
|
|
|
rewriter.getContext(), stablehlo::ComparisonDirection::LE);
|
2023-02-02 21:29:47 +08:00
|
|
|
stablehlo::ComparisonDirectionAttr compareEqDirectionAttr =
|
|
|
|
stablehlo::ComparisonDirectionAttr::get(
|
|
|
|
rewriter.getContext(), stablehlo::ComparisonDirection::EQ);
|
2022-08-03 10:47:52 +08:00
|
|
|
|
|
|
|
{
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
|
|
rewriter.setInsertionPointToStart(&block);
|
|
|
|
|
2024-06-29 16:53:33 +08:00
|
|
|
Value compareResult;
|
|
|
|
if (isa<AtenMaxDimOp>(op)) {
|
|
|
|
compareResult = rewriter.create<stablehlo::CompareOp>(
|
|
|
|
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
|
|
|
compareGeDirectionAttr, compareTypeAttr);
|
|
|
|
} else if (isa<AtenMinDimOp>(op)) {
|
|
|
|
compareResult = rewriter.create<stablehlo::CompareOp>(
|
|
|
|
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
|
|
|
compareLeDirectionAttr, compareTypeAttr);
|
|
|
|
} else {
|
|
|
|
op->emitError("unimplement lowering of createReduceOpReturnIndices");
|
|
|
|
return std::nullopt;
|
|
|
|
}
|
2023-02-02 21:29:47 +08:00
|
|
|
Value retValResult = rewriter.create<stablehlo::SelectOp>(
|
2024-06-29 16:53:33 +08:00
|
|
|
op->getLoc(), compareResult, *firstValArg, *secondValArg);
|
2022-08-03 10:47:52 +08:00
|
|
|
|
|
|
|
// get smaller index value if compared nums are equal.
|
2023-02-02 21:29:47 +08:00
|
|
|
Value compareEqResult = rewriter.create<stablehlo::CompareOp>(
|
2022-08-03 10:47:52 +08:00
|
|
|
op->getLoc(), compareResultType, *firstValArg, *secondValArg,
|
|
|
|
compareEqDirectionAttr, compareTypeAttr);
|
2023-02-02 21:29:47 +08:00
|
|
|
Value minIdx = rewriter.create<stablehlo::MinOp>(op->getLoc(), *firstIdxArg,
|
|
|
|
*secondIdxArg);
|
|
|
|
Value idxWithGeVal = rewriter.create<stablehlo::SelectOp>(
|
2024-06-29 16:53:33 +08:00
|
|
|
op->getLoc(), compareResult, *firstIdxArg, *secondIdxArg);
|
2023-02-02 21:29:47 +08:00
|
|
|
Value retIdxResult = rewriter.create<stablehlo::SelectOp>(
|
2022-08-03 10:47:52 +08:00
|
|
|
op->getLoc(), compareEqResult, minIdx, idxWithGeVal);
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
rewriter.create<stablehlo::ReturnOp>(
|
2024-06-29 16:53:33 +08:00
|
|
|
op->getLoc(), ValueRange{retValResult, retIdxResult});
|
2022-08-03 10:47:52 +08:00
|
|
|
}
|
2023-02-02 21:29:47 +08:00
|
|
|
return stablehloReduceOp.getResults();
|
2022-08-03 10:47:52 +08:00
|
|
|
}
|
|
|
|
|
2024-06-29 16:53:33 +08:00
|
|
|
static Value reshapeReduceResultWhenKeepDim(ConversionPatternRewriter &rewriter,
|
|
|
|
Location loc, Value reduceResult,
|
|
|
|
ArrayRef<Value> inputShapeVec,
|
|
|
|
Type outType,
|
|
|
|
ArrayRef<int64_t> dims,
|
|
|
|
size_t dimSizeIndexBits) {
|
|
|
|
SmallVector<Value> outShapeVec(inputShapeVec);
|
|
|
|
Value one = rewriter.create<arith::ConstantOp>(
|
|
|
|
loc,
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(dimSizeIndexBits), 1));
|
|
|
|
for (auto dim : dims) {
|
|
|
|
outShapeVec[dim] = one;
|
|
|
|
}
|
|
|
|
auto outShapeTensor =
|
|
|
|
rewriter.create<tensor::FromElementsOp>(loc, outShapeVec);
|
|
|
|
return rewriter.create<stablehlo::DynamicReshapeOp>(
|
|
|
|
loc, outType, reduceResult, outShapeTensor);
|
|
|
|
}
|
|
|
|
|
2022-08-03 10:47:52 +08:00
|
|
|
namespace {
|
|
|
|
template <typename AtenOpT>
|
2022-09-01 10:36:02 +08:00
|
|
|
class ConvertAtenReductionOp : public ConvertAtenOp<AtenOpT> {
|
2022-08-03 10:47:52 +08:00
|
|
|
public:
|
2022-09-01 10:36:02 +08:00
|
|
|
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
|
2022-08-03 10:47:52 +08:00
|
|
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
2024-05-23 20:40:20 +08:00
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
assert(false && "Unimplemented");
|
|
|
|
return failure();
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
template <typename AtenOpT>
|
|
|
|
class ConvertAtenReduceAllDimsOp : public ConvertAtenReductionOp<AtenOpT> {
|
|
|
|
public:
|
|
|
|
using ConvertAtenReductionOp<AtenOpT>::ConvertAtenReductionOp;
|
|
|
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
Value input = adaptor.getSelf();
|
|
|
|
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
|
|
|
auto outTy = dyn_cast<RankedTensorType>(
|
|
|
|
ConvertAtenReductionOp<AtenOpT>::getTypeConverter()->convertType(
|
|
|
|
op.getType()));
|
|
|
|
if (!inputTy || !outTy) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "only Tensor types supported in StableHLO");
|
|
|
|
}
|
|
|
|
|
|
|
|
auto inputElemTy = inputTy.getElementType();
|
|
|
|
if (!inputElemTy.isIntOrFloat()) {
|
|
|
|
return op.emitError(
|
|
|
|
"only floating-point or integer datatype legalization supported");
|
|
|
|
}
|
|
|
|
if (inputElemTy != outTy.getElementType()) {
|
|
|
|
// use output type as computation type
|
|
|
|
input = rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input,
|
|
|
|
outTy.getElementType());
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<int64_t> dims =
|
|
|
|
llvm::to_vector(llvm::seq<int64_t>(0, inputTy.getRank()));
|
|
|
|
Value result =
|
|
|
|
createReduceOpWithSingleRegionOp(op, input, outTy, dims, rewriter);
|
|
|
|
if (!result) {
|
|
|
|
return op->emitError("createReduceOpWithSingleRegionOp return nullptr");
|
|
|
|
}
|
|
|
|
rewriter.replaceOp(op, result);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
template <typename AtenOpT>
|
2024-06-29 16:53:33 +08:00
|
|
|
class ConvertAtenReduceOneDimOp : public ConvertAtenReductionOp<AtenOpT> {
|
2024-05-23 20:40:20 +08:00
|
|
|
public:
|
|
|
|
using ConvertAtenReductionOp<AtenOpT>::ConvertAtenReductionOp;
|
|
|
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
Value input = adaptor.getSelf();
|
|
|
|
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
2024-06-29 16:53:33 +08:00
|
|
|
auto outTy = dyn_cast<RankedTensorType>(
|
|
|
|
ConvertAtenReductionOp<AtenOpT>::getTypeConverter()->convertType(
|
|
|
|
op.getType()));
|
|
|
|
if (!inputTy || !outTy) {
|
2024-05-23 20:40:20 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "only Tensor types supported in StableHLO");
|
|
|
|
}
|
|
|
|
|
|
|
|
auto inputElemTy = inputTy.getElementType();
|
|
|
|
if (!inputElemTy.isIntOrFloat()) {
|
|
|
|
return op.emitError(
|
|
|
|
"only floating-point or integer datatype legalization supported");
|
|
|
|
}
|
2024-06-29 16:53:33 +08:00
|
|
|
if (inputElemTy != outTy.getElementType()) {
|
|
|
|
// use output type as computation type
|
|
|
|
input = rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input,
|
|
|
|
outTy.getElementType());
|
|
|
|
}
|
|
|
|
|
|
|
|
bool keepDim = false;
|
|
|
|
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
|
|
|
|
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
|
|
|
|
}
|
|
|
|
|
|
|
|
int64_t dim;
|
|
|
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
|
2024-05-23 20:40:20 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
2024-06-29 16:53:33 +08:00
|
|
|
op, "non-const integer `dim` is not supported");
|
|
|
|
}
|
|
|
|
dim = toPositiveDim(dim, inputTy.getRank());
|
|
|
|
SmallVector<int64_t> reduceResultShape =
|
|
|
|
getReduceOutputShape(inputTy.getShape(), {dim});
|
|
|
|
|
|
|
|
Value reduceResult = createReduceOpWithSingleRegionOp(
|
|
|
|
op, input,
|
|
|
|
RankedTensorType::get(reduceResultShape, outTy.getElementType()), {dim},
|
|
|
|
rewriter);
|
|
|
|
if (!reduceResult) {
|
|
|
|
return op->emitError("createReduceOpWithSingleRegionOp return nullptr");
|
|
|
|
}
|
|
|
|
|
|
|
|
if (keepDim) {
|
|
|
|
const auto &options = ConvertAtenReductionOp<AtenOpT>::getOptions();
|
|
|
|
auto outShapeInfo = hlo::getDimSizesOfTensor(rewriter, op, input,
|
|
|
|
options.dimSizeIndexBits);
|
|
|
|
if (failed(outShapeInfo)) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "failed to get dimension sizes of the input");
|
|
|
|
}
|
|
|
|
reduceResult = reshapeReduceResultWhenKeepDim(
|
|
|
|
rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, {dim},
|
|
|
|
options.dimSizeIndexBits);
|
|
|
|
}
|
|
|
|
rewriter.replaceOp(op, reduceResult);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
template <typename AtenOpT>
|
|
|
|
class ConvertAtenReduceDimsOp : public ConvertAtenReductionOp<AtenOpT> {
|
|
|
|
public:
|
|
|
|
using ConvertAtenReductionOp<AtenOpT>::ConvertAtenReductionOp;
|
|
|
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
Value input = adaptor.getSelf();
|
|
|
|
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
|
|
|
auto outTy = dyn_cast<RankedTensorType>(
|
|
|
|
ConvertAtenReductionOp<AtenOpT>::getTypeConverter()->convertType(
|
|
|
|
op.getType()));
|
|
|
|
if (!inputTy || !outTy) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "only Tensor types supported in StableHLO");
|
|
|
|
}
|
|
|
|
|
|
|
|
auto inputElemTy = inputTy.getElementType();
|
|
|
|
if (!inputElemTy.isIntOrFloat()) {
|
|
|
|
return op.emitError(
|
|
|
|
"only floating-point or integer datatype legalization supported");
|
|
|
|
}
|
|
|
|
if (inputElemTy != outTy.getElementType()) {
|
|
|
|
// use output type as computation type
|
|
|
|
input = rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input,
|
|
|
|
outTy.getElementType());
|
2024-05-23 20:40:20 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
bool keepDim = false;
|
|
|
|
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
|
|
|
|
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<int64_t> inputDims;
|
|
|
|
SmallVector<int64_t> dims;
|
|
|
|
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "non-const integer `dim` is not supported");
|
|
|
|
}
|
|
|
|
for (auto d : inputDims) {
|
|
|
|
d = toPositiveDim(d, inputTy.getRank());
|
|
|
|
// Drop invalid dims
|
|
|
|
if (isValidDim(d, inputTy.getRank())) {
|
|
|
|
dims.push_back(d);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
llvm::sort(dims.begin(), dims.end());
|
2024-06-29 16:53:33 +08:00
|
|
|
SmallVector<int64_t> reduceResultShape =
|
|
|
|
getReduceOutputShape(inputTy.getShape(), dims);
|
2024-05-23 20:40:20 +08:00
|
|
|
|
|
|
|
Value reduceResult = createReduceOpWithSingleRegionOp(
|
2024-06-29 16:53:33 +08:00
|
|
|
op, input,
|
|
|
|
RankedTensorType::get(reduceResultShape, outTy.getElementType()), dims,
|
2024-05-23 20:40:20 +08:00
|
|
|
rewriter);
|
2024-06-29 16:53:33 +08:00
|
|
|
if (!reduceResult) {
|
2024-05-23 20:40:20 +08:00
|
|
|
return op->emitError("createReduceOpWithSingleRegionOp return nullptr");
|
2024-06-29 16:53:33 +08:00
|
|
|
}
|
2024-05-23 20:40:20 +08:00
|
|
|
|
|
|
|
if (keepDim) {
|
|
|
|
const auto &options = ConvertAtenReductionOp<AtenOpT>::getOptions();
|
|
|
|
auto outShapeInfo = hlo::getDimSizesOfTensor(rewriter, op, input,
|
|
|
|
options.dimSizeIndexBits);
|
|
|
|
if (failed(outShapeInfo)) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "failed to get dimension sizes of the input");
|
|
|
|
}
|
2024-06-29 16:53:33 +08:00
|
|
|
reduceResult = reshapeReduceResultWhenKeepDim(
|
|
|
|
rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, dims,
|
|
|
|
options.dimSizeIndexBits);
|
2024-05-23 20:40:20 +08:00
|
|
|
}
|
|
|
|
rewriter.replaceOp(op, reduceResult);
|
|
|
|
return success();
|
|
|
|
}
|
2022-08-03 10:47:52 +08:00
|
|
|
};
|
|
|
|
|
2024-06-29 16:53:33 +08:00
|
|
|
template <typename AtenOpT>
|
|
|
|
class ConvertAtenReduceWithIndicesOp : public ConvertAtenReductionOp<AtenOpT> {
|
|
|
|
public:
|
|
|
|
using ConvertAtenReductionOp<AtenOpT>::ConvertAtenReductionOp;
|
|
|
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
Value input = adaptor.getSelf();
|
|
|
|
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
|
|
|
if (!inputTy) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "only Tensor types supported in StableHLO");
|
|
|
|
}
|
|
|
|
auto inputElemTy = inputTy.getElementType();
|
|
|
|
if (!inputElemTy.isIntOrFloat()) {
|
|
|
|
return op.emitError(
|
|
|
|
"Only floating-point or integer datatype legalization supported");
|
|
|
|
}
|
2022-08-03 10:47:52 +08:00
|
|
|
|
2024-06-29 16:53:33 +08:00
|
|
|
RankedTensorType valResultType = cast<RankedTensorType>(
|
|
|
|
ConvertAtenReductionOp<AtenOpT>::getTypeConverter()->convertType(
|
|
|
|
op.getResult(0).getType()));
|
|
|
|
RankedTensorType idxResultType = cast<RankedTensorType>(
|
|
|
|
ConvertAtenReductionOp<AtenOpT>::getTypeConverter()->convertType(
|
|
|
|
op.getResult(1).getType()));
|
|
|
|
Type idxElementType = idxResultType.getElementType();
|
|
|
|
if (!isa<mlir::IntegerType>(idxElementType)) {
|
|
|
|
return op.emitError("indices result should to be integer tyep");
|
|
|
|
}
|
2022-08-03 10:47:52 +08:00
|
|
|
|
2024-06-29 16:53:33 +08:00
|
|
|
int64_t dim;
|
|
|
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
|
|
|
|
return rewriter.notifyMatchFailure(op, "non-int dim unsupported");
|
|
|
|
}
|
|
|
|
dim = toPositiveDim(dim, inputTy.getRank());
|
|
|
|
if (!isValidDim(dim, inputTy.getRank())) {
|
|
|
|
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
|
|
|
}
|
|
|
|
bool keepDim = false;
|
|
|
|
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
|
|
|
|
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
|
|
|
|
}
|
2022-08-03 10:47:52 +08:00
|
|
|
|
2024-06-29 16:53:33 +08:00
|
|
|
const auto &options = ConvertAtenReductionOp<AtenOpT>::getOptions();
|
|
|
|
auto inputShapeInfo =
|
|
|
|
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
|
|
|
if (failed(inputShapeInfo)) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "failed to get dimension sizes of the input");
|
|
|
|
}
|
|
|
|
auto inputShapeVec = *inputShapeInfo;
|
2022-08-03 10:47:52 +08:00
|
|
|
|
2024-06-29 16:53:33 +08:00
|
|
|
if (op.getResult(1).use_empty()) {
|
|
|
|
llvm::SmallVector<int64_t> outputShape(inputTy.getShape());
|
|
|
|
outputShape.erase(outputShape.begin() + dim);
|
|
|
|
Value reduceResult = createReduceOpWithSingleRegionOp(
|
|
|
|
op, input, RankedTensorType::get(outputShape, inputElemTy),
|
|
|
|
ArrayRef<int64_t>{dim}, rewriter);
|
|
|
|
if (!reduceResult) {
|
|
|
|
return op->emitError("createReduceOpWithSingleRegionOp return nullptr");
|
|
|
|
}
|
2024-05-16 00:05:19 +08:00
|
|
|
|
2024-06-29 16:53:33 +08:00
|
|
|
if (keepDim) {
|
|
|
|
reduceResult = reshapeReduceResultWhenKeepDim(
|
|
|
|
rewriter, op->getLoc(), reduceResult, inputShapeVec, valResultType,
|
|
|
|
{dim}, options.dimSizeIndexBits);
|
|
|
|
}
|
|
|
|
rewriter.replaceOp(op, {reduceResult, Value()});
|
2024-05-16 00:05:19 +08:00
|
|
|
return success();
|
2024-06-29 16:53:33 +08:00
|
|
|
} else {
|
|
|
|
ValueRange stablehloReduceResults =
|
|
|
|
createReduceOpReturnIndices(rewriter, op, input, inputShapeVec, dim,
|
|
|
|
options.dimSizeIndexBits)
|
|
|
|
.value();
|
|
|
|
if (keepDim) {
|
|
|
|
stablehloReduceResults[0] = reshapeReduceResultWhenKeepDim(
|
|
|
|
rewriter, op->getLoc(), stablehloReduceResults[0], inputShapeVec,
|
|
|
|
valResultType, {dim}, options.dimSizeIndexBits);
|
|
|
|
stablehloReduceResults[1] = reshapeReduceResultWhenKeepDim(
|
|
|
|
rewriter, op->getLoc(), stablehloReduceResults[1], inputShapeVec,
|
|
|
|
idxResultType, {dim}, options.dimSizeIndexBits);
|
|
|
|
}
|
2024-05-16 00:05:19 +08:00
|
|
|
rewriter.replaceOp(
|
2024-06-29 16:53:33 +08:00
|
|
|
op, {stablehloReduceResults[0], stablehloReduceResults[1]});
|
2024-05-16 00:05:19 +08:00
|
|
|
return success();
|
|
|
|
}
|
2024-06-29 16:53:33 +08:00
|
|
|
};
|
|
|
|
};
|
2022-08-03 10:47:52 +08:00
|
|
|
} // namespace
|
|
|
|
|
|
|
|
// AtenSumDimIntListOp
|
|
|
|
namespace {
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
|
|
|
AtenSumDimIntListOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = adaptor.getSelf();
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
2024-05-31 14:45:13 +08:00
|
|
|
auto outTy =
|
|
|
|
dyn_cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
2022-08-03 10:47:52 +08:00
|
|
|
if (!inputTy) {
|
2023-02-02 21:29:47 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "only Tensor types supported in StableHLO");
|
2022-08-03 10:47:52 +08:00
|
|
|
}
|
2022-11-23 15:02:41 +08:00
|
|
|
if (inputTy.getElementType() != outTy.getElementType()) {
|
|
|
|
// Use output element type as computation type.
|
|
|
|
auto dstElemTy = outTy.getElementType();
|
2023-02-02 21:29:47 +08:00
|
|
|
input =
|
|
|
|
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
2024-04-28 05:00:56 +08:00
|
|
|
inputTy = dyn_cast<RankedTensorType>(input.getType());
|
2022-08-03 10:47:52 +08:00
|
|
|
}
|
|
|
|
auto inputElemTy = inputTy.getElementType();
|
|
|
|
if (!inputElemTy.isIntOrFloat()) {
|
|
|
|
return op.emitError(
|
|
|
|
"Only floating-point or integer datatype legalization supported");
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<int64_t> inputDims;
|
|
|
|
SmallVector<int64_t> dims;
|
2024-04-24 11:25:46 +08:00
|
|
|
if (failed(checkNotNone(rewriter, op, op.getDim()))) {
|
2022-08-23 16:47:21 +08:00
|
|
|
inputDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, inputTy.getRank()));
|
2024-04-24 11:25:46 +08:00
|
|
|
} else {
|
|
|
|
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "non-const integer `dim` is not supported");
|
|
|
|
}
|
|
|
|
if (inputDims.size() == 0) {
|
|
|
|
inputDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, inputTy.getRank()));
|
|
|
|
}
|
2022-08-23 16:47:21 +08:00
|
|
|
}
|
2022-08-03 10:47:52 +08:00
|
|
|
for (auto d : inputDims) {
|
|
|
|
d = toPositiveDim(d, inputTy.getRank());
|
|
|
|
// Drop invalid dims
|
|
|
|
if (isValidDim(d, inputTy.getRank())) {
|
|
|
|
dims.push_back(d);
|
|
|
|
}
|
|
|
|
}
|
2024-06-29 16:53:33 +08:00
|
|
|
llvm::sort(dims.begin(), dims.end());
|
2022-08-03 10:47:52 +08:00
|
|
|
|
2024-06-29 16:53:33 +08:00
|
|
|
SmallVector<int64_t> reduceResultShape =
|
|
|
|
getReduceOutputShape(inputTy.getShape(), dims);
|
2024-02-16 01:08:48 +08:00
|
|
|
|
2022-08-03 10:47:52 +08:00
|
|
|
bool keepDim = false;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
|
2022-08-03 10:47:52 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
|
|
|
|
}
|
|
|
|
|
2024-06-29 16:53:33 +08:00
|
|
|
Value reduceResult = createReduceOpWithSingleRegionOp(
|
|
|
|
op, input,
|
|
|
|
RankedTensorType::get(reduceResultShape, outTy.getElementType()), dims,
|
|
|
|
rewriter);
|
|
|
|
if (!reduceResult) {
|
|
|
|
return op->emitError("createReduceOpWithSingleRegionOp return nullptr");
|
2022-08-03 10:47:52 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
if (keepDim) {
|
2022-09-01 10:36:02 +08:00
|
|
|
const auto &options = getOptions();
|
2023-02-02 21:29:47 +08:00
|
|
|
auto outShapeInfo =
|
|
|
|
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
2022-08-03 10:47:52 +08:00
|
|
|
if (failed(outShapeInfo)) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "failed to get dimension sizes of the input");
|
|
|
|
}
|
2024-06-29 16:53:33 +08:00
|
|
|
reduceResult = reshapeReduceResultWhenKeepDim(
|
|
|
|
rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, dims,
|
|
|
|
options.dimSizeIndexBits);
|
2022-08-03 10:47:52 +08:00
|
|
|
}
|
2024-06-29 16:53:33 +08:00
|
|
|
rewriter.replaceOp(op, reduceResult);
|
2022-08-03 10:47:52 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
} // namespace
|
|
|
|
|
2022-09-08 10:15:36 +08:00
|
|
|
// AtenFrobeniusNormDimOp
|
2023-02-02 21:29:47 +08:00
|
|
|
// aten.frobenius_norm.dim => stablehlo.reduce(calculate square sum along given
|
2023-03-21 05:14:27 +08:00
|
|
|
// dims) + stablehlo.sqrt
|
2022-09-08 10:15:36 +08:00
|
|
|
namespace {
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
|
|
|
|
AtenFrobeniusNormDimOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
2023-02-02 21:29:47 +08:00
|
|
|
const TorchToStablehloOptions &options = getOptions();
|
2022-09-08 10:15:36 +08:00
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
Value input = adaptor.getSelf();
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inputType = dyn_cast<RankedTensorType>(input.getType());
|
2022-09-08 10:15:36 +08:00
|
|
|
if (!inputType) {
|
|
|
|
return op.emitError(
|
|
|
|
"only ranked tensor input supported in AtenFrobeniusNormDimOp");
|
|
|
|
}
|
|
|
|
auto inputRank = inputType.getRank();
|
|
|
|
auto inputElemType = inputType.getElementType();
|
2024-04-11 21:47:35 +08:00
|
|
|
if (!isa<mlir::FloatType>(inputElemType)) {
|
2022-09-08 10:15:36 +08:00
|
|
|
return op.emitError(
|
|
|
|
"only float dtype allowed in input tensor of AtenFrobeniusNormDimOp");
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<int64_t> dims;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dims))) {
|
2022-09-08 10:15:36 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "non-const integer `dim` is not supported");
|
|
|
|
}
|
|
|
|
for (auto &dim : dims) {
|
|
|
|
dim = toPositiveDim(dim, inputRank);
|
|
|
|
if (!isValidDim(dim, inputRank)) {
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"invalid dimension detected in `dim`");
|
|
|
|
}
|
|
|
|
}
|
2023-02-02 21:29:47 +08:00
|
|
|
// Sort the dims in ascending order, making the conversion
|
2022-09-08 10:15:36 +08:00
|
|
|
// stable with unordered dims.
|
|
|
|
std::sort(dims.begin(), dims.end());
|
|
|
|
|
2024-06-29 16:53:33 +08:00
|
|
|
SmallVector<int64_t> reduceResultShape =
|
|
|
|
getReduceOutputShape(inputType.getShape(), dims);
|
2024-02-16 01:08:48 +08:00
|
|
|
|
2022-09-08 10:15:36 +08:00
|
|
|
bool keepDim = false;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
|
2022-09-08 10:15:36 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "non-const bool `keepdim` is not supported");
|
|
|
|
}
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
auto squareOp = rewriter.create<stablehlo::MulOp>(op->getLoc(), input, input);
|
2023-01-18 05:52:12 +08:00
|
|
|
|
2024-06-29 16:53:33 +08:00
|
|
|
Value reduceResult = createReduceOpWithSingleRegionOp(
|
|
|
|
op, squareOp.getResult(),
|
|
|
|
RankedTensorType::get(reduceResultShape, inputElemType), dims, rewriter);
|
|
|
|
if (!reduceResult) {
|
|
|
|
return op->emitError("createReduceOpWithSingleRegionOp return nullptr");
|
2022-09-08 10:15:36 +08:00
|
|
|
}
|
|
|
|
|
2024-06-29 16:53:33 +08:00
|
|
|
Value output = rewriter.create<stablehlo::SqrtOp>(op->getLoc(), reduceResult);
|
2022-09-08 10:15:36 +08:00
|
|
|
|
|
|
|
if (keepDim) {
|
2023-02-02 21:29:47 +08:00
|
|
|
auto outShapeInfo =
|
|
|
|
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
2022-09-08 10:15:36 +08:00
|
|
|
if (failed(outShapeInfo)) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "failed to get dimension sizes of the input");
|
|
|
|
}
|
2024-06-29 16:53:33 +08:00
|
|
|
output = reshapeReduceResultWhenKeepDim(
|
|
|
|
rewriter, op->getLoc(), output, *outShapeInfo,
|
|
|
|
getTypeConverter()->convertType(op.getType()), dims,
|
|
|
|
options.dimSizeIndexBits);
|
2022-09-08 10:15:36 +08:00
|
|
|
}
|
2024-06-29 16:53:33 +08:00
|
|
|
rewriter.replaceOp(op, output);
|
2022-09-08 10:15:36 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
} // namespace
|
|
|
|
|
2023-03-21 05:14:27 +08:00
|
|
|
// AtenLinalgVectorNormOp
|
|
|
|
namespace {
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenReductionOp<AtenLinalgVectorNormOp>::matchAndRewrite(
|
|
|
|
AtenLinalgVectorNormOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
const TorchToStablehloOptions &options = getOptions();
|
|
|
|
|
|
|
|
Value input = adaptor.getSelf();
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inputType = dyn_cast<RankedTensorType>(input.getType());
|
2023-03-21 05:14:27 +08:00
|
|
|
if (!inputType) {
|
|
|
|
return op.emitError(
|
|
|
|
"only ranked tensor input supported in AtenLinalgVectorNormOp");
|
|
|
|
}
|
|
|
|
int64_t inputRank = inputType.getRank();
|
|
|
|
|
|
|
|
auto outType =
|
2024-04-28 05:00:56 +08:00
|
|
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
2023-03-21 05:14:27 +08:00
|
|
|
auto outElemType = outType.getElementType();
|
2024-04-11 21:47:35 +08:00
|
|
|
if (!isa<mlir::FloatType>(outElemType)) {
|
2023-03-21 05:14:27 +08:00
|
|
|
return op.emitError("only float dtype allowed in AtenLinalgVectorNormOp");
|
|
|
|
}
|
|
|
|
|
|
|
|
if (inputType.getElementType() != outType.getElementType()) {
|
|
|
|
input =
|
|
|
|
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, outElemType);
|
|
|
|
}
|
|
|
|
|
|
|
|
Value ord =
|
|
|
|
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOrd(), outElemType);
|
|
|
|
|
|
|
|
SmallVector<int64_t> dims;
|
|
|
|
if (failed(checkNotNone(rewriter, op, op.getDim()))) {
|
|
|
|
dims = llvm::to_vector<4>(llvm::seq<int64_t>(0, inputRank));
|
|
|
|
} else {
|
|
|
|
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dims))) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "non-const integer `dim` is not supported");
|
|
|
|
}
|
|
|
|
|
|
|
|
for (auto &dim : dims) {
|
|
|
|
dim = toPositiveDim(dim, inputRank);
|
|
|
|
if (!isValidDim(dim, inputRank)) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "invalid dimension detected in `dim`");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// Sort the dims in ascending order, making the conversion
|
|
|
|
// stable with unordered dims.
|
|
|
|
std::sort(dims.begin(), dims.end());
|
|
|
|
}
|
|
|
|
|
2024-06-29 16:53:33 +08:00
|
|
|
SmallVector<int64_t> reduceResultShape =
|
|
|
|
getReduceOutputShape(inputType.getShape(), dims);
|
2024-02-16 01:08:48 +08:00
|
|
|
|
2023-03-21 05:14:27 +08:00
|
|
|
bool keepDim = false;
|
|
|
|
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "non-const bool `keepdim` is not supported");
|
|
|
|
}
|
|
|
|
|
|
|
|
Value absValue = rewriter.create<stablehlo::AbsOp>(op->getLoc(), input);
|
|
|
|
Value powValue = rewriter.create<chlo::BroadcastPowOp>(op->getLoc(), absValue,
|
|
|
|
ord, nullptr);
|
|
|
|
|
2024-06-29 16:53:33 +08:00
|
|
|
Value reduceResult = createReduceOpWithSingleRegionOp(
|
|
|
|
op, powValue, RankedTensorType::get(reduceResultShape, outElemType), dims,
|
|
|
|
rewriter);
|
|
|
|
if (!reduceResult) {
|
|
|
|
return op->emitError("createReduceOpWithSingleRegionOp return nullptr");
|
2023-03-21 05:14:27 +08:00
|
|
|
}
|
2024-06-29 16:53:33 +08:00
|
|
|
|
|
|
|
auto scalarType = RankedTensorType::get({}, outElemType);
|
2023-03-21 05:14:27 +08:00
|
|
|
auto constantOne = rewriter.create<stablehlo::ConstantOp>(
|
2024-06-29 16:53:33 +08:00
|
|
|
op->getLoc(), scalarType,
|
2023-03-21 05:14:27 +08:00
|
|
|
DenseElementsAttr::get(
|
2024-06-29 16:53:33 +08:00
|
|
|
scalarType,
|
2024-04-11 21:47:35 +08:00
|
|
|
APFloat(cast<mlir::FloatType>(outElemType).getFloatSemantics(), 1)));
|
2023-03-21 05:14:27 +08:00
|
|
|
auto reciprocalOrd = rewriter.create<stablehlo::DivOp>(
|
2024-06-29 16:53:33 +08:00
|
|
|
op->getLoc(), scalarType, constantOne, ord);
|
|
|
|
Value output = rewriter.create<chlo::BroadcastPowOp>(
|
|
|
|
op->getLoc(), reduceResult, reciprocalOrd, nullptr);
|
2023-03-21 05:14:27 +08:00
|
|
|
|
|
|
|
if (keepDim) {
|
|
|
|
auto outShapeInfo =
|
|
|
|
hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
|
|
|
if (failed(outShapeInfo)) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "failed to get dimension sizes of the input");
|
|
|
|
}
|
2024-06-29 16:53:33 +08:00
|
|
|
output = reshapeReduceResultWhenKeepDim(rewriter, op->getLoc(), output,
|
|
|
|
*outShapeInfo, outType, dims,
|
|
|
|
options.dimSizeIndexBits);
|
2023-03-21 05:14:27 +08:00
|
|
|
}
|
2024-06-29 16:53:33 +08:00
|
|
|
rewriter.replaceOp(op, output);
|
2023-03-21 05:14:27 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
} // namespace
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
|
2022-08-03 10:47:52 +08:00
|
|
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
2023-02-02 21:29:47 +08:00
|
|
|
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
2022-08-03 10:47:52 +08:00
|
|
|
MLIRContext *context = patterns.getContext();
|
|
|
|
#define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \
|
|
|
|
target.addIllegalOp<AtenOp>(); \
|
2022-09-01 10:36:02 +08:00
|
|
|
patterns.add<ConvertAtenReductionOp<AtenOp>>(typeConverter, context, options)
|
2022-08-03 10:47:52 +08:00
|
|
|
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp);
|
2022-09-08 10:15:36 +08:00
|
|
|
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenFrobeniusNormDimOp);
|
2023-03-21 05:14:27 +08:00
|
|
|
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp);
|
2022-08-03 10:47:52 +08:00
|
|
|
#undef INSERT_ATEN_REDUCTION_OP_PATTERN
|
2024-05-23 20:40:20 +08:00
|
|
|
|
|
|
|
#define INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenOp) \
|
|
|
|
target.addIllegalOp<AtenOp>(); \
|
|
|
|
patterns.add<ConvertAtenReduceAllDimsOp<AtenOp>>(typeConverter, context, \
|
|
|
|
options)
|
|
|
|
INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenMaxOp);
|
|
|
|
INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenMinOp);
|
|
|
|
INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenSumOp);
|
|
|
|
INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenProdOp);
|
|
|
|
INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenAllOp);
|
|
|
|
INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenAnyOp);
|
|
|
|
#undef INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN
|
|
|
|
|
2024-06-29 16:53:33 +08:00
|
|
|
#define INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenOp) \
|
2024-05-23 20:40:20 +08:00
|
|
|
target.addIllegalOp<AtenOp>(); \
|
2024-06-29 16:53:33 +08:00
|
|
|
patterns.add<ConvertAtenReduceOneDimOp<AtenOp>>(typeConverter, context, \
|
|
|
|
options)
|
|
|
|
INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenAnyDimOp);
|
|
|
|
#undef INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN
|
|
|
|
|
|
|
|
#define INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN(AtenOp) \
|
|
|
|
target.addIllegalOp<AtenOp>(); \
|
|
|
|
patterns.add<ConvertAtenReduceDimsOp<AtenOp>>(typeConverter, context, options)
|
|
|
|
INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN(AtenAmaxOp);
|
|
|
|
INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN(AtenAminOp);
|
|
|
|
#undef INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN
|
2024-05-23 20:40:20 +08:00
|
|
|
|
2024-06-29 16:53:33 +08:00
|
|
|
#define INSERT_ATEN_REDUCTION_WITH_INDICES_PATTERN(AtenOp) \
|
|
|
|
target.addIllegalOp<AtenOp>(); \
|
|
|
|
patterns.add<ConvertAtenReduceWithIndicesOp<AtenOp>>(typeConverter, context, \
|
|
|
|
options)
|
|
|
|
INSERT_ATEN_REDUCTION_WITH_INDICES_PATTERN(AtenMaxDimOp);
|
|
|
|
INSERT_ATEN_REDUCTION_WITH_INDICES_PATTERN(AtenMinDimOp);
|
|
|
|
#undef INSERT_ATEN_REDUCTION_WITH_INDICES_PATTERN
|
2022-08-03 10:47:52 +08:00
|
|
|
}
|