2022-03-03 00:42:25 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
|
|
|
|
2022-10-05 21:28:06 +08:00
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
2024-05-01 14:36:53 +08:00
|
|
|
#include "mlir/Dialect/Complex/IR/Complex.h"
|
2022-03-03 00:42:25 +08:00
|
|
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
|
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
2022-04-27 03:27:51 +08:00
|
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
2022-03-03 00:42:25 +08:00
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
2022-12-22 03:04:07 +08:00
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
2022-03-03 00:42:25 +08:00
|
|
|
|
|
|
|
namespace mlir {
|
|
|
|
namespace torch {
|
|
|
|
namespace Torch {
|
|
|
|
|
|
|
|
LogicalResult verifyLinalgCompatibleTypes(Operation *op,
|
|
|
|
PatternRewriter &rewriter) {
|
|
|
|
// Check the value tensor is ranked as expected by Linalg.
|
|
|
|
// TODO: Remove this check but use a separate verification pass to verify the
|
|
|
|
// invariants expected by later passes.
|
|
|
|
auto isValidLinalgType = [](Type type) {
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<NonValueTensorType>(type))
|
2022-06-16 07:21:36 +08:00
|
|
|
return false;
|
2024-04-11 21:47:35 +08:00
|
|
|
auto tensor = dyn_cast<ValueTensorType>(type);
|
2022-03-03 00:42:25 +08:00
|
|
|
return !tensor ||
|
2024-05-31 14:45:13 +08:00
|
|
|
dyn_cast_or_null<RankedTensorType>(tensor.toBuiltinTensor());
|
2022-03-03 00:42:25 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
bool valid = llvm::all_of(op->getOperandTypes(), isValidLinalgType) &&
|
|
|
|
llvm::all_of(op->getResultTypes(), isValidLinalgType);
|
|
|
|
if (!valid)
|
|
|
|
return rewriter.notifyMatchFailure(op, "type cannot be lowered to linalg");
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op, Value v) {
|
|
|
|
Type type = v.getType();
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<OptionalType>(type) || isa<Torch::NoneType>(type) ||
|
|
|
|
isa<mlir::NoneType>(type))
|
2022-03-03 00:42:25 +08:00
|
|
|
return rewriter.notifyMatchFailure(op, "unimplemented None type arg");
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Generate IR: dim = dim >= 0 ? dim : dim + inputRank
|
|
|
|
Value toPositiveDimDynamic(OpBuilder &b, Location loc, Value dim,
|
|
|
|
Value inputRank) {
|
2024-04-28 05:00:56 +08:00
|
|
|
assert(isa<IntegerType>(dim.getType()) &&
|
2022-03-03 00:42:25 +08:00
|
|
|
"dim arg of toPositiveDim must be integer type");
|
|
|
|
Value dimAddInputRank = b.create<arith::AddIOp>(loc, dim, inputRank);
|
|
|
|
Value cst0 =
|
|
|
|
b.create<arith::ConstantOp>(loc, b.getZeroAttr(inputRank.getType()));
|
|
|
|
Value predDimGEZero =
|
|
|
|
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, dim, cst0);
|
|
|
|
Value dimInt =
|
|
|
|
b.create<arith::SelectOp>(loc, predDimGEZero, dim, dimAddInputRank);
|
|
|
|
return dimInt;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Generate IR: assert(dim >= 0 && dim < inputRank)
|
|
|
|
void assertIsValidDim(OpBuilder &b, Location loc, Value dim, Value inputRank) {
|
2024-05-31 14:45:13 +08:00
|
|
|
assert(isa<IntegerType>(dim.getType()) &&
|
2022-03-03 00:42:25 +08:00
|
|
|
"dim arg of assertIsValidDim must be integer type");
|
|
|
|
Value cst0 =
|
|
|
|
b.create<arith::ConstantOp>(loc, b.getZeroAttr(inputRank.getType()));
|
|
|
|
Value predGEZero =
|
|
|
|
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, dim, cst0);
|
|
|
|
b.create<cf::AssertOp>(
|
|
|
|
loc, predGEZero, b.getStringAttr("dim must be greater or equal to zero"));
|
|
|
|
Value predLTInputRank =
|
|
|
|
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, dim, inputRank);
|
|
|
|
b.create<cf::AssertOp>(loc, predLTInputRank,
|
|
|
|
b.getStringAttr("dim must be smaller than inputRank"));
|
|
|
|
}
|
|
|
|
|
|
|
|
// Hack to deal with the Torch list type arguments which is not supported end
|
|
|
|
// to end. Constant values can be be extracted directly and non constant
|
|
|
|
// list values are not supported.
|
|
|
|
// TODO: loose this constraint when properly support list type
|
|
|
|
bool isConstantIntListMatching(Value value, SmallVectorImpl<int64_t> &expects) {
|
|
|
|
SmallVector<int64_t> intValues;
|
2022-11-17 04:33:12 +08:00
|
|
|
if (!matchPattern(value, m_TorchListOfConstantInts(intValues)))
|
2022-03-03 00:42:25 +08:00
|
|
|
return false;
|
|
|
|
|
|
|
|
if (intValues.size() != expects.size())
|
|
|
|
return false;
|
|
|
|
|
|
|
|
for (auto it : llvm::zip(intValues, expects)) {
|
|
|
|
if (std::get<0>(it) != std::get<1>(it))
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDim,
|
|
|
|
Value rhsDim) {
|
|
|
|
Type lhsType = lhsDim.getType();
|
|
|
|
Type rhsType = rhsDim.getType();
|
|
|
|
auto checkIntOrIndex = [](Type type) {
|
2024-04-11 21:47:35 +08:00
|
|
|
assert((isa<IntegerType>(type) || isa<IndexType>(type)) &&
|
2023-09-06 15:23:23 +08:00
|
|
|
"must be either integer or index type");
|
2022-03-03 00:42:25 +08:00
|
|
|
};
|
|
|
|
checkIntOrIndex(lhsType);
|
|
|
|
checkIntOrIndex(rhsType);
|
2022-04-22 01:10:04 +08:00
|
|
|
Value lhsDimInt =
|
|
|
|
lhsType.isIndex() ? castIndexToInt64(b, loc, lhsDim) : lhsDim;
|
|
|
|
Value rhsDimInt =
|
|
|
|
rhsType.isIndex() ? castIndexToInt64(b, loc, rhsDim) : rhsDim;
|
2022-03-03 00:42:25 +08:00
|
|
|
Value contractingDimEqual = b.create<arith::CmpIOp>(
|
|
|
|
loc, arith::CmpIPredicate::eq, lhsDimInt, rhsDimInt);
|
|
|
|
b.create<cf::AssertOp>(loc, contractingDimEqual,
|
|
|
|
b.getStringAttr("mismatching contracting dimension"));
|
|
|
|
}
|
|
|
|
|
|
|
|
// Creates a tensor with required `sizes` and `elemTy` and fills it with
|
|
|
|
// initElem.
|
|
|
|
Value createInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
|
|
|
Type elemTy, Value initElem) {
|
2022-10-18 12:22:53 +08:00
|
|
|
Value initTensor =
|
|
|
|
b.create<tensor::EmptyOp>(loc, getAsOpFoldResult(sizes), elemTy);
|
2022-03-03 00:42:25 +08:00
|
|
|
return b.create<linalg::FillOp>(loc, initElem, initTensor).getResult(0);
|
|
|
|
}
|
|
|
|
|
2022-04-01 16:23:29 +08:00
|
|
|
Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
|
|
|
Type elemTy) {
|
2022-10-18 12:22:53 +08:00
|
|
|
Value initTensor =
|
|
|
|
b.create<tensor::EmptyOp>(loc, getAsOpFoldResult(sizes), elemTy);
|
2024-04-28 05:00:56 +08:00
|
|
|
RankedTensorType type = cast<RankedTensorType>(initTensor.getType());
|
2022-04-01 16:23:29 +08:00
|
|
|
Value c0 =
|
|
|
|
b.create<arith::ConstantOp>(loc, b.getZeroAttr(type.getElementType()));
|
|
|
|
return b.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
|
|
|
|
}
|
|
|
|
|
2022-03-03 00:42:25 +08:00
|
|
|
Value castIntToIndex(OpBuilder &b, Location loc, Value v) {
|
2024-05-31 14:45:13 +08:00
|
|
|
assert(isa<IntegerType>(v.getType()) && "must be called with integer type");
|
TorchToLinalg: Try folding shape computations to keep static shapes when possible (#3475)
Before this PR, a statically shaped aten.convolution would generate
dynamically shaped linalg IR, and even `-canonicalize` would not be able
to fold it back into static shapes. This PR ensure that shape
calculations are folded on construction to directly generate statically
shaped linalg IR.
We achieve that by ensuring that `arith` ops involved in computing
shapes are created via `createOrFold`, so that later uses of
`getAsOpFoldResult` see constants instead of those ops.
For example
```
module {
func.func @forward(%arg0: !torch.vtensor<[32,336,112,112],f32>,
%arg1: !torch.vtensor<[336,168,3,3],f32>,
%arg2: !torch.vtensor<[336],f32>)
-> !torch.vtensor<[32,336,56,56],f32> {
%false = torch.constant.bool false
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%0 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct : () -> !torch.list<int>
%3 = torch.aten.convolution %arg0, %arg1, %arg2, %1, %0, %0, %false, %2, %int2
: !torch.vtensor<[32,336,112,112],f32>, !torch.vtensor<[336,168,3,3],f32>, !torch.vtensor<[336],f32>, !torch.list<int>,
!torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int
-> !torch.vtensor<[32,336,56,56],f32>
return %3 : !torch.vtensor<[32,336,56,56],f32>
}
}
```
would result in
```
[...]
%padded = tensor.pad %2 low[%14, %15, %16, %17] high[%14, %15, %16, %17] {
^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
tensor.yield %cst : f32
} : tensor<32x336x112x112xf32> to tensor<?x?x?x?xf32>
[...]
%45 = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
ins(%expanded, %expanded_37 : tensor<?x2x?x?x?xf32>, tensor<2x168x168x3x3xf32>)
outs(%expanded_44 : tensor<32x2x168x?x?xf32>) -> tensor<32x2x168x?x?xf32>
[...]
```
and with this PR all shapes are static.
2024-06-27 14:43:10 +08:00
|
|
|
return b.createOrFold<arith::IndexCastOp>(loc, b.getIndexType(), v);
|
2022-03-03 00:42:25 +08:00
|
|
|
}
|
|
|
|
|
2022-04-22 01:10:04 +08:00
|
|
|
Value castIndexToInt64(OpBuilder &b, Location loc, Value idx) {
|
2024-05-31 14:45:13 +08:00
|
|
|
assert(isa<IndexType>(idx.getType()) && "must be called with integer type");
|
TorchToLinalg: Try folding shape computations to keep static shapes when possible (#3475)
Before this PR, a statically shaped aten.convolution would generate
dynamically shaped linalg IR, and even `-canonicalize` would not be able
to fold it back into static shapes. This PR ensure that shape
calculations are folded on construction to directly generate statically
shaped linalg IR.
We achieve that by ensuring that `arith` ops involved in computing
shapes are created via `createOrFold`, so that later uses of
`getAsOpFoldResult` see constants instead of those ops.
For example
```
module {
func.func @forward(%arg0: !torch.vtensor<[32,336,112,112],f32>,
%arg1: !torch.vtensor<[336,168,3,3],f32>,
%arg2: !torch.vtensor<[336],f32>)
-> !torch.vtensor<[32,336,56,56],f32> {
%false = torch.constant.bool false
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%0 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct : () -> !torch.list<int>
%3 = torch.aten.convolution %arg0, %arg1, %arg2, %1, %0, %0, %false, %2, %int2
: !torch.vtensor<[32,336,112,112],f32>, !torch.vtensor<[336,168,3,3],f32>, !torch.vtensor<[336],f32>, !torch.list<int>,
!torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int
-> !torch.vtensor<[32,336,56,56],f32>
return %3 : !torch.vtensor<[32,336,56,56],f32>
}
}
```
would result in
```
[...]
%padded = tensor.pad %2 low[%14, %15, %16, %17] high[%14, %15, %16, %17] {
^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
tensor.yield %cst : f32
} : tensor<32x336x112x112xf32> to tensor<?x?x?x?xf32>
[...]
%45 = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
ins(%expanded, %expanded_37 : tensor<?x2x?x?x?xf32>, tensor<2x168x168x3x3xf32>)
outs(%expanded_44 : tensor<32x2x168x?x?xf32>) -> tensor<32x2x168x?x?xf32>
[...]
```
and with this PR all shapes are static.
2024-06-27 14:43:10 +08:00
|
|
|
return b.createOrFold<arith::IndexCastOp>(loc, b.getI64Type(), idx);
|
2022-03-03 00:42:25 +08:00
|
|
|
}
|
|
|
|
|
2022-05-13 20:06:24 +08:00
|
|
|
SmallVector<Value>
|
|
|
|
castIntVectorToIndexVector(OpBuilder &b, Location loc,
|
|
|
|
SmallVectorImpl<Value> &intValues) {
|
|
|
|
SmallVector<Value> indexValues;
|
|
|
|
for (Value v : intValues)
|
|
|
|
indexValues.push_back(castIntToIndex(b, loc, v));
|
|
|
|
return indexValues;
|
|
|
|
}
|
|
|
|
|
2022-11-01 21:08:04 +08:00
|
|
|
SmallVector<Value>
|
|
|
|
castIndexVectorToInt64Vector(OpBuilder &b, Location loc,
|
|
|
|
SmallVectorImpl<Value> &indexValues) {
|
|
|
|
SmallVector<Value> intValues;
|
|
|
|
for (Value v : indexValues)
|
|
|
|
intValues.push_back(castIndexToInt64(b, loc, v));
|
|
|
|
return intValues;
|
|
|
|
}
|
|
|
|
|
2022-03-03 00:42:25 +08:00
|
|
|
Value getDimOp(OpBuilder &b, Location loc, Value v, int dim) {
|
|
|
|
return b.createOrFold<tensor::DimOp>(loc, v, dim);
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<Value> getTensorSizesUntilDim(OpBuilder &b, Location loc,
|
|
|
|
Value tensor, int dim) {
|
2024-04-28 05:00:56 +08:00
|
|
|
RankedTensorType type = cast<RankedTensorType>(tensor.getType());
|
2022-03-03 00:42:25 +08:00
|
|
|
assert(dim < type.getRank() &&
|
|
|
|
"The given dim must be smaller than tensor rank");
|
|
|
|
(void)type;
|
|
|
|
SmallVector<Value> sizes;
|
|
|
|
for (int i = 0; i <= dim; i++)
|
|
|
|
sizes.push_back(getDimOp(b, loc, tensor, i));
|
|
|
|
return sizes;
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<Value> getTensorSizes(OpBuilder &b, Location loc, Value tensor) {
|
2024-04-28 05:00:56 +08:00
|
|
|
RankedTensorType type = cast<RankedTensorType>(tensor.getType());
|
2022-03-03 00:42:25 +08:00
|
|
|
return getTensorSizesUntilDim(b, loc, tensor, type.getRank() - 1);
|
|
|
|
}
|
|
|
|
|
|
|
|
Value getTensorSize(OpBuilder &b, Location loc, Value tensor) {
|
|
|
|
SmallVector<Value> sizes(getTensorSizes(b, loc, tensor));
|
|
|
|
Value productResult = b.create<arith::ConstantOp>(loc, b.getIndexAttr(1));
|
|
|
|
for (Value size : sizes)
|
|
|
|
productResult = b.create<arith::MulIOp>(loc, productResult, size);
|
2022-04-22 01:10:04 +08:00
|
|
|
return castIndexToInt64(b, loc, productResult);
|
2022-03-03 00:42:25 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// Creates a constant of type `elemType` with value `val`.
|
|
|
|
Value getConstant(OpBuilder &b, Location loc, int64_t val, Type elemType) {
|
2023-04-25 23:52:46 +08:00
|
|
|
TypedAttr attr = {};
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<mlir::FloatType>(elemType))
|
2022-03-03 00:42:25 +08:00
|
|
|
attr = b.getFloatAttr(elemType, val);
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<mlir::IndexType>(elemType))
|
2022-03-03 00:42:25 +08:00
|
|
|
attr = b.getIndexAttr(val);
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<mlir::IntegerType>(elemType))
|
|
|
|
attr = b.getIntegerAttr(elemType,
|
|
|
|
APInt(cast<IntegerType>(elemType).getWidth(), val));
|
2022-03-03 00:42:25 +08:00
|
|
|
if (!attr)
|
|
|
|
return nullptr;
|
|
|
|
return b.create<arith::ConstantOp>(loc, elemType, attr);
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<Value> getAsConstantIntValues(OpBuilder &b, Location loc,
|
|
|
|
SmallVectorImpl<int64_t> &ints) {
|
|
|
|
return llvm::to_vector<4>(llvm::map_range(ints, [&](int64_t val) -> Value {
|
|
|
|
return b.create<arith::ConstantOp>(loc,
|
|
|
|
b.getIntegerAttr(b.getI64Type(), val));
|
|
|
|
}));
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<Value> getAsConstantIndexValues(OpBuilder &b, Location loc,
|
|
|
|
SmallVectorImpl<int64_t> &ints) {
|
|
|
|
return llvm::to_vector<4>(llvm::map_range(ints, [&](int64_t val) -> Value {
|
|
|
|
return b.create<arith::ConstantOp>(loc, b.getIndexAttr(val));
|
|
|
|
}));
|
|
|
|
}
|
|
|
|
|
|
|
|
// This is a temporary solution to deal with types that are not fully supported
|
|
|
|
// like list, dict. For those container tyes, this helper can be used to
|
|
|
|
// convert their elements to valid target type.
|
|
|
|
// TODO: remove this when list gets full support.
|
|
|
|
SmallVector<Value> getTypeConvertedValues(OpBuilder &b, Location loc,
|
2023-08-16 00:53:28 +08:00
|
|
|
const TypeConverter *converter,
|
2022-03-03 00:42:25 +08:00
|
|
|
SmallVectorImpl<Value> &vs) {
|
|
|
|
return llvm::to_vector<4>(llvm::map_range(vs, [&](Value v) {
|
|
|
|
return converter->materializeTargetConversion(
|
|
|
|
b, loc, converter->convertType(v.getType()), v);
|
|
|
|
}));
|
|
|
|
}
|
|
|
|
|
2022-12-22 03:04:07 +08:00
|
|
|
mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef<int64_t> shape,
|
|
|
|
mlir::Type elementType,
|
|
|
|
mlir::Attribute encoding) {
|
|
|
|
return mlir::RankedTensorType::get(makeShapeLLVMCompatible(shape),
|
|
|
|
elementType, encoding);
|
|
|
|
}
|
|
|
|
|
2024-03-01 00:18:46 +08:00
|
|
|
static std::optional<int64_t> getIntegerValue(Value scalar) {
|
|
|
|
if (auto constOp = scalar.getDefiningOp<Torch::ConstantIntOp>()) {
|
|
|
|
return std::optional<int64_t>(constOp.getValue());
|
|
|
|
}
|
|
|
|
return std::optional<int64_t>();
|
|
|
|
}
|
|
|
|
|
2022-03-03 00:42:25 +08:00
|
|
|
// Convert a scalar value to the target type. The scalar value can be an element
|
|
|
|
// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype
|
|
|
|
// should be converted builtin types.
|
2022-09-20 02:50:51 +08:00
|
|
|
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
2023-09-29 20:19:18 +08:00
|
|
|
std::optional<Type> srcOriginalDtype,
|
2024-03-01 00:18:46 +08:00
|
|
|
std::optional<Type> dstOriginalDtype,
|
|
|
|
std::optional<Value> originalScalar) {
|
2022-03-03 00:42:25 +08:00
|
|
|
Type scalarType = scalar.getType();
|
|
|
|
if (scalarType == dtype)
|
|
|
|
return scalar;
|
|
|
|
|
|
|
|
auto isByteOrChar = [](Type type) {
|
2024-04-11 21:47:35 +08:00
|
|
|
if (auto integerTy = dyn_cast<mlir::IntegerType>(type)) {
|
2022-03-03 00:42:25 +08:00
|
|
|
return integerTy.getWidth() == 8;
|
|
|
|
}
|
|
|
|
return false;
|
|
|
|
};
|
|
|
|
|
2024-03-01 00:18:46 +08:00
|
|
|
// We support conversion to Byte dtype only if the original scalar is an
|
|
|
|
// integer constant with value lying between 0 - 63.
|
2022-09-20 02:50:51 +08:00
|
|
|
if (isByteOrChar(dtype)) {
|
2023-09-29 20:19:18 +08:00
|
|
|
if (!dstOriginalDtype.has_value()) {
|
|
|
|
mlir::emitError(loc)
|
|
|
|
<< "unimplemented: for conversion to byte or char type "
|
|
|
|
"dstOriginalDtype has to be passed to convertScalarToDtype";
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
if (dstOriginalDtype->isUnsignedInteger()) {
|
2024-03-01 00:18:46 +08:00
|
|
|
if (originalScalar.has_value()) {
|
|
|
|
std::optional<int64_t> optConstVal =
|
|
|
|
getIntegerValue(originalScalar.value());
|
|
|
|
if (optConstVal.has_value()) {
|
|
|
|
int64_t constVal = optConstVal.value();
|
|
|
|
if (constVal < 0 || constVal > 63) {
|
|
|
|
// Do the conversion only if the original integer value is between
|
|
|
|
// 0 - 63.
|
|
|
|
mlir::emitError(loc)
|
|
|
|
<< "unsupported: conversion to byte type for "
|
|
|
|
"convertScalarToDtype "
|
|
|
|
<< scalarType << "(scalar type) -> " << dtype << "(dtype)";
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2023-09-29 20:19:18 +08:00
|
|
|
}
|
2022-03-03 00:42:25 +08:00
|
|
|
}
|
|
|
|
|
2022-06-14 22:05:22 +08:00
|
|
|
// If the dtype is i1, i.e., a boolean type.
|
|
|
|
if (dtype.isSignlessInteger(1)) {
|
|
|
|
Type scalarType = scalar.getType();
|
|
|
|
Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(scalarType));
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<mlir::FloatType>(scalarType)) {
|
2022-06-14 22:05:22 +08:00
|
|
|
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, scalar,
|
|
|
|
cstZero);
|
2024-04-11 21:47:35 +08:00
|
|
|
} else if (isa<mlir::IntegerType>(scalarType)) {
|
2022-06-14 22:05:22 +08:00
|
|
|
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, scalar,
|
|
|
|
cstZero);
|
|
|
|
} else {
|
|
|
|
mlir::emitError(loc)
|
|
|
|
<< "unsupported scalar type for convertScalarToDtype " << scalarType
|
|
|
|
<< "(scalar type) -> " << dtype << "(dtype)";
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-04-11 21:47:35 +08:00
|
|
|
if (auto dtypeFloat = dyn_cast<mlir::FloatType>(dtype)) {
|
|
|
|
if (auto scalarFloat = dyn_cast<mlir::FloatType>(scalarType)) {
|
2022-03-03 00:42:25 +08:00
|
|
|
if (scalarFloat.getWidth() > dtypeFloat.getWidth())
|
|
|
|
return b.create<arith::TruncFOp>(loc, dtype, scalar);
|
|
|
|
// Only scalarFloat width < dtypeFloat width can reach here.
|
|
|
|
return b.create<arith::ExtFOp>(loc, dtype, scalar);
|
|
|
|
}
|
2024-04-11 21:47:35 +08:00
|
|
|
assert(isa<mlir::IntegerType>(scalarType));
|
2022-09-20 02:50:51 +08:00
|
|
|
if (scalarType.isSignlessInteger(1) ||
|
|
|
|
(srcOriginalDtype.has_value() && srcOriginalDtype->isUnsignedInteger()))
|
2022-03-03 00:42:25 +08:00
|
|
|
return b.create<arith::UIToFPOp>(loc, dtype, scalar);
|
|
|
|
// It's safe to use SIToFPOp because ui8/si8 are the only ones where
|
|
|
|
// unsigned handling is needed, and we checked for that case above.
|
|
|
|
return b.create<arith::SIToFPOp>(loc, dtype, scalar);
|
|
|
|
}
|
|
|
|
|
2024-04-11 21:47:35 +08:00
|
|
|
if (auto dtypeInteger = dyn_cast<mlir::IntegerType>(dtype)) {
|
|
|
|
if (auto scalarFloat = dyn_cast<mlir::FloatType>(scalarType))
|
2022-03-03 00:42:25 +08:00
|
|
|
return b.create<arith::FPToSIOp>(loc, dtype, scalar);
|
2024-04-11 21:47:35 +08:00
|
|
|
assert(isa<mlir::IntegerType>(scalarType));
|
|
|
|
auto scalarInteger = cast<mlir::IntegerType>(scalarType);
|
2022-03-03 00:42:25 +08:00
|
|
|
if (scalarInteger.getWidth() > dtypeInteger.getWidth())
|
|
|
|
return b.create<arith::TruncIOp>(loc, dtype, scalar);
|
2022-09-20 02:50:51 +08:00
|
|
|
if (scalarType.isSignlessInteger(1) ||
|
|
|
|
(srcOriginalDtype.has_value() && srcOriginalDtype->isUnsignedInteger()))
|
2022-03-03 00:42:25 +08:00
|
|
|
return b.create<arith::ExtUIOp>(loc, dtype, scalar);
|
|
|
|
// Only scalarInteger width < dtypeInteger width can reach here.
|
|
|
|
// It's safe to use ExtSIOp here because ui8/si8 are the only ones where
|
|
|
|
// unsigned handling is needed, and we checked for that case above.
|
|
|
|
return b.create<arith::ExtSIOp>(loc, dtype, scalar);
|
|
|
|
}
|
|
|
|
|
2024-05-01 14:36:53 +08:00
|
|
|
if (auto dtypeComplex = dyn_cast<mlir::ComplexType>(dtype)) {
|
2024-08-07 15:06:48 +08:00
|
|
|
|
|
|
|
// Complex to complex.
|
2024-05-01 14:36:53 +08:00
|
|
|
if (auto scalarComplex = dyn_cast<mlir::ComplexType>(scalarType)) {
|
|
|
|
auto dtypeElemType = dtypeComplex.getElementType();
|
|
|
|
|
|
|
|
// Extract the real and imaginary parts of the scalar.
|
|
|
|
// Cast them to the target element type, and create a new complex
|
|
|
|
// value with the target complex type.
|
|
|
|
Value realVal = b.create<complex::ReOp>(loc, scalar);
|
|
|
|
Value imgVal = b.create<complex::ImOp>(loc, scalar);
|
|
|
|
|
|
|
|
realVal = convertScalarToDtype(b, loc, realVal, dtypeElemType);
|
|
|
|
imgVal = convertScalarToDtype(b, loc, imgVal, dtypeElemType);
|
|
|
|
|
|
|
|
return b.create<complex::CreateOp>(loc, dtypeComplex, realVal, imgVal);
|
|
|
|
}
|
2024-08-07 15:06:48 +08:00
|
|
|
|
|
|
|
// Float to complex type.
|
|
|
|
if (auto dtypeFloat = dyn_cast<mlir::FloatType>(scalarType)) {
|
|
|
|
auto complexElementType =
|
|
|
|
cast<mlir::FloatType>(dtypeComplex.getElementType());
|
|
|
|
Value realVal;
|
|
|
|
Value imgVal =
|
|
|
|
b.create<arith::ConstantOp>(loc, b.getZeroAttr(complexElementType));
|
|
|
|
|
|
|
|
if (complexElementType.getWidth() > dtypeFloat.getWidth()) {
|
|
|
|
realVal = b.create<arith::ExtFOp>(loc, complexElementType, scalar);
|
|
|
|
} else if (complexElementType.getWidth() < dtypeFloat.getWidth()) {
|
|
|
|
realVal = b.create<arith::TruncFOp>(loc, complexElementType, scalar);
|
|
|
|
} else {
|
|
|
|
realVal = scalar;
|
|
|
|
}
|
|
|
|
|
|
|
|
return b.create<complex::CreateOp>(loc, dtypeComplex, realVal, imgVal);
|
|
|
|
}
|
|
|
|
|
2024-08-14 20:43:00 +08:00
|
|
|
// Int to complex type.
|
|
|
|
if (auto dtypeInt = dyn_cast<mlir::IntegerType>(scalarType)) {
|
|
|
|
auto complexElementType =
|
|
|
|
cast<mlir::FloatType>(dtypeComplex.getElementType());
|
|
|
|
|
|
|
|
Value realVal =
|
|
|
|
b.create<arith::SIToFPOp>(loc, complexElementType, scalar);
|
|
|
|
Value imgVal =
|
|
|
|
b.create<arith::ConstantOp>(loc, b.getZeroAttr(complexElementType));
|
|
|
|
|
|
|
|
return b.create<complex::CreateOp>(loc, dtypeComplex, realVal, imgVal);
|
|
|
|
}
|
|
|
|
|
2024-05-01 14:36:53 +08:00
|
|
|
mlir::emitError(loc) << "unsupported scalar type for convertScalarToDtype "
|
|
|
|
<< scalarType << "(scalar type) -> " << dtype
|
|
|
|
<< "(dtype)";
|
|
|
|
}
|
|
|
|
|
2022-03-03 00:42:25 +08:00
|
|
|
llvm_unreachable("convertScalarToDtype should handle all the types");
|
|
|
|
}
|
|
|
|
|
2023-03-23 04:41:04 +08:00
|
|
|
Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
|
|
|
|
Value torchOptionalInt, Value builtinInt,
|
|
|
|
Value defaultValue, Value dimSize) {
|
2024-05-31 14:45:13 +08:00
|
|
|
if (isa<Torch::NoneType>(torchOptionalInt.getType()))
|
2023-03-23 04:41:04 +08:00
|
|
|
return defaultValue;
|
|
|
|
auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize);
|
|
|
|
Value positiveDim =
|
|
|
|
toPositiveDimDynamic(rewriter, loc, builtinInt, dimSizeAsInt);
|
|
|
|
// positiveDim < 0 ? 0 : positiveDim
|
|
|
|
Value cst0 = rewriter.create<arith::ConstantOp>(
|
|
|
|
loc, rewriter.getZeroAttr(dimSizeAsInt.getType()));
|
|
|
|
Value predDimSltZero = rewriter.create<arith::CmpIOp>(
|
|
|
|
loc, arith::CmpIPredicate::slt, positiveDim, cst0);
|
|
|
|
Value atLeastZero =
|
|
|
|
rewriter.create<arith::SelectOp>(loc, predDimSltZero, cst0, positiveDim);
|
|
|
|
// atLeastZero > dimSizeAsInt ? dimSizeAsInt : atLeastZero
|
|
|
|
Value sgtDimSize = rewriter.create<arith::CmpIOp>(
|
|
|
|
loc, arith::CmpIPredicate::sgt, atLeastZero, dimSizeAsInt);
|
|
|
|
Value boundedByDimSize = rewriter.create<arith::SelectOp>(
|
|
|
|
loc, sgtDimSize, dimSizeAsInt, atLeastZero);
|
|
|
|
|
|
|
|
return castIntToIndex(rewriter, loc, boundedByDimSize);
|
|
|
|
}
|
2022-12-08 13:46:54 +08:00
|
|
|
|
2022-03-03 00:42:25 +08:00
|
|
|
} // namespace Torch
|
|
|
|
} // namespace torch
|
|
|
|
} // namespace mlir
|