2022-03-11 01:54:13 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2024-01-30 01:59:33 +08:00
|
|
|
#include "mlir/Dialect/Tensor/Utils/Utils.h"
|
2022-03-11 01:54:13 +08:00
|
|
|
#include "../PassDetail.h"
|
|
|
|
#include "PopulatePatterns.h"
|
2022-10-05 21:28:06 +08:00
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
2022-03-11 01:54:13 +08:00
|
|
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
|
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
2023-12-02 08:38:21 +08:00
|
|
|
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
|
2022-03-11 01:54:13 +08:00
|
|
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::torch;
|
|
|
|
using namespace mlir::torch::Torch;
|
|
|
|
|
|
|
|
static SmallVector<OpFoldResult>
|
|
|
|
getIndexIntsAsOpFoldResult(OpBuilder &b, SmallVectorImpl<int64_t> &ints) {
|
|
|
|
return llvm::to_vector<4>(llvm::map_range(
|
|
|
|
ints, [&](int64_t val) -> OpFoldResult { return b.getIndexAttr(val); }));
|
|
|
|
}
|
|
|
|
|
|
|
|
// Helper function to get the padding tensor given the padding int values.
|
|
|
|
Value torch_to_linalg::getPaddedTensor(
|
|
|
|
Operation *op, OpBuilder &b, Value &input,
|
|
|
|
SmallVectorImpl<int64_t> &lowPaddingInts,
|
|
|
|
SmallVectorImpl<int64_t> &highPaddingInts, Value pad) {
|
|
|
|
Location loc = op->getLoc();
|
2024-04-28 05:00:56 +08:00
|
|
|
Type rankedTensorType = tensor::PadOp::inferResultType(
|
|
|
|
cast<RankedTensorType>(input.getType()), lowPaddingInts, highPaddingInts);
|
2022-03-11 01:54:13 +08:00
|
|
|
SmallVector<OpFoldResult> lowPaddings =
|
|
|
|
getIndexIntsAsOpFoldResult(b, lowPaddingInts);
|
|
|
|
SmallVector<OpFoldResult> highPaddings =
|
|
|
|
getIndexIntsAsOpFoldResult(b, highPaddingInts);
|
2022-11-01 15:27:09 +08:00
|
|
|
Value paddedInput =
|
|
|
|
b.create<tensor::PadOp>(loc, rankedTensorType, input, /*low=*/lowPaddings,
|
|
|
|
/*high=*/highPaddings, pad);
|
2022-03-11 01:54:13 +08:00
|
|
|
return paddedInput;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Helper function to get the padding tensor given the padding int values.
|
|
|
|
// It's assumed that the padding on the low end and high end are the same,
|
|
|
|
// and that zero padding is required.
|
2022-04-01 16:23:29 +08:00
|
|
|
Value torch_to_linalg::getZeroPaddedTensor(
|
|
|
|
Operation *op, OpBuilder &b, Value &input,
|
|
|
|
SmallVectorImpl<int64_t> &paddingInts) {
|
2022-03-11 01:54:13 +08:00
|
|
|
assert(input.getType().isa<RankedTensorType>() &&
|
|
|
|
"input must be RankedTensorType");
|
|
|
|
Location loc = op->getLoc();
|
|
|
|
Value c0 = b.create<arith::ConstantOp>(
|
|
|
|
loc,
|
2024-04-28 05:00:56 +08:00
|
|
|
b.getZeroAttr(cast<RankedTensorType>(input.getType()).getElementType()));
|
2022-03-11 01:54:13 +08:00
|
|
|
return getPaddedTensor(op, b, input, paddingInts, paddingInts, c0);
|
|
|
|
}
|
|
|
|
|
2022-11-04 15:57:29 +08:00
|
|
|
// Helper function that adds dynamic padding to a tensor, ignoring unpaddedDims
|
|
|
|
// dimensions at the beginning. The high and low padding are the same, and the
|
|
|
|
// padding value is zero.
|
|
|
|
Value torch_to_linalg::getDynamicZeroPaddedTensor(
|
|
|
|
Operation *op, OpBuilder &b, Value &input, SmallVectorImpl<Value> &padding,
|
2024-01-31 05:46:47 +08:00
|
|
|
int unpaddedDims, Value pad) {
|
2022-11-04 15:57:29 +08:00
|
|
|
assert(input.getType().isa<RankedTensorType>() &&
|
|
|
|
"input must be RankedTensorType");
|
2024-04-28 05:00:56 +08:00
|
|
|
unsigned int inRank = cast<RankedTensorType>(input.getType()).getRank();
|
2022-11-04 15:57:29 +08:00
|
|
|
Location loc = op->getLoc();
|
|
|
|
|
|
|
|
SmallVector<Value> inputDims = getTensorSizes(b, loc, input);
|
|
|
|
Value c0 = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(0));
|
|
|
|
SmallVector<Value> paddingIncludingUnchanged(unpaddedDims, c0);
|
|
|
|
paddingIncludingUnchanged.append(padding);
|
|
|
|
assert(unpaddedDims + padding.size() == inRank &&
|
|
|
|
"sum of unpaddedDims and padding.size() must equal to inputRank");
|
|
|
|
for (auto pad = paddingIncludingUnchanged.begin();
|
|
|
|
pad < paddingIncludingUnchanged.end(); pad++)
|
|
|
|
*pad = castIntToIndex(b, loc, *pad);
|
|
|
|
|
2024-04-28 05:00:56 +08:00
|
|
|
Type elementType = cast<RankedTensorType>(input.getType()).getElementType();
|
2024-01-27 02:54:59 +08:00
|
|
|
// TODO: audit possibility of sparsity on this tensor
|
2022-11-29 20:33:31 +08:00
|
|
|
Type inputType =
|
|
|
|
RankedTensorType::get(makeShapeLLVMCompatible(llvm::ArrayRef<int64_t>(
|
|
|
|
SmallVector<int64_t>(inRank, kUnknownSize))),
|
|
|
|
elementType);
|
2022-11-04 15:57:29 +08:00
|
|
|
|
|
|
|
SmallVector<OpFoldResult> paddingValues =
|
|
|
|
getAsOpFoldResult(paddingIncludingUnchanged);
|
|
|
|
return b.create<tensor::PadOp>(loc, inputType, input, /*low=*/paddingValues,
|
2024-01-31 05:46:47 +08:00
|
|
|
/*high=*/paddingValues, pad);
|
2022-11-04 15:57:29 +08:00
|
|
|
}
|
|
|
|
|
2022-03-11 01:54:13 +08:00
|
|
|
Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc,
|
|
|
|
Value in, Value paddingInt,
|
|
|
|
Value dilationInt,
|
|
|
|
Value kernelSizeInt,
|
2022-05-03 21:22:42 +08:00
|
|
|
Value strideInt, bool ceilMode) {
|
2022-03-11 01:54:13 +08:00
|
|
|
Value c1 = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(1));
|
|
|
|
Value c2 = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(2));
|
|
|
|
|
|
|
|
Value doublePadding = b.create<arith::MulIOp>(loc, paddingInt, c2);
|
|
|
|
// in + 2 * padding
|
|
|
|
Value inAddDoublePadding =
|
2022-04-22 01:10:04 +08:00
|
|
|
b.create<arith::AddIOp>(loc, castIndexToInt64(b, loc, in), doublePadding);
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
// dilation * (kernelSize - 1)
|
|
|
|
Value kernelSizeSub1 = b.create<arith::SubIOp>(loc, kernelSizeInt, c1);
|
|
|
|
Value dilationTimesKernelSize =
|
|
|
|
b.create<arith::MulIOp>(loc, dilationInt, kernelSizeSub1);
|
|
|
|
|
|
|
|
Value temp =
|
|
|
|
b.create<arith::SubIOp>(loc, inAddDoublePadding, dilationTimesKernelSize);
|
|
|
|
Value dividend = b.create<arith::SubIOp>(loc, temp, c1);
|
2022-05-03 21:22:42 +08:00
|
|
|
Value division;
|
|
|
|
if (ceilMode)
|
|
|
|
division = b.create<arith::CeilDivSIOp>(loc, dividend, strideInt);
|
|
|
|
else
|
|
|
|
division = b.create<arith::FloorDivSIOp>(loc, dividend, strideInt);
|
2022-03-11 01:54:13 +08:00
|
|
|
Value out = b.create<arith::AddIOp>(loc, division, c1);
|
|
|
|
return castIntToIndex(b, loc, out);
|
|
|
|
}
|
|
|
|
|
2022-08-25 00:19:35 +08:00
|
|
|
Value torch_to_linalg::getOutputDimForConvTransposeOps(
|
|
|
|
OpBuilder &b, Location loc, Value in, Value paddingInt, Value dilationInt,
|
2022-12-08 22:15:31 +08:00
|
|
|
Value kernelSizeInt, Value strideInt, Value outputPaddingInt) {
|
2022-08-25 00:19:35 +08:00
|
|
|
Value c1 = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(1));
|
|
|
|
Value c2 = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(2));
|
|
|
|
|
|
|
|
// (in - 1) * stride
|
|
|
|
Value inStrided =
|
|
|
|
b.create<arith::SubIOp>(loc, castIndexToInt64(b, loc, in), c1);
|
|
|
|
inStrided = b.create<arith::MulIOp>(loc, inStrided, strideInt);
|
|
|
|
|
|
|
|
// 2 * padding
|
|
|
|
Value doublePadding = b.create<arith::MulIOp>(loc, paddingInt, c2);
|
|
|
|
|
|
|
|
// (kernelSize - 1) * dilation
|
|
|
|
Value kernelDilated = b.create<arith::SubIOp>(loc, kernelSizeInt, c1);
|
|
|
|
kernelDilated = b.create<arith::MulIOp>(loc, kernelDilated, dilationInt);
|
|
|
|
|
|
|
|
Value out = b.create<arith::SubIOp>(loc, inStrided, doublePadding);
|
|
|
|
out = b.create<arith::AddIOp>(loc, out, kernelDilated);
|
2022-12-08 22:15:31 +08:00
|
|
|
out = b.create<arith::AddIOp>(loc, out, outputPaddingInt);
|
2022-08-25 00:19:35 +08:00
|
|
|
out = b.create<arith::AddIOp>(loc, out, c1);
|
|
|
|
|
|
|
|
return castIntToIndex(b, loc, out);
|
|
|
|
}
|
|
|
|
|
2022-03-11 01:54:13 +08:00
|
|
|
Value torch_to_linalg::createReductionLinalgGeneric(
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
OpBuilder &b, Location loc, const ReductionOpInfo &opInfo, Value initElem,
|
2022-03-11 01:54:13 +08:00
|
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
|
2024-04-28 05:00:56 +08:00
|
|
|
auto inputType = cast<RankedTensorType>(opInfo.tensorOperand.getType());
|
2022-03-11 01:54:13 +08:00
|
|
|
|
|
|
|
// Get the result shape by obtaining the size of each
|
|
|
|
// dimension in the input tensor that is not getting reduced.
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
// If `opInfo.keepDim` is true, the rank of the output tensor
|
2022-03-11 01:54:13 +08:00
|
|
|
// is kept the same as the rank of the input tensor, and the
|
|
|
|
// reduced dimensions are set to have size 1.
|
|
|
|
auto c1 = b.create<arith::ConstantIndexOp>(loc, /*value=*/1);
|
|
|
|
SmallVector<Value> resultShape;
|
|
|
|
for (int64_t i = 0; i < inputType.getRank(); i++) {
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
auto currentDimSize = b.create<tensor::DimOp>(loc, opInfo.tensorOperand, i);
|
|
|
|
if (!opInfo.dimSet.contains(i))
|
2022-03-11 01:54:13 +08:00
|
|
|
resultShape.push_back(currentDimSize);
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
else if (opInfo.keepDim)
|
2022-03-11 01:54:13 +08:00
|
|
|
resultShape.push_back(c1);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Create the affine expressions that will be used to
|
|
|
|
// iterate over the input and output tensors.
|
|
|
|
// Here we also set the type of iterator: parallel or reduction.
|
|
|
|
SmallVector<AffineExpr> exprs;
|
2022-11-17 06:40:36 +08:00
|
|
|
SmallVector<utils::IteratorType> iteratorTypes;
|
2022-03-11 01:54:13 +08:00
|
|
|
SmallVector<AffineExpr> resultExprs;
|
2022-11-29 20:33:31 +08:00
|
|
|
for (auto size :
|
|
|
|
llvm::enumerate(makeShapeTorchCompatible(inputType.getShape()))) {
|
2022-03-11 01:54:13 +08:00
|
|
|
exprs.push_back(b.getAffineDimExpr(size.index()));
|
|
|
|
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
if (opInfo.dimSet.contains(size.index())) {
|
2022-11-17 06:40:36 +08:00
|
|
|
iteratorTypes.push_back(utils::IteratorType::reduction);
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
// If `opInfo.keepDim`, create affine map to the first element
|
2022-03-11 01:54:13 +08:00
|
|
|
// in the current dimension.
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
if (opInfo.keepDim)
|
2022-03-11 01:54:13 +08:00
|
|
|
resultExprs.push_back(b.getAffineConstantExpr(0));
|
|
|
|
} else {
|
2022-11-17 06:40:36 +08:00
|
|
|
iteratorTypes.push_back(utils::IteratorType::parallel);
|
2022-03-11 01:54:13 +08:00
|
|
|
resultExprs.push_back(b.getAffineDimExpr(size.index()));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-02-10 06:07:49 +08:00
|
|
|
auto indexingMaps =
|
|
|
|
AffineMap::inferFromExprList({exprs, resultExprs}, b.getContext());
|
2022-03-11 01:54:13 +08:00
|
|
|
Value accumulator =
|
|
|
|
createInitTensor(b, loc, resultShape, initElem.getType(), initElem);
|
|
|
|
|
|
|
|
return b
|
|
|
|
.create<linalg::GenericOp>(
|
|
|
|
loc, /*resultTensorTypes=*/accumulator.getType(),
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
/*inputs=*/opInfo.tensorOperand,
|
2022-03-11 01:54:13 +08:00
|
|
|
/*outputs=*/accumulator, indexingMaps, iteratorTypes, bodyBuild)
|
|
|
|
.getResult(0);
|
|
|
|
}
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
|
|
|
|
Value torch_to_linalg::createElementwiseLinalgGeneric(
|
|
|
|
OpBuilder &b, Location loc, ValueRange tensorOperands,
|
|
|
|
Type resultElementType,
|
|
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
|
|
|
|
// The overall error handling strategy here is best viewed by thinking about
|
|
|
|
// what happens for a single result dimension. This loop not structured that
|
|
|
|
// way because it is hard to create the affine maps for each operand unless
|
|
|
|
// we structure the loop to iterate over tensor operands as the outer loop
|
|
|
|
// instead of inner loop. This pseudocode gives better intuition:
|
|
|
|
// ```
|
|
|
|
// for each result dimension:
|
|
|
|
// for each tensor operand:
|
|
|
|
// if it doesn't even have high enough rank relative to the result:
|
|
|
|
// continue
|
|
|
|
// if it is a static size-1 along this result dimension:
|
|
|
|
// continue
|
|
|
|
// if this is the first tensor operand that didn't continue above:
|
|
|
|
// take its dimension size as the size of the non-broadcasted
|
|
|
|
// traversal along this dimension (this may include a dynamic size-1,
|
2023-09-30 07:45:48 +08:00
|
|
|
// **non-broadcasted** traversal unless if
|
|
|
|
// isAssumingStrictSymbolicShapes!)
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
// emit error check "if the size does not match the non-broadcasted
|
|
|
|
// traversal size along this dimension, error"
|
|
|
|
// ```
|
|
|
|
SmallVector<int64_t> operandRanks;
|
|
|
|
operandRanks.resize(tensorOperands.size());
|
|
|
|
llvm::transform(tensorOperands, operandRanks.begin(), [](Value tensor) {
|
2024-04-28 05:00:56 +08:00
|
|
|
return dyn_cast<RankedTensorType>(tensor.getType()).getRank();
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
});
|
|
|
|
|
|
|
|
auto resultRankIt =
|
|
|
|
std::max_element(operandRanks.begin(), operandRanks.end());
|
|
|
|
assert(resultRankIt != operandRanks.end() && "Unable to get result rank.");
|
|
|
|
int64_t resultRank = *resultRankIt;
|
|
|
|
|
|
|
|
// Initialize the resultShape to all 1's, as a fallback in case
|
|
|
|
// all sizes along that result dimension are statically 1.
|
|
|
|
auto c1 = b.create<arith::ConstantIndexOp>(loc, /*value=*/1);
|
|
|
|
SmallVector<Value> resultShape(resultRank, c1);
|
|
|
|
SmallVector<AffineMap> indexingMaps;
|
2023-09-30 07:45:48 +08:00
|
|
|
bool elideDynamicBroadcastCheck = isAssumingStrictSymbolicShapes(b);
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
for (Value tensorOperand : tensorOperands) {
|
|
|
|
SmallVector<AffineExpr> exprs;
|
2024-04-28 05:00:56 +08:00
|
|
|
auto type = cast<RankedTensorType>(tensorOperand.getType());
|
2022-11-29 20:33:31 +08:00
|
|
|
for (auto size :
|
|
|
|
llvm::enumerate(makeShapeTorchCompatible(type.getShape()))) {
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
// If the size is statically known to be 1, we don't want any
|
|
|
|
// error guards to be spuriously emitted, since we are specifically
|
|
|
|
// allowing size-1 broadcasts in this case, as they correspond to a
|
|
|
|
// constant-0 indexing map.
|
|
|
|
if (size.value() == 1) {
|
|
|
|
exprs.push_back(b.getAffineConstantExpr(0));
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
// The rank of this operand might be smaller than the overall rank of
|
|
|
|
// the broadcast. Add an offset to correlate it to the correct
|
|
|
|
// dimension of the result.
|
|
|
|
auto resultDim = size.index() + (resultRank - type.getRank());
|
|
|
|
|
|
|
|
// The generated linalg op will now be iterating along the full size
|
|
|
|
// of this dimension. Record that fact.
|
|
|
|
exprs.push_back(b.getAffineDimExpr(resultDim));
|
|
|
|
|
|
|
|
// Now, we need to ensure that such iteration is not going to trigger
|
|
|
|
// undefined behavior, by doing appropriate checks against the current
|
|
|
|
// dimension size.
|
|
|
|
auto currentDimSize = getDimOp(b, loc, tensorOperand, size.index());
|
|
|
|
|
|
|
|
// If the result size of this dimension has so far only hit the
|
|
|
|
// statically-known-to-be-1 case above (i.e., we have not yet assigned a
|
|
|
|
// new Value to `resultShape[resultDim]`), then we have no other dynamic
|
|
|
|
// values to check against, and merely need to record the current
|
|
|
|
// dimension size.
|
|
|
|
if (resultShape[resultDim] == c1) {
|
|
|
|
resultShape[resultDim] = currentDimSize;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
// We prohibit the size-1 dynamic broadcasting scenario, so just check
|
|
|
|
// for exact equality with the running result size.
|
|
|
|
// This is the check which protects against the undefined behavior of
|
|
|
|
// the generated linalg op in the case of iterating two operands with
|
|
|
|
// dimensions sizes that are expected to match.
|
2023-09-30 07:45:48 +08:00
|
|
|
if (!elideDynamicBroadcastCheck) {
|
|
|
|
auto equalToRunning =
|
|
|
|
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
|
|
|
|
resultShape[resultDim], currentDimSize);
|
|
|
|
b.create<cf::AssertOp>(loc, equalToRunning,
|
|
|
|
"mismatched size for broadcast");
|
|
|
|
}
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
}
|
|
|
|
indexingMaps.push_back(AffineMap::get(
|
|
|
|
/*dimCount=*/resultRank, /*symbolCount=*/0, exprs, b.getContext()));
|
|
|
|
}
|
|
|
|
|
2022-11-17 06:40:36 +08:00
|
|
|
SmallVector<utils::IteratorType> iteratorTypes(resultRank,
|
|
|
|
utils::IteratorType::parallel);
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
// Add the indexing map for the outs init tensor.
|
|
|
|
indexingMaps.push_back(b.getMultiDimIdentityMap(resultRank));
|
|
|
|
|
2022-10-18 12:22:53 +08:00
|
|
|
Value initTensor = b.create<tensor::EmptyOp>(
|
torch,linalg: add support for translating aten.linalg.vector_norm (#839)
This patch adds support for the torch.linalg.vector_norm op to the torch
dialect, including the necessary shape function. It also extends the
conversion of reduction operators to support lowering of
AtenLinalgVectorNormOp, in addition to adding a handful of end-to-end
tests to validate the lowering.
There exist several opportunities to make this lowering optimal and
robust. For instance, in its current form, the translation does not
support ord = 0, +inf, or -inf. For L1 norms, we don't need to raise
each element to the power 1.0. Similarly, L2 norms could benefit from
strength reduction. Since the canonicalization pass is not able to
apply these optimizations, we should consider applying them during the
linalg lowering itself.
2022-05-20 06:48:15 +08:00
|
|
|
loc, getAsOpFoldResult(resultShape), resultElementType);
|
|
|
|
return b
|
|
|
|
.create<linalg::GenericOp>(loc,
|
|
|
|
/*resultTensorTypes=*/initTensor.getType(),
|
|
|
|
/*inputs=*/tensorOperands,
|
|
|
|
/*outputs=*/initTensor, indexingMaps,
|
|
|
|
iteratorTypes, bodyBuild)
|
|
|
|
.getResult(0);
|
|
|
|
}
|
2022-06-16 23:45:10 +08:00
|
|
|
|
|
|
|
// Broadcasts input tensor based on the broadcastToShape.
|
|
|
|
LogicalResult torch_to_linalg::broadcastToGivenShape(
|
|
|
|
Operation *op, PatternRewriter &rewriter, Value input,
|
2023-10-06 03:15:26 +08:00
|
|
|
SmallVector<Value> broadcastToShape, RankedTensorType broadcastType,
|
|
|
|
Value &result, SmallVector<bool> useBroadcastToShape) {
|
2024-04-28 05:00:56 +08:00
|
|
|
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
|
2023-10-06 03:15:26 +08:00
|
|
|
int64_t inputRank = inputType.getRank();
|
|
|
|
int64_t outputRank = broadcastToShape.size();
|
|
|
|
ArrayRef<int64_t> outputShape = broadcastType.getShape();
|
2022-11-29 20:33:31 +08:00
|
|
|
SmallVector<int64_t> inputShape =
|
|
|
|
makeShapeTorchCompatible(inputType.getShape());
|
2023-10-06 03:15:26 +08:00
|
|
|
if (outputRank < inputRank) {
|
2022-06-16 23:45:10 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "invalid shape: broadcastToShape size must not be smaller than the "
|
|
|
|
"size of the input shape");
|
|
|
|
}
|
|
|
|
|
|
|
|
Type elementType = inputType.getElementType();
|
|
|
|
Location loc = op->getLoc();
|
2023-10-06 03:15:26 +08:00
|
|
|
SmallVector<OpFoldResult> outShape;
|
2023-09-30 07:45:48 +08:00
|
|
|
bool elideDynamicBroadcastCheck = isAssumingStrictSymbolicShapes(rewriter);
|
2022-06-16 23:45:10 +08:00
|
|
|
|
2023-10-06 03:15:26 +08:00
|
|
|
// Vector indicating broadcasted status when assuming strict symbolic shapes.
|
|
|
|
SmallVector<bool> broadcastedStatus;
|
|
|
|
|
2022-06-16 23:45:10 +08:00
|
|
|
// Create affine map and shapes for tensor initialization.
|
|
|
|
SmallVector<AffineExpr> outExpr;
|
|
|
|
Value zero =
|
|
|
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(0));
|
2023-07-08 01:01:51 +08:00
|
|
|
Value zeroIndex =
|
|
|
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
|
|
|
|
Value oneIndex =
|
|
|
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
2023-10-06 03:15:26 +08:00
|
|
|
size_t diff = outputRank - inputRank;
|
|
|
|
bool hasDynamicNumpyBroadcast = false;
|
|
|
|
for (size_t i = 0, e = outputRank; i < e; i++) {
|
2022-06-16 23:45:10 +08:00
|
|
|
Value shapeValue = broadcastToShape[i];
|
|
|
|
size_t j = i - diff;
|
2023-10-06 03:15:26 +08:00
|
|
|
bool isDynamic = i >= diff && inputShape[j] == kUnknownSize;
|
|
|
|
|
|
|
|
// Inherit static output shapes if present.
|
|
|
|
if (outputShape[i] != ShapedType::kDynamic) {
|
|
|
|
outShape.push_back(rewriter.getIndexAttr(outputShape[i]));
|
|
|
|
if (i < diff) {
|
|
|
|
if (outputShape[i] < 0) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "invalid shape: negative values not allowed in new broadcast "
|
|
|
|
"dimensions");
|
|
|
|
}
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
if (isDynamic) {
|
|
|
|
hasDynamicNumpyBroadcast = true;
|
|
|
|
} else if (inputShape[j] != outputShape[i] && inputShape[j] != 1) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "invalid shape: static mismatch in input and output broadcast "
|
|
|
|
"shapes");
|
|
|
|
}
|
|
|
|
|
|
|
|
// If strict symbolic shapes are assumed and the input shape is dynamic,
|
|
|
|
// we can assume that dim is not broadcasted.
|
|
|
|
broadcastedStatus.push_back(inputShape[j] != outputShape[i] &&
|
|
|
|
!isDynamic);
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
2022-06-16 23:45:10 +08:00
|
|
|
if (i < diff) {
|
2023-09-30 07:45:48 +08:00
|
|
|
if (!elideDynamicBroadcastCheck) {
|
|
|
|
Value isValid = rewriter.create<arith::CmpIOp>(
|
|
|
|
loc, arith::CmpIPredicate::sge, shapeValue, zero);
|
|
|
|
rewriter.create<cf::AssertOp>(
|
|
|
|
loc, isValid,
|
|
|
|
rewriter.getStringAttr(
|
|
|
|
"negative values not allowed in new dimensions"));
|
|
|
|
}
|
2022-06-16 23:45:10 +08:00
|
|
|
outShape.push_back(castIntToIndex(rewriter, loc, shapeValue));
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
if (inputShape[j] == 1) {
|
|
|
|
// Broadcast singleton dimension
|
|
|
|
Value isNegative = rewriter.create<arith::CmpIOp>(
|
|
|
|
loc, arith::CmpIPredicate::slt, shapeValue, zero);
|
|
|
|
Value select = rewriter.create<arith::SelectOp>(
|
2023-07-08 01:01:51 +08:00
|
|
|
loc, isNegative, oneIndex, castIntToIndex(rewriter, loc, shapeValue));
|
2022-06-16 23:45:10 +08:00
|
|
|
outShape.push_back(select);
|
2023-10-06 03:15:26 +08:00
|
|
|
broadcastedStatus.push_back(true);
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Case of dynamic input dimension wherein the shape to broadcast will
|
|
|
|
// yield us the dimension size of the output.
|
|
|
|
Value dim;
|
|
|
|
if (!useBroadcastToShape.empty() && useBroadcastToShape[j]) {
|
|
|
|
dim = castIntToIndex(rewriter, loc, broadcastToShape[i]);
|
|
|
|
if (isDynamic) {
|
|
|
|
hasDynamicNumpyBroadcast = true;
|
2023-07-08 01:01:51 +08:00
|
|
|
}
|
2023-10-06 03:15:26 +08:00
|
|
|
if (!elideDynamicBroadcastCheck) {
|
|
|
|
Value isValid = rewriter.create<arith::CmpIOp>(
|
|
|
|
loc, arith::CmpIPredicate::sge, shapeValue, zero);
|
|
|
|
rewriter.create<cf::AssertOp>(
|
|
|
|
loc, isValid,
|
|
|
|
rewriter.getStringAttr(
|
|
|
|
"unimplemented: dynamic negative broadcast sizes"));
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
dim = getDimOp(rewriter, loc, input, j);
|
2022-06-16 23:45:10 +08:00
|
|
|
}
|
2023-10-06 03:15:26 +08:00
|
|
|
// We can safely assume this dimension is not broadcasted with strict
|
|
|
|
// symbols.
|
|
|
|
broadcastedStatus.push_back(false);
|
|
|
|
outShape.push_back(dim);
|
2022-06-16 23:45:10 +08:00
|
|
|
}
|
|
|
|
|
2023-10-06 03:15:26 +08:00
|
|
|
Value outTensor =
|
|
|
|
rewriter.create<tensor::EmptyOp>(loc, outShape, elementType);
|
|
|
|
|
|
|
|
// If we know there are no ? -> ? broadcasted dims, or we are assuming
|
|
|
|
// strict symbols, we can safely use standard linalg style broadcasting
|
|
|
|
// semantics.
|
|
|
|
if (!hasDynamicNumpyBroadcast || elideDynamicBroadcastCheck) {
|
|
|
|
// If no dims are broadcasted and the rank doesn't change, we can just fold
|
|
|
|
// the op away entirely.
|
|
|
|
if (!llvm::any_of(broadcastedStatus, [](bool b) { return b; }) &&
|
|
|
|
inputRank == outputRank) {
|
|
|
|
result = rewriter.create<tensor::CastOp>(loc, outTensor.getType(), input);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<AffineExpr> inputExprs;
|
|
|
|
for (int64_t i = 0, e = inputRank; i < e; ++i) {
|
|
|
|
if (broadcastedStatus[i]) {
|
|
|
|
inputExprs.push_back(rewriter.getAffineConstantExpr(0));
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
inputExprs.push_back(rewriter.getAffineDimExpr(i + diff));
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<AffineMap> indexingMaps = {
|
|
|
|
AffineMap::get(outputRank, 0, inputExprs, rewriter.getContext()),
|
|
|
|
rewriter.getMultiDimIdentityMap(outputRank)};
|
|
|
|
SmallVector<utils::IteratorType> iteratorTypes(
|
|
|
|
outputRank, utils::IteratorType::parallel);
|
|
|
|
result = rewriter
|
|
|
|
.create<linalg::GenericOp>(
|
|
|
|
loc, outTensor.getType(), input, outTensor, indexingMaps,
|
|
|
|
iteratorTypes,
|
|
|
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
|
|
|
b.create<linalg::YieldOp>(loc, args[0]);
|
|
|
|
})
|
|
|
|
.getResult(0);
|
|
|
|
return success();
|
|
|
|
}
|
2022-06-16 23:45:10 +08:00
|
|
|
|
2023-10-06 03:15:26 +08:00
|
|
|
// Fall back to numpy-style dynamic broadcasting in the form of a single
|
|
|
|
// linalg op.
|
2022-06-16 23:45:10 +08:00
|
|
|
SmallVector<AffineMap> indexingMaps = {
|
2023-10-06 03:15:26 +08:00
|
|
|
rewriter.getMultiDimIdentityMap(outputRank)};
|
|
|
|
SmallVector<utils::IteratorType> iteratorTypes(outputRank,
|
2022-11-17 06:40:36 +08:00
|
|
|
utils::IteratorType::parallel);
|
2022-06-16 23:45:10 +08:00
|
|
|
result = rewriter
|
|
|
|
.create<linalg::GenericOp>(
|
2023-07-08 01:01:51 +08:00
|
|
|
loc, outTensor.getType(), ValueRange(), outTensor,
|
|
|
|
indexingMaps, iteratorTypes,
|
|
|
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
|
|
|
// `loopIndices` contains IV of the linalg loops which
|
|
|
|
// would be used to extract values from the input tensor
|
|
|
|
// later on.
|
|
|
|
SmallVector<Value> loopIndices;
|
2023-10-06 03:15:26 +08:00
|
|
|
for (size_t i = 0, e = outputRank; i < e; ++i) {
|
2023-07-08 01:01:51 +08:00
|
|
|
if (i < diff)
|
|
|
|
continue;
|
|
|
|
loopIndices.push_back(b.create<linalg::IndexOp>(loc, i));
|
|
|
|
}
|
|
|
|
// `inputIndicesToExtract` contains i-th linalg loop IV if
|
|
|
|
// the i-th input dimension is not 1, else it contains a
|
|
|
|
// zero index.
|
|
|
|
SmallVector<Value> inputIndicesToExtract;
|
2023-10-06 03:15:26 +08:00
|
|
|
for (size_t i = 0, n = inputRank; i < n; i++) {
|
2023-07-08 01:01:51 +08:00
|
|
|
if (inputShape[i] == 1) {
|
|
|
|
inputIndicesToExtract.push_back(zeroIndex);
|
|
|
|
} else {
|
|
|
|
Value inputDim = getDimOp(b, loc, input, i);
|
|
|
|
Value isEqual = b.create<arith::CmpIOp>(
|
|
|
|
loc, arith::CmpIPredicate::eq, inputDim, oneIndex);
|
|
|
|
Value select = rewriter.create<arith::SelectOp>(
|
|
|
|
loc, isEqual, zeroIndex, loopIndices[i]);
|
|
|
|
inputIndicesToExtract.push_back(select);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// Extract and yield the value from input tensor at
|
|
|
|
// `inputIndicesToExtract` indices.
|
|
|
|
Value result = b.create<tensor::ExtractOp>(
|
|
|
|
loc, input, inputIndicesToExtract);
|
|
|
|
b.create<linalg::YieldOp>(loc, result);
|
2022-06-16 23:45:10 +08:00
|
|
|
})
|
|
|
|
.getResult(0);
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
2022-08-25 00:19:35 +08:00
|
|
|
|
|
|
|
Value torch_to_linalg::removeSizeInformation(OpBuilder &b, Location loc,
|
|
|
|
Value tensor) {
|
2024-04-28 05:00:56 +08:00
|
|
|
auto tensorType = cast<RankedTensorType>(tensor.getType());
|
2022-08-25 00:19:35 +08:00
|
|
|
auto rank = tensorType.getRank();
|
|
|
|
SmallVector<int64_t> unknownSizes(rank, kUnknownSize);
|
2022-11-29 20:33:31 +08:00
|
|
|
return b.create<tensor::CastOp>(
|
|
|
|
loc, tensorType.clone(makeShapeLLVMCompatible(unknownSizes)), tensor);
|
2022-08-25 00:19:35 +08:00
|
|
|
}
|
2023-09-11 20:58:59 +08:00
|
|
|
|
|
|
|
Value torch_to_linalg::convertTensorToElementType(OpBuilder &b, Location loc,
|
|
|
|
Value tensor,
|
|
|
|
Type elementType) {
|
|
|
|
auto dtypePromoteBody = [&](OpBuilder &builder, Location loc,
|
|
|
|
ValueRange payloadArgs) {
|
|
|
|
Value elem =
|
|
|
|
convertScalarToDtype(builder, loc, payloadArgs[0], elementType);
|
|
|
|
builder.create<linalg::YieldOp>(loc, elem);
|
|
|
|
};
|
|
|
|
return torch_to_linalg::createElementwiseLinalgGeneric(
|
|
|
|
b, loc, {tensor}, elementType, dtypePromoteBody);
|
|
|
|
}
|
2023-11-30 01:43:09 +08:00
|
|
|
|
|
|
|
FailureOr<Type> torch_to_linalg::getBackendTypeForScalarType(
|
|
|
|
MLIRContext *context, torch_upstream::ScalarType dtypeInt) {
|
|
|
|
FailureOr<Type> maybeType =
|
|
|
|
getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
|
|
|
|
if (failed(maybeType)) {
|
|
|
|
return failure();
|
|
|
|
}
|
|
|
|
Type type = *maybeType;
|
|
|
|
// The linalg-on-tensors backend currently expects integers to be signless.
|
2024-04-11 21:47:35 +08:00
|
|
|
if (auto intType = dyn_cast<IntegerType>(type)) {
|
2023-11-30 01:43:09 +08:00
|
|
|
type = IntegerType::get(context, intType.getWidth(), IntegerType::Signless);
|
|
|
|
}
|
|
|
|
return type;
|
|
|
|
}
|
2024-01-13 11:11:14 +08:00
|
|
|
|
|
|
|
bool torch_to_linalg::isUnsignedTorchType(Type type) {
|
|
|
|
if (auto tty = dyn_cast<ValueTensorType>(type))
|
|
|
|
return isUnsignedTorchType(tty.getDtype());
|
|
|
|
if (isa<mlir::FloatType>(type))
|
|
|
|
return false;
|
|
|
|
if (isa<QInt8Type>(type))
|
|
|
|
return false;
|
|
|
|
if (isa<QUInt8Type>(type))
|
|
|
|
return true;
|
|
|
|
if (isa<QInt32Type>(type))
|
|
|
|
return false;
|
|
|
|
if (auto intTy = dyn_cast<IntegerType>(type))
|
|
|
|
return intTy.isUnsigned();
|
|
|
|
llvm_unreachable("Unknown type checked for signedness");
|
|
|
|
return false;
|
|
|
|
}
|