2022-07-27 13:07:51 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
2023-03-28 12:16:21 +08:00
|
|
|
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
|
2023-02-02 21:29:47 +08:00
|
|
|
|
2022-10-05 21:28:06 +08:00
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
2024-05-26 12:34:56 +08:00
|
|
|
#include "mlir/Dialect/Complex/IR/Complex.h"
|
2024-04-29 17:40:30 +08:00
|
|
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
2022-08-02 09:21:37 +08:00
|
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
2024-05-26 12:34:56 +08:00
|
|
|
#include "stablehlo/dialect/ChloOps.h"
|
2023-02-02 21:29:47 +08:00
|
|
|
#include "stablehlo/dialect/StablehloOps.h"
|
|
|
|
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
2022-07-27 13:07:51 +08:00
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
2022-08-02 09:21:37 +08:00
|
|
|
#include <numeric>
|
2022-07-27 13:07:51 +08:00
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::torch;
|
|
|
|
using namespace mlir::torch::Torch;
|
|
|
|
|
|
|
|
namespace mlir {
|
2023-02-02 21:29:47 +08:00
|
|
|
namespace hlo {
|
2022-07-27 13:07:51 +08:00
|
|
|
|
2024-05-26 12:34:56 +08:00
|
|
|
// Create chlo::ConstantLikeOp
|
|
|
|
template <typename T>
|
|
|
|
Value getConstantLike(OpBuilder &rewriter, Location loc, T constant,
|
|
|
|
Value val) {
|
|
|
|
Type ty = getElementTypeOrSelf(val.getType());
|
|
|
|
auto getAttr = [&]() -> Attribute {
|
|
|
|
if (isa<mlir::IntegerType>(ty))
|
|
|
|
return rewriter.getIntegerAttr(ty, constant);
|
|
|
|
if (isa<mlir::FloatType>(ty))
|
|
|
|
return rewriter.getFloatAttr(ty, constant);
|
|
|
|
if (auto complexTy = dyn_cast<mlir::ComplexType>(ty))
|
|
|
|
return mlir::complex::NumberAttr::get(complexTy, constant, 0);
|
|
|
|
llvm_unreachable("unhandled element type");
|
|
|
|
};
|
|
|
|
return rewriter.create<mlir::chlo::ConstantLikeOp>(
|
|
|
|
loc, cast<TypedAttr>(getAttr()), val);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Template instantiation
|
|
|
|
template Value getConstantLike<int64_t>(OpBuilder &rewriter, Location loc,
|
|
|
|
int64_t constant, Value val);
|
|
|
|
|
|
|
|
template Value getConstantLike<double>(OpBuilder &rewriter, Location loc,
|
|
|
|
double constant, Value val);
|
|
|
|
|
2022-07-27 13:07:51 +08:00
|
|
|
// Create a 32-bit float constant operator from a float
|
2023-02-02 21:29:47 +08:00
|
|
|
Value getStablehloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
|
|
|
|
float val) {
|
2022-07-27 13:07:51 +08:00
|
|
|
auto const_type = RankedTensorType::get({}, rewriter.getF32Type());
|
|
|
|
auto const_attr = DenseElementsAttr::get(const_type, val);
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
auto const_op = rewriter.create<stablehlo::ConstantOp>(
|
|
|
|
op->getLoc(), const_type, const_attr);
|
2022-07-27 13:07:51 +08:00
|
|
|
return const_op.getResult();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Create a 64-bit float constant operator from a double
|
2023-02-02 21:29:47 +08:00
|
|
|
Value getStablehloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op,
|
|
|
|
double val) {
|
2022-07-27 13:07:51 +08:00
|
|
|
auto const_type = RankedTensorType::get({}, rewriter.getF64Type());
|
|
|
|
auto const_attr = DenseElementsAttr::get(const_type, val);
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
auto const_op = rewriter.create<stablehlo::ConstantOp>(
|
|
|
|
op->getLoc(), const_type, const_attr);
|
2022-07-27 13:07:51 +08:00
|
|
|
return const_op.getResult();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Templated function to create a constant op for given type and shape.
|
|
|
|
// T: storage C type.
|
|
|
|
// Default template creates a constant tensor in T.
|
|
|
|
template <typename T>
|
2022-12-20 18:17:27 +08:00
|
|
|
std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
|
|
|
|
ArrayRef<T> vec, ArrayRef<int64_t> shape) {
|
2022-07-27 13:07:51 +08:00
|
|
|
uint64_t num_total_elements = 1;
|
|
|
|
for (int64_t a : shape) {
|
|
|
|
num_total_elements *= a;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (vec.size() != num_total_elements) {
|
|
|
|
op->emitOpError("getConstTensor(): number of elements mismatch.");
|
2022-12-14 18:44:05 +08:00
|
|
|
return std::nullopt;
|
2022-07-27 13:07:51 +08:00
|
|
|
}
|
|
|
|
|
2024-04-02 05:18:49 +08:00
|
|
|
RankedTensorType const_type;
|
|
|
|
if constexpr (std::is_same_v<T, APInt>) {
|
|
|
|
const_type = RankedTensorType::get(
|
|
|
|
shape, rewriter.getIntegerType(vec[0].getBitWidth()));
|
|
|
|
} else if constexpr (std::is_same_v<T, float>) {
|
|
|
|
const_type = RankedTensorType::get(shape, rewriter.getF32Type());
|
|
|
|
} else if constexpr (std::is_same_v<T, double>) {
|
|
|
|
const_type = RankedTensorType::get(shape, rewriter.getF64Type());
|
|
|
|
} else {
|
|
|
|
const_type =
|
|
|
|
RankedTensorType::get(shape, rewriter.getIntegerType(sizeof(T) * 8));
|
2022-07-27 13:07:51 +08:00
|
|
|
}
|
|
|
|
auto const_attr = DenseElementsAttr::get(const_type, vec);
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
auto const_op = rewriter.create<stablehlo::ConstantOp>(
|
|
|
|
op->getLoc(), const_type, const_attr);
|
2022-07-27 13:07:51 +08:00
|
|
|
return const_op.getResult();
|
|
|
|
}
|
|
|
|
|
2024-04-02 05:18:49 +08:00
|
|
|
// Template instantiation
|
|
|
|
template std::optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
|
|
|
|
Operation *op,
|
|
|
|
ArrayRef<APInt> vec,
|
|
|
|
ArrayRef<int64_t> shape);
|
2022-07-27 13:07:51 +08:00
|
|
|
|
2024-04-02 05:18:49 +08:00
|
|
|
template std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
|
|
|
|
Operation *op,
|
|
|
|
ArrayRef<float> vec,
|
|
|
|
ArrayRef<int64_t> shape);
|
2022-07-27 13:07:51 +08:00
|
|
|
|
2024-04-02 05:18:49 +08:00
|
|
|
template std::optional<Value> getConstTensor<double>(PatternRewriter &rewriter,
|
|
|
|
Operation *op,
|
|
|
|
ArrayRef<double> vec,
|
|
|
|
ArrayRef<int64_t> shape);
|
2022-07-27 13:07:51 +08:00
|
|
|
|
2022-12-20 18:17:27 +08:00
|
|
|
template std::optional<Value> getConstTensor<int32_t>(PatternRewriter &,
|
|
|
|
Operation *,
|
|
|
|
ArrayRef<int32_t> vec,
|
|
|
|
ArrayRef<int64_t> shape);
|
|
|
|
|
|
|
|
template std::optional<Value> getConstTensor<int64_t>(PatternRewriter &,
|
|
|
|
Operation *,
|
|
|
|
ArrayRef<int64_t> vec,
|
|
|
|
ArrayRef<int64_t> shape);
|
2022-07-27 13:07:51 +08:00
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
static bool isInValidRange(bool isFloat, const double &doubleValue, bool isInt,
|
|
|
|
const int64_t &intValue) {
|
|
|
|
if (isFloat) {
|
|
|
|
// Do a round-trip check here instead of numeric limits due to
|
|
|
|
// compiler warnings around double <-> int conversion.
|
|
|
|
return (doubleValue == static_cast<double>(static_cast<T>(doubleValue)));
|
|
|
|
} else {
|
|
|
|
assert(isInt);
|
|
|
|
return (intValue >= std::numeric_limits<T>::min()) &&
|
|
|
|
(intValue <= std::numeric_limits<T>::max());
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
2022-08-02 12:53:24 +08:00
|
|
|
Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
|
|
|
|
T val, Type dtype, llvm::ArrayRef<int64_t> dshape) {
|
|
|
|
auto const_type = RankedTensorType::get(dshape, dtype);
|
2022-07-27 13:07:51 +08:00
|
|
|
auto const_attr = SplatElementsAttr::get(const_type, val);
|
2023-02-02 21:29:47 +08:00
|
|
|
auto const_op = rewriter.create<stablehlo::ConstantOp>(
|
|
|
|
op->getLoc(), const_type, const_attr);
|
2022-07-27 13:07:51 +08:00
|
|
|
return const_op.getResult();
|
|
|
|
}
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
|
|
|
|
Operation *op, Value scalarValue, Type dtype) {
|
2022-08-03 08:16:31 +08:00
|
|
|
auto tensor = rewriter.create<tensor::FromElementsOp>(
|
|
|
|
op->getLoc(), ArrayRef<Value>{scalarValue});
|
|
|
|
auto dtype_tensor =
|
2023-02-02 21:29:47 +08:00
|
|
|
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), tensor, dtype);
|
|
|
|
return rewriter.create<stablehlo::ReshapeOp>(
|
2022-08-03 08:16:31 +08:00
|
|
|
op->getLoc(), RankedTensorType::get(mlir::ArrayRef<int64_t>{}, dtype),
|
|
|
|
dtype_tensor);
|
2022-07-27 13:07:51 +08:00
|
|
|
}
|
|
|
|
|
2023-06-26 00:04:17 +08:00
|
|
|
Value promoteType(PatternRewriter &rewriter, Location loc, Value input,
|
2024-07-10 10:52:19 +08:00
|
|
|
Type outElementType) {
|
|
|
|
TensorType inType = cast<TensorType>(input.getType());
|
|
|
|
if (inType.getElementType() != outElementType) {
|
|
|
|
return rewriter.create<stablehlo::ConvertOp>(loc, input, outElementType);
|
2022-08-02 12:53:24 +08:00
|
|
|
}
|
|
|
|
return input;
|
|
|
|
}
|
|
|
|
|
|
|
|
Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
|
|
|
|
TensorType outType) {
|
2022-07-27 13:07:51 +08:00
|
|
|
// Two tensors are “broadcastable” if the following rules hold:
|
|
|
|
// - Each tensor has at least one dimension.
|
2022-08-02 12:53:24 +08:00
|
|
|
// - When iterating over the dimension sizes, starting at the trailing
|
|
|
|
// dimension, the dimension sizes must either be equal, one of them is 1, or
|
|
|
|
// one of them does not exist.
|
|
|
|
Operation *op = input.getDefiningOp();
|
2024-04-28 05:00:56 +08:00
|
|
|
TensorType in_type = dyn_cast<TensorType>(input.getType());
|
2022-07-27 13:07:51 +08:00
|
|
|
|
|
|
|
if (in_type.getElementType() != outType.getElementType()) {
|
2022-08-02 12:53:24 +08:00
|
|
|
TensorType promoted_type =
|
|
|
|
in_type.cloneWith(in_type.getShape(), outType.getElementType());
|
2023-02-02 21:29:47 +08:00
|
|
|
input = rewriter.create<stablehlo::ConvertOp>(op->getLoc(), promoted_type,
|
|
|
|
input);
|
2022-07-27 13:07:51 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
ArrayRef<int64_t> inShape = in_type.getShape();
|
|
|
|
ArrayRef<int64_t> outShape = outType.getShape();
|
|
|
|
|
|
|
|
bool do_bcast = (inShape.size() != outShape.size());
|
|
|
|
SmallVector<int64_t> bcastDims;
|
|
|
|
for (size_t i = 0; i < inShape.size(); ++i) {
|
|
|
|
// iterating over the dimension sizes, starting at the trailing dimension
|
|
|
|
size_t outPos = outShape.size() - 1 - i;
|
|
|
|
size_t inPos = inShape.size() - 1 - i;
|
|
|
|
int64_t outDim = outShape[outPos];
|
|
|
|
int64_t inDim = inShape[inPos];
|
|
|
|
if (inDim == outDim) {
|
|
|
|
bcastDims.push_back(outPos);
|
|
|
|
} else if (inDim != outDim && inDim == 1) {
|
|
|
|
bcastDims.push_back(outPos);
|
|
|
|
do_bcast = true;
|
|
|
|
} else {
|
2022-08-02 12:53:24 +08:00
|
|
|
op->emitError("The size of tensor a (")
|
2024-04-28 05:08:09 +08:00
|
|
|
<< inDim << ")" << "must match the size of tensor b (" << outDim
|
|
|
|
<< ")" << "at non-singleton dimension " << inPos;
|
2022-07-27 13:07:51 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
std::reverse(bcastDims.begin(), bcastDims.end());
|
|
|
|
if (!do_bcast) {
|
|
|
|
return input;
|
|
|
|
}
|
Bump stablehlo to openxla/stablehlo@fd52182f76cadb82f2064fe5fc49a4fb4347a826 (#2821)
With the recent LLVM integrate and changes from
https://github.com/llvm/llvm-project/pull/78260, we hit this build error
in Stablehlo (which is quite old).
```
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1020:14: error: no member named 'startRootUpdate' in 'mlir::PatternRewriter'
rewriter.startRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1026:16: error: no member named 'finalizeRootUpdate' in 'mlir::PatternRewriter'
rewriter.finalizeRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1029:16: error: no member named 'cancelRootUpdate' in 'mlir::PatternRewriter'
rewriter.cancelRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1108:14: error: no member named 'updateRootInPlace' in 'mlir::PatternRewriter'
rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; });
~~~~~~~~ ^
4 errors generated.
Target @torch-mlir//:torch-mlir-opt failed to build
```
I'm still puzzled as to how this didn't fail with the CMake merge gating
CI (do we not test Stablehlo builds/tests?). In any case, bumping our
submodule to https://github.com/openxla/stablehlo/pull/1918 fixes it.
It exposes a new failing lit test in TorchToStablehlo though, that I
have looped stablehlo developers into
([here](https://discord.com/channels/999073994483433573/999074539138990131/1201235845391331419)).
```
bazel run @torch-mlir//test/Conversion:TorchToStablehlo/scatter.mlir.test
...external/torch-mlir/test/Conversion/TorchToStablehlo/scatter.mlir
within split at <stdin>:1 offset :33:8: error: unexpected error: Expects non-empty reduction block for type inference
%0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64>
^
LLVM ERROR: Failed to infer result type(s).
```
Bazel CI:
https://github.com/sjain-stanford/torch-mlir/actions/runs/7732673480/job/21083102228
2024-02-01 06:21:17 +08:00
|
|
|
auto bcast_attr = rewriter.getDenseI64ArrayAttr(bcastDims);
|
2023-02-02 21:29:47 +08:00
|
|
|
auto bcast_op = rewriter.create<stablehlo::BroadcastInDimOp>(
|
|
|
|
op->getLoc(), outType, input, bcast_attr);
|
2022-07-27 13:07:51 +08:00
|
|
|
return bcast_op.getResult();
|
|
|
|
}
|
2022-08-02 09:21:37 +08:00
|
|
|
|
2023-12-08 15:13:42 +08:00
|
|
|
SmallVector<int64_t> toPositiveDims(ArrayRef<int64_t> dims, int64_t rank) {
|
|
|
|
SmallVector<int64_t> posDims;
|
2022-08-02 09:21:37 +08:00
|
|
|
posDims.reserve(rank);
|
|
|
|
std::transform(
|
|
|
|
dims.begin(), dims.end(), std::back_inserter(posDims),
|
2023-12-08 15:13:42 +08:00
|
|
|
[rank](int64_t d) -> int64_t { return toPositiveDim(d, rank); });
|
2022-08-02 09:21:37 +08:00
|
|
|
return posDims;
|
|
|
|
}
|
|
|
|
|
2022-09-01 10:36:02 +08:00
|
|
|
FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
|
|
|
|
Operation *op, Value value,
|
|
|
|
ArrayRef<int64_t> inpDims,
|
|
|
|
size_t dimSizeIndexBits) {
|
2024-04-28 05:00:56 +08:00
|
|
|
auto valueTy = dyn_cast<RankedTensorType>(value.getType());
|
2022-08-02 09:21:37 +08:00
|
|
|
if (!valueTy) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "getDimSizesOfTensor(): the input is not a ranked tensor");
|
|
|
|
}
|
|
|
|
|
|
|
|
auto rank = valueTy.getRank();
|
|
|
|
auto dims = toPositiveDims(inpDims, rank);
|
|
|
|
SmallVector<Value, 4> dimSizes;
|
|
|
|
dimSizes.reserve(dims.size());
|
|
|
|
|
|
|
|
auto loc = op->getLoc();
|
|
|
|
for (auto d : dims) {
|
|
|
|
dimSizes.emplace_back(rewriter.create<arith::IndexCastOp>(
|
2022-09-01 10:36:02 +08:00
|
|
|
loc, rewriter.getIntegerType(dimSizeIndexBits),
|
2022-08-02 09:21:37 +08:00
|
|
|
rewriter.create<tensor::DimOp>(loc, value, d)));
|
|
|
|
}
|
|
|
|
return dimSizes;
|
|
|
|
}
|
|
|
|
|
2022-09-01 10:36:02 +08:00
|
|
|
FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
|
|
|
|
Operation *op, Value value,
|
|
|
|
size_t dimSizeIndexBits) {
|
2024-04-28 05:00:56 +08:00
|
|
|
auto valueTy = dyn_cast<RankedTensorType>(value.getType());
|
2022-08-02 09:21:37 +08:00
|
|
|
if (!valueTy) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "getDimSizesOfTensor(): the input is not a ranked tensor");
|
|
|
|
}
|
|
|
|
|
|
|
|
auto rank = valueTy.getRank();
|
|
|
|
// Get int vector [0, 1, ..., rank-1]
|
|
|
|
std::vector<int64_t> dims(rank);
|
|
|
|
std::iota(dims.begin(), dims.end(), 0);
|
2022-09-01 10:36:02 +08:00
|
|
|
return getDimSizesOfTensor(rewriter, op, value, dims, dimSizeIndexBits);
|
2022-08-02 09:21:37 +08:00
|
|
|
}
|
|
|
|
|
[Stablehlo] use index type as dim size, avoid to generate index_cast (#3526)
For example, the original IR is:
```
module attributes {torch.debug_module_name = "Matmul3D"} {
func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
%0 = arith.index_cast %dim : index to i64
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
%1 = arith.index_cast %dim_0 : index to i64
%dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
%2 = arith.index_cast %dim_1 : index to i64
%from_elements = tensor.from_elements %0, %1, %2 : tensor<3xi64>
%3 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
%4 = stablehlo.dot_general %arg0, %3, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %4 : tensor<?x?x?xf32>
}
}
```
After using IndexType, the IR is:
```
module attributes {torch.debug_module_name = "Matmul3D"} {
func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
%dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
%from_elements = tensor.from_elements %dim, %dim_0, %dim_1 : tensor<3xindex>
%0 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
%1 = stablehlo.dot_general %arg0, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
}
}
```
The benefits of using IndexType on shape tensor:
* simplify the IR, avoid to generate `arith.index_cast`
* let backend compiler have a chance to decide the index width of shape
tensor
* let stablehlo backend have a chance to serialize dynamic shape IR by
[shape_legalize_to_stablehlo](https://github.com/openxla/stablehlo/blob/main/stablehlo/tests/shape_legalize_to_stablehlo.mlir)
2024-07-07 18:03:03 +08:00
|
|
|
// Get the dimension sizes of the input tensor, given the dimension axes
|
|
|
|
FailureOr<SmallVector<Value, 4>>
|
|
|
|
getDimIndexOfTensor(PatternRewriter &rewriter, Operation *op, Value value,
|
|
|
|
ArrayRef<int64_t> inpDims) {
|
|
|
|
auto valueTy = dyn_cast<RankedTensorType>(value.getType());
|
|
|
|
if (!valueTy) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "getDimIndexOfTensor(): the input is not a ranked tensor");
|
|
|
|
}
|
|
|
|
|
|
|
|
auto rank = valueTy.getRank();
|
|
|
|
auto dims = toPositiveDims(inpDims, rank);
|
|
|
|
SmallVector<Value, 4> dimSizes;
|
|
|
|
dimSizes.reserve(dims.size());
|
|
|
|
|
|
|
|
auto loc = op->getLoc();
|
|
|
|
for (auto d : dims) {
|
|
|
|
dimSizes.emplace_back(rewriter.create<tensor::DimOp>(loc, value, d));
|
|
|
|
}
|
|
|
|
return dimSizes;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Get the dimension sizes of the input tensor
|
|
|
|
FailureOr<SmallVector<Value, 4>>
|
|
|
|
getDimIndexOfTensor(PatternRewriter &rewriter, Operation *op, Value value) {
|
|
|
|
auto valueTy = dyn_cast<RankedTensorType>(value.getType());
|
|
|
|
if (!valueTy) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "getDimIndexOfTensor(): the input is not a ranked tensor");
|
|
|
|
}
|
|
|
|
|
|
|
|
auto rank = valueTy.getRank();
|
|
|
|
// Get int vector [0, 1, ..., rank-1]
|
|
|
|
std::vector<int64_t> dims(rank);
|
|
|
|
std::iota(dims.begin(), dims.end(), 0);
|
|
|
|
return getDimIndexOfTensor(rewriter, op, value, dims);
|
|
|
|
}
|
|
|
|
|
2022-08-02 09:21:37 +08:00
|
|
|
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
[Stablehlo] use index type as dim size, avoid to generate index_cast (#3526)
For example, the original IR is:
```
module attributes {torch.debug_module_name = "Matmul3D"} {
func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
%0 = arith.index_cast %dim : index to i64
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
%1 = arith.index_cast %dim_0 : index to i64
%dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
%2 = arith.index_cast %dim_1 : index to i64
%from_elements = tensor.from_elements %0, %1, %2 : tensor<3xi64>
%3 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
%4 = stablehlo.dot_general %arg0, %3, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %4 : tensor<?x?x?xf32>
}
}
```
After using IndexType, the IR is:
```
module attributes {torch.debug_module_name = "Matmul3D"} {
func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
%dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
%from_elements = tensor.from_elements %dim, %dim_0, %dim_1 : tensor<3xindex>
%0 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
%1 = stablehlo.dot_general %arg0, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
}
}
```
The benefits of using IndexType on shape tensor:
* simplify the IR, avoid to generate `arith.index_cast`
* let backend compiler have a chance to decide the index width of shape
tensor
* let stablehlo backend have a chance to serialize dynamic shape IR by
[shape_legalize_to_stablehlo](https://github.com/openxla/stablehlo/blob/main/stablehlo/tests/shape_legalize_to_stablehlo.mlir)
2024-07-07 18:03:03 +08:00
|
|
|
Value tensor,
|
|
|
|
ArrayRef<int64_t> inputUnsqzDims) {
|
2022-08-02 09:21:37 +08:00
|
|
|
// Returns a new tensor with dims of size 1 inserted at the specified
|
|
|
|
// position.
|
|
|
|
//
|
|
|
|
// The position indices (must be high to low dimension number of the returned
|
|
|
|
// tensor) are specified with unsqzDims. Indices must be in-order, and in
|
|
|
|
// range of tensor rank. Thus, unsqueeze a rank 1 tensor with {0, 2}, {0, 1,
|
|
|
|
// 3}, {0, 1, 2} are all valid dimension sets, but {0, 3}, {2} are not.
|
[Stablehlo] use index type as dim size, avoid to generate index_cast (#3526)
For example, the original IR is:
```
module attributes {torch.debug_module_name = "Matmul3D"} {
func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
%0 = arith.index_cast %dim : index to i64
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
%1 = arith.index_cast %dim_0 : index to i64
%dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
%2 = arith.index_cast %dim_1 : index to i64
%from_elements = tensor.from_elements %0, %1, %2 : tensor<3xi64>
%3 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
%4 = stablehlo.dot_general %arg0, %3, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %4 : tensor<?x?x?xf32>
}
}
```
After using IndexType, the IR is:
```
module attributes {torch.debug_module_name = "Matmul3D"} {
func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
%dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
%from_elements = tensor.from_elements %dim, %dim_0, %dim_1 : tensor<3xindex>
%0 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
%1 = stablehlo.dot_general %arg0, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
}
}
```
The benefits of using IndexType on shape tensor:
* simplify the IR, avoid to generate `arith.index_cast`
* let backend compiler have a chance to decide the index width of shape
tensor
* let stablehlo backend have a chance to serialize dynamic shape IR by
[shape_legalize_to_stablehlo](https://github.com/openxla/stablehlo/blob/main/stablehlo/tests/shape_legalize_to_stablehlo.mlir)
2024-07-07 18:03:03 +08:00
|
|
|
auto dimSizesInfo = getDimIndexOfTensor(rewriter, op, tensor);
|
2022-08-02 09:21:37 +08:00
|
|
|
if (failed(dimSizesInfo))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "failed to get dimension sizes of the input");
|
|
|
|
|
|
|
|
auto dimSizes = *dimSizesInfo;
|
2023-12-08 15:13:42 +08:00
|
|
|
int64_t rank = dimSizes.size();
|
|
|
|
int64_t newRank = rank + inputUnsqzDims.size();
|
2022-08-02 09:21:37 +08:00
|
|
|
auto unsqzDims = toPositiveDims(inputUnsqzDims, newRank);
|
2023-12-08 15:13:42 +08:00
|
|
|
for (int64_t k = 0, sz = unsqzDims.size(); k < sz; ++k)
|
2022-08-02 09:21:37 +08:00
|
|
|
if (k > 1 && unsqzDims[k] <= unsqzDims[k - 1])
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "unsqueeze dimensions must be specified in order");
|
|
|
|
|
|
|
|
auto loc = op->getLoc();
|
2024-04-28 05:00:56 +08:00
|
|
|
auto rankTy = dyn_cast<RankedTensorType>(tensor.getType());
|
2022-08-02 09:21:37 +08:00
|
|
|
auto oldShape = rankTy.getShape();
|
|
|
|
auto one = rewriter.create<arith::ConstantOp>(
|
[Stablehlo] use index type as dim size, avoid to generate index_cast (#3526)
For example, the original IR is:
```
module attributes {torch.debug_module_name = "Matmul3D"} {
func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
%0 = arith.index_cast %dim : index to i64
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
%1 = arith.index_cast %dim_0 : index to i64
%dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
%2 = arith.index_cast %dim_1 : index to i64
%from_elements = tensor.from_elements %0, %1, %2 : tensor<3xi64>
%3 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
%4 = stablehlo.dot_general %arg0, %3, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %4 : tensor<?x?x?xf32>
}
}
```
After using IndexType, the IR is:
```
module attributes {torch.debug_module_name = "Matmul3D"} {
func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
%dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
%from_elements = tensor.from_elements %dim, %dim_0, %dim_1 : tensor<3xindex>
%0 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
%1 = stablehlo.dot_general %arg0, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
}
}
```
The benefits of using IndexType on shape tensor:
* simplify the IR, avoid to generate `arith.index_cast`
* let backend compiler have a chance to decide the index width of shape
tensor
* let stablehlo backend have a chance to serialize dynamic shape IR by
[shape_legalize_to_stablehlo](https://github.com/openxla/stablehlo/blob/main/stablehlo/tests/shape_legalize_to_stablehlo.mlir)
2024-07-07 18:03:03 +08:00
|
|
|
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
|
2022-08-02 09:21:37 +08:00
|
|
|
|
|
|
|
std::vector<Value> newDimSizes;
|
|
|
|
std::vector<int64_t> newShape;
|
|
|
|
newDimSizes.reserve(newRank);
|
|
|
|
newShape.reserve(newRank);
|
2023-12-08 15:13:42 +08:00
|
|
|
for (int64_t k = 0, i = 0, j = 0; k < newRank; ++k) {
|
|
|
|
if (j < static_cast<int64_t>(unsqzDims.size()) && unsqzDims[j] == k) {
|
2022-08-02 09:21:37 +08:00
|
|
|
newDimSizes.push_back(one);
|
|
|
|
newShape.push_back(1);
|
|
|
|
j++;
|
|
|
|
} else {
|
|
|
|
newDimSizes.push_back(dimSizes[i]);
|
|
|
|
newShape.push_back(oldShape[i]);
|
|
|
|
i++;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
auto outTy = RankedTensorType::get(newShape, rankTy.getElementType());
|
2023-02-02 21:29:47 +08:00
|
|
|
auto shape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
|
|
|
|
return rewriter.create<stablehlo::DynamicReshapeOp>(loc, outTy, tensor, shape)
|
2022-08-02 09:21:37 +08:00
|
|
|
.getResult();
|
|
|
|
}
|
|
|
|
|
2024-04-29 17:40:30 +08:00
|
|
|
FailureOr<Value> collapseTensor(PatternRewriter &rewriter, Operation *op,
|
|
|
|
Value tensor, int64_t collapseStartDim,
|
[Stablehlo] use index type as dim size, avoid to generate index_cast (#3526)
For example, the original IR is:
```
module attributes {torch.debug_module_name = "Matmul3D"} {
func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
%0 = arith.index_cast %dim : index to i64
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
%1 = arith.index_cast %dim_0 : index to i64
%dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
%2 = arith.index_cast %dim_1 : index to i64
%from_elements = tensor.from_elements %0, %1, %2 : tensor<3xi64>
%3 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
%4 = stablehlo.dot_general %arg0, %3, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %4 : tensor<?x?x?xf32>
}
}
```
After using IndexType, the IR is:
```
module attributes {torch.debug_module_name = "Matmul3D"} {
func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
%dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
%from_elements = tensor.from_elements %dim, %dim_0, %dim_1 : tensor<3xindex>
%0 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
%1 = stablehlo.dot_general %arg0, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
}
}
```
The benefits of using IndexType on shape tensor:
* simplify the IR, avoid to generate `arith.index_cast`
* let backend compiler have a chance to decide the index width of shape
tensor
* let stablehlo backend have a chance to serialize dynamic shape IR by
[shape_legalize_to_stablehlo](https://github.com/openxla/stablehlo/blob/main/stablehlo/tests/shape_legalize_to_stablehlo.mlir)
2024-07-07 18:03:03 +08:00
|
|
|
int64_t collapseEndDim) {
|
2024-04-29 17:40:30 +08:00
|
|
|
|
[Stablehlo] use index type as dim size, avoid to generate index_cast (#3526)
For example, the original IR is:
```
module attributes {torch.debug_module_name = "Matmul3D"} {
func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
%0 = arith.index_cast %dim : index to i64
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
%1 = arith.index_cast %dim_0 : index to i64
%dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
%2 = arith.index_cast %dim_1 : index to i64
%from_elements = tensor.from_elements %0, %1, %2 : tensor<3xi64>
%3 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
%4 = stablehlo.dot_general %arg0, %3, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %4 : tensor<?x?x?xf32>
}
}
```
After using IndexType, the IR is:
```
module attributes {torch.debug_module_name = "Matmul3D"} {
func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
%dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
%from_elements = tensor.from_elements %dim, %dim_0, %dim_1 : tensor<3xindex>
%0 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
%1 = stablehlo.dot_general %arg0, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
}
}
```
The benefits of using IndexType on shape tensor:
* simplify the IR, avoid to generate `arith.index_cast`
* let backend compiler have a chance to decide the index width of shape
tensor
* let stablehlo backend have a chance to serialize dynamic shape IR by
[shape_legalize_to_stablehlo](https://github.com/openxla/stablehlo/blob/main/stablehlo/tests/shape_legalize_to_stablehlo.mlir)
2024-07-07 18:03:03 +08:00
|
|
|
auto dimSizesInfo = getDimIndexOfTensor(rewriter, op, tensor);
|
2024-04-29 17:40:30 +08:00
|
|
|
if (failed(dimSizesInfo))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "failed to get dimension sizes of the input");
|
|
|
|
|
|
|
|
auto dimSizes = *dimSizesInfo;
|
|
|
|
int64_t rank = dimSizes.size();
|
|
|
|
|
|
|
|
collapseStartDim = toPositiveDim(collapseStartDim, rank);
|
|
|
|
collapseEndDim = toPositiveDim(collapseEndDim, rank);
|
|
|
|
|
|
|
|
int64_t newRank = rank - (collapseEndDim - collapseStartDim + 1);
|
|
|
|
|
|
|
|
auto loc = op->getLoc();
|
|
|
|
auto rankTy = dyn_cast<RankedTensorType>(tensor.getType());
|
|
|
|
auto oldShape = rankTy.getShape();
|
|
|
|
|
|
|
|
std::vector<Value> newDimSizes;
|
|
|
|
std::vector<int64_t> newShape;
|
|
|
|
newDimSizes.reserve(newRank);
|
|
|
|
newShape.reserve(newRank);
|
|
|
|
|
|
|
|
Value collapseDimSize = rewriter.create<arith::ConstantOp>(
|
[Stablehlo] use index type as dim size, avoid to generate index_cast (#3526)
For example, the original IR is:
```
module attributes {torch.debug_module_name = "Matmul3D"} {
func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
%0 = arith.index_cast %dim : index to i64
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
%1 = arith.index_cast %dim_0 : index to i64
%dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
%2 = arith.index_cast %dim_1 : index to i64
%from_elements = tensor.from_elements %0, %1, %2 : tensor<3xi64>
%3 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
%4 = stablehlo.dot_general %arg0, %3, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %4 : tensor<?x?x?xf32>
}
}
```
After using IndexType, the IR is:
```
module attributes {torch.debug_module_name = "Matmul3D"} {
func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
%dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
%from_elements = tensor.from_elements %dim, %dim_0, %dim_1 : tensor<3xindex>
%0 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
%1 = stablehlo.dot_general %arg0, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
}
}
```
The benefits of using IndexType on shape tensor:
* simplify the IR, avoid to generate `arith.index_cast`
* let backend compiler have a chance to decide the index width of shape
tensor
* let stablehlo backend have a chance to serialize dynamic shape IR by
[shape_legalize_to_stablehlo](https://github.com/openxla/stablehlo/blob/main/stablehlo/tests/shape_legalize_to_stablehlo.mlir)
2024-07-07 18:03:03 +08:00
|
|
|
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
|
2024-04-29 17:40:30 +08:00
|
|
|
int64_t collapseShape = 1;
|
|
|
|
|
|
|
|
for (int64_t k = collapseStartDim; k <= collapseEndDim; ++k) {
|
|
|
|
if (k < 0 || k >= rank) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "collapse dimensions must be within the rank of the tensor");
|
|
|
|
}
|
|
|
|
if (collapseShape == ShapedType::kDynamic ||
|
|
|
|
oldShape[k] == ShapedType::kDynamic) {
|
|
|
|
collapseShape = ShapedType::kDynamic;
|
|
|
|
} else {
|
|
|
|
collapseShape *= oldShape[k];
|
|
|
|
}
|
|
|
|
collapseDimSize =
|
|
|
|
rewriter.create<arith::MulIOp>(loc, collapseDimSize, dimSizes[k]);
|
|
|
|
}
|
|
|
|
|
|
|
|
for (int64_t k = 0; k < collapseStartDim; ++k) {
|
|
|
|
newDimSizes.push_back(dimSizes[k]);
|
|
|
|
newShape.push_back(oldShape[k]);
|
|
|
|
}
|
|
|
|
newDimSizes.push_back(collapseDimSize);
|
|
|
|
newShape.push_back(collapseShape);
|
|
|
|
for (int64_t k = collapseEndDim + 1; k < rank; ++k) {
|
|
|
|
newDimSizes.push_back(dimSizes[k]);
|
|
|
|
newShape.push_back(oldShape[k]);
|
|
|
|
}
|
|
|
|
|
|
|
|
auto outTy = RankedTensorType::get(newShape, rankTy.getElementType());
|
|
|
|
auto shape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
|
|
|
|
return rewriter.create<stablehlo::DynamicReshapeOp>(loc, outTy, tensor, shape)
|
|
|
|
.getResult();
|
|
|
|
}
|
|
|
|
|
|
|
|
// TODO: support splitDim & outerLength to be Value
|
|
|
|
FailureOr<Value> splitTensor(PatternRewriter &rewriter, Operation *op,
|
|
|
|
Value tensor, int64_t splitDim,
|
[Stablehlo] use index type as dim size, avoid to generate index_cast (#3526)
For example, the original IR is:
```
module attributes {torch.debug_module_name = "Matmul3D"} {
func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
%0 = arith.index_cast %dim : index to i64
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
%1 = arith.index_cast %dim_0 : index to i64
%dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
%2 = arith.index_cast %dim_1 : index to i64
%from_elements = tensor.from_elements %0, %1, %2 : tensor<3xi64>
%3 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
%4 = stablehlo.dot_general %arg0, %3, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %4 : tensor<?x?x?xf32>
}
}
```
After using IndexType, the IR is:
```
module attributes {torch.debug_module_name = "Matmul3D"} {
func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
%dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
%from_elements = tensor.from_elements %dim, %dim_0, %dim_1 : tensor<3xindex>
%0 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
%1 = stablehlo.dot_general %arg0, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
}
}
```
The benefits of using IndexType on shape tensor:
* simplify the IR, avoid to generate `arith.index_cast`
* let backend compiler have a chance to decide the index width of shape
tensor
* let stablehlo backend have a chance to serialize dynamic shape IR by
[shape_legalize_to_stablehlo](https://github.com/openxla/stablehlo/blob/main/stablehlo/tests/shape_legalize_to_stablehlo.mlir)
2024-07-07 18:03:03 +08:00
|
|
|
int64_t outerLength) {
|
|
|
|
auto dimSizesInfo = getDimIndexOfTensor(rewriter, op, tensor);
|
2024-04-29 17:40:30 +08:00
|
|
|
if (failed(dimSizesInfo))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "failed to get dimension sizes of the input");
|
|
|
|
|
|
|
|
auto dimSizes = *dimSizesInfo;
|
|
|
|
int64_t rank = dimSizes.size();
|
|
|
|
splitDim = toPositiveDim(splitDim, rank);
|
|
|
|
|
|
|
|
auto loc = op->getLoc();
|
|
|
|
auto rankTy = dyn_cast<RankedTensorType>(tensor.getType());
|
|
|
|
auto oldShape = rankTy.getShape();
|
|
|
|
|
|
|
|
if (splitDim < 0 || splitDim >= rank) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "split dimensions must be within the rank of the tensor");
|
|
|
|
}
|
|
|
|
|
|
|
|
int64_t newRank = rank + 1;
|
|
|
|
auto outerLengthValue = rewriter.create<arith::ConstantOp>(
|
[Stablehlo] use index type as dim size, avoid to generate index_cast (#3526)
For example, the original IR is:
```
module attributes {torch.debug_module_name = "Matmul3D"} {
func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
%0 = arith.index_cast %dim : index to i64
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
%1 = arith.index_cast %dim_0 : index to i64
%dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
%2 = arith.index_cast %dim_1 : index to i64
%from_elements = tensor.from_elements %0, %1, %2 : tensor<3xi64>
%3 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
%4 = stablehlo.dot_general %arg0, %3, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %4 : tensor<?x?x?xf32>
}
}
```
After using IndexType, the IR is:
```
module attributes {torch.debug_module_name = "Matmul3D"} {
func.func @forward(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
%dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
%from_elements = tensor.from_elements %dim, %dim_0, %dim_1 : tensor<3xindex>
%0 = stablehlo.dynamic_broadcast_in_dim %arg1, %from_elements, dims = [0, 1, 2] : (tensor<?x?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
%1 = stablehlo.dot_general %arg0, %0, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
}
}
```
The benefits of using IndexType on shape tensor:
* simplify the IR, avoid to generate `arith.index_cast`
* let backend compiler have a chance to decide the index width of shape
tensor
* let stablehlo backend have a chance to serialize dynamic shape IR by
[shape_legalize_to_stablehlo](https://github.com/openxla/stablehlo/blob/main/stablehlo/tests/shape_legalize_to_stablehlo.mlir)
2024-07-07 18:03:03 +08:00
|
|
|
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), outerLength));
|
2024-04-29 17:40:30 +08:00
|
|
|
|
|
|
|
auto innerLengthValue = rewriter.create<arith::DivSIOp>(
|
|
|
|
loc, dimSizes[splitDim], outerLengthValue);
|
|
|
|
|
|
|
|
int64_t originShape = oldShape[splitDim];
|
|
|
|
int64_t outerShape = outerLength;
|
|
|
|
int64_t innerShape = originShape == ShapedType::kDynamic
|
|
|
|
? ShapedType::kDynamic
|
|
|
|
: originShape / outerLength;
|
|
|
|
|
|
|
|
std::vector<Value> newDimSizes;
|
|
|
|
std::vector<int64_t> newShape;
|
|
|
|
|
|
|
|
newDimSizes.reserve(newRank);
|
|
|
|
newShape.reserve(newRank);
|
|
|
|
|
|
|
|
for (int64_t k = 0; k < splitDim; ++k) {
|
|
|
|
newDimSizes.push_back(dimSizes[k]);
|
|
|
|
newShape.push_back(oldShape[k]);
|
|
|
|
}
|
|
|
|
newDimSizes.push_back(outerLengthValue);
|
|
|
|
newShape.push_back(outerShape);
|
|
|
|
newDimSizes.push_back(innerLengthValue);
|
|
|
|
newShape.push_back(innerShape);
|
|
|
|
|
|
|
|
for (int64_t k = splitDim + 1; k < rank; ++k) {
|
|
|
|
newDimSizes.push_back(dimSizes[k]);
|
|
|
|
newShape.push_back(oldShape[k]);
|
|
|
|
}
|
|
|
|
|
|
|
|
auto outTy = RankedTensorType::get(newShape, rankTy.getElementType());
|
|
|
|
auto shape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
|
|
|
|
return rewriter.create<stablehlo::DynamicReshapeOp>(loc, outTy, tensor, shape)
|
|
|
|
.getResult();
|
|
|
|
}
|
|
|
|
|
2022-08-02 12:53:24 +08:00
|
|
|
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
|
|
|
|
const APFloat &constant, Value shape,
|
|
|
|
TensorType outType) {
|
|
|
|
auto constAttr = rewriter.getFloatAttr(outType.getElementType(), constant);
|
2023-02-02 21:29:47 +08:00
|
|
|
auto constTensor = rewriter.create<stablehlo::ConstantOp>(loc, constAttr);
|
2022-08-02 12:53:24 +08:00
|
|
|
return rewriter
|
2023-02-02 21:29:47 +08:00
|
|
|
.create<stablehlo::DynamicBroadcastInDimOp>(
|
Bump stablehlo to openxla/stablehlo@fd52182f76cadb82f2064fe5fc49a4fb4347a826 (#2821)
With the recent LLVM integrate and changes from
https://github.com/llvm/llvm-project/pull/78260, we hit this build error
in Stablehlo (which is quite old).
```
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1020:14: error: no member named 'startRootUpdate' in 'mlir::PatternRewriter'
rewriter.startRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1026:16: error: no member named 'finalizeRootUpdate' in 'mlir::PatternRewriter'
rewriter.finalizeRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1029:16: error: no member named 'cancelRootUpdate' in 'mlir::PatternRewriter'
rewriter.cancelRootUpdate(op);
~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1108:14: error: no member named 'updateRootInPlace' in 'mlir::PatternRewriter'
rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; });
~~~~~~~~ ^
4 errors generated.
Target @torch-mlir//:torch-mlir-opt failed to build
```
I'm still puzzled as to how this didn't fail with the CMake merge gating
CI (do we not test Stablehlo builds/tests?). In any case, bumping our
submodule to https://github.com/openxla/stablehlo/pull/1918 fixes it.
It exposes a new failing lit test in TorchToStablehlo though, that I
have looped stablehlo developers into
([here](https://discord.com/channels/999073994483433573/999074539138990131/1201235845391331419)).
```
bazel run @torch-mlir//test/Conversion:TorchToStablehlo/scatter.mlir.test
...external/torch-mlir/test/Conversion/TorchToStablehlo/scatter.mlir
within split at <stdin>:1 offset :33:8: error: unexpected error: Expects non-empty reduction block for type inference
%0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64>
^
LLVM ERROR: Failed to infer result type(s).
```
Bazel CI:
https://github.com/sjain-stanford/torch-mlir/actions/runs/7732673480/job/21083102228
2024-02-01 06:21:17 +08:00
|
|
|
loc, outType, constTensor, shape, rewriter.getDenseI64ArrayAttr({}))
|
2022-08-02 12:53:24 +08:00
|
|
|
.getResult();
|
|
|
|
}
|
2023-02-02 21:29:47 +08:00
|
|
|
} // namespace hlo
|
2022-08-03 08:16:31 +08:00
|
|
|
} // namespace mlir
|