2021-10-16 06:23:59 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
2022-05-06 09:35:34 +08:00
|
|
|
#include "mlir/IR/BuiltinDialect.h"
|
Add type promotion code to refine types.
The types have different levels of categories: where
complex > floating > integral > boolean (> means left hand
side has higher category).
The operands have different levels of priorities where:
dimensioned tensor > 0-dim tensor > scalar == wrapped 0-dim tensor.
This is represented by the `ResultTypeState.dimResult`,
`ResultTypeState.zeroResult` and `ResultTypeState..wrappedResult` in
the source code.
For operands of the same priorities, the result type should be the
highest categories with sufficient width to hold all operands.
By default, only the highest priority operands participate in the type
promotion logic. Lower priority operands participate if they are in
a higher category than any higher priority operands.
For example, <[],f32> (lower priority) and <[1], si64> tensor would
result in <[?],f32> tensor because floating > integeral. Another example
<[],f64> (lower priority) and <[1], f32> tensor would result in
<[?], f32> tensor because f32 and f64 are the same category.
The ScalarType enum definition, type promotion table, ResultTypeState
struct definition and some helpers are copied from
aten/src/ATen/native/TypeProperties.*
Other references:
- https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc
- https://github.com/pytorch/pytorch/issues/9515
Other minor changes:
1. Fix `visitExpandLikeOp` to consider cases where the given sizes list
size is larger than the input rank.
2. Add back the somehow deleted `torch.aten.softmax.int` tests in
decompose-complex-ops.mlir.
2021-10-21 03:31:28 +08:00
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
2023-12-21 23:29:22 +08:00
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
2024-05-16 01:09:27 +08:00
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/SparsityUtils.h"
|
2021-10-16 06:23:59 +08:00
|
|
|
|
2022-03-16 08:08:45 +08:00
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::torch;
|
|
|
|
using namespace mlir::torch::Torch;
|
2022-02-04 19:43:25 +08:00
|
|
|
|
2022-03-16 08:08:45 +08:00
|
|
|
int64_t Torch::toPositiveDim(int64_t dim, int64_t inputRank) {
|
2021-10-16 06:23:59 +08:00
|
|
|
return dim >= 0 ? dim : dim + inputRank;
|
|
|
|
}
|
|
|
|
|
2022-03-16 08:08:45 +08:00
|
|
|
bool Torch::isValidDim(int64_t dim, int64_t inputRank) {
|
2021-10-16 06:23:59 +08:00
|
|
|
return dim >= 0 && dim < inputRank;
|
|
|
|
}
|
|
|
|
|
2022-12-20 18:17:27 +08:00
|
|
|
std::optional<int64_t>
|
2022-03-30 04:21:47 +08:00
|
|
|
Torch::matchLegalConstantIndexIntoListOfSize(Value v, int64_t length) {
|
|
|
|
int64_t dim;
|
|
|
|
if (!matchPattern(v, m_TorchConstantInt(&dim)))
|
2022-12-14 16:06:39 +08:00
|
|
|
return std::nullopt;
|
2022-03-30 04:21:47 +08:00
|
|
|
dim = toPositiveDim(dim, length);
|
|
|
|
if (!isValidDim(dim, length))
|
2022-12-14 16:06:39 +08:00
|
|
|
return std::nullopt;
|
2022-03-30 04:21:47 +08:00
|
|
|
return dim;
|
|
|
|
}
|
|
|
|
|
2022-03-16 08:08:45 +08:00
|
|
|
bool Torch::getListConstructElements(Value v, SmallVectorImpl<Value> &elems) {
|
Add type promotion code to refine types.
The types have different levels of categories: where
complex > floating > integral > boolean (> means left hand
side has higher category).
The operands have different levels of priorities where:
dimensioned tensor > 0-dim tensor > scalar == wrapped 0-dim tensor.
This is represented by the `ResultTypeState.dimResult`,
`ResultTypeState.zeroResult` and `ResultTypeState..wrappedResult` in
the source code.
For operands of the same priorities, the result type should be the
highest categories with sufficient width to hold all operands.
By default, only the highest priority operands participate in the type
promotion logic. Lower priority operands participate if they are in
a higher category than any higher priority operands.
For example, <[],f32> (lower priority) and <[1], si64> tensor would
result in <[?],f32> tensor because floating > integeral. Another example
<[],f64> (lower priority) and <[1], f32> tensor would result in
<[?], f32> tensor because f32 and f64 are the same category.
The ScalarType enum definition, type promotion table, ResultTypeState
struct definition and some helpers are copied from
aten/src/ATen/native/TypeProperties.*
Other references:
- https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc
- https://github.com/pytorch/pytorch/issues/9515
Other minor changes:
1. Fix `visitExpandLikeOp` to consider cases where the given sizes list
size is larger than the input rank.
2. Add back the somehow deleted `torch.aten.softmax.int` tests in
decompose-complex-ops.mlir.
2021-10-21 03:31:28 +08:00
|
|
|
auto listConstruct = v.getDefiningOp<PrimListConstructOp>();
|
|
|
|
if (!listConstruct)
|
|
|
|
return false;
|
2022-12-08 04:20:41 +08:00
|
|
|
elems = llvm::to_vector<4>(listConstruct.getElements());
|
Add type promotion code to refine types.
The types have different levels of categories: where
complex > floating > integral > boolean (> means left hand
side has higher category).
The operands have different levels of priorities where:
dimensioned tensor > 0-dim tensor > scalar == wrapped 0-dim tensor.
This is represented by the `ResultTypeState.dimResult`,
`ResultTypeState.zeroResult` and `ResultTypeState..wrappedResult` in
the source code.
For operands of the same priorities, the result type should be the
highest categories with sufficient width to hold all operands.
By default, only the highest priority operands participate in the type
promotion logic. Lower priority operands participate if they are in
a higher category than any higher priority operands.
For example, <[],f32> (lower priority) and <[1], si64> tensor would
result in <[?],f32> tensor because floating > integeral. Another example
<[],f64> (lower priority) and <[1], f32> tensor would result in
<[?], f32> tensor because f32 and f64 are the same category.
The ScalarType enum definition, type promotion table, ResultTypeState
struct definition and some helpers are copied from
aten/src/ATen/native/TypeProperties.*
Other references:
- https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc
- https://github.com/pytorch/pytorch/issues/9515
Other minor changes:
1. Fix `visitExpandLikeOp` to consider cases where the given sizes list
size is larger than the input rank.
2. Add back the somehow deleted `torch.aten.softmax.int` tests in
decompose-complex-ops.mlir.
2021-10-21 03:31:28 +08:00
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2022-03-16 08:08:45 +08:00
|
|
|
torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
|
2024-04-28 05:00:56 +08:00
|
|
|
if (isa<Float32Type>(type))
|
2022-03-16 08:08:45 +08:00
|
|
|
return torch_upstream::ScalarType::Float;
|
2024-04-28 05:00:56 +08:00
|
|
|
if (isa<Float64Type>(type))
|
2022-03-16 08:08:45 +08:00
|
|
|
return torch_upstream::ScalarType::Double;
|
2022-02-04 19:43:25 +08:00
|
|
|
if (type.isSignedInteger(64))
|
2022-03-16 08:08:45 +08:00
|
|
|
return torch_upstream::ScalarType::Long;
|
2022-02-04 19:43:25 +08:00
|
|
|
if (type.isSignedInteger(32))
|
2022-03-16 08:08:45 +08:00
|
|
|
return torch_upstream::ScalarType::Int;
|
2023-12-07 21:36:48 +08:00
|
|
|
if (type.isSignedInteger(16))
|
|
|
|
return torch_upstream::ScalarType::Short;
|
2022-04-14 01:28:27 +08:00
|
|
|
if (type.isSignlessInteger(1))
|
2022-03-16 08:08:45 +08:00
|
|
|
return torch_upstream::ScalarType::Bool;
|
2022-04-30 00:01:49 +08:00
|
|
|
if (type.isBF16())
|
|
|
|
return torch_upstream::ScalarType::BFloat16;
|
2022-08-08 12:37:31 +08:00
|
|
|
if (type.isF16())
|
|
|
|
return torch_upstream::ScalarType::Half;
|
2022-09-20 02:50:51 +08:00
|
|
|
if (type.isUnsignedInteger(8))
|
|
|
|
return torch_upstream::ScalarType::Byte;
|
|
|
|
if (type.isSignedInteger(8))
|
|
|
|
return torch_upstream::ScalarType::Char;
|
2024-04-28 05:00:56 +08:00
|
|
|
if (isa<QUInt8Type>(type))
|
2024-01-13 11:11:14 +08:00
|
|
|
return torch_upstream::ScalarType::QUInt8;
|
2024-04-28 05:00:56 +08:00
|
|
|
if (isa<QInt8Type>(type))
|
2024-01-13 11:11:14 +08:00
|
|
|
return torch_upstream::ScalarType::QInt8;
|
2024-06-12 13:07:22 +08:00
|
|
|
if (isa<QInt16Type>(type))
|
|
|
|
return torch_upstream::ScalarType::QInt16;
|
2024-04-28 05:00:56 +08:00
|
|
|
if (isa<QInt32Type>(type))
|
2024-01-13 11:11:14 +08:00
|
|
|
return torch_upstream::ScalarType::QInt32;
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<ComplexType>(type)) {
|
|
|
|
mlir::Type complexElemType = cast<ComplexType>(type).getElementType();
|
2023-09-02 12:12:01 +08:00
|
|
|
if (complexElemType.isF16())
|
2022-12-16 05:40:01 +08:00
|
|
|
return torch_upstream::ScalarType::ComplexHalf;
|
2023-09-02 12:12:01 +08:00
|
|
|
if (complexElemType.isF32())
|
2022-12-16 05:40:01 +08:00
|
|
|
return torch_upstream::ScalarType::ComplexFloat;
|
2023-09-02 12:12:01 +08:00
|
|
|
if (complexElemType.isF64())
|
2022-12-16 05:40:01 +08:00
|
|
|
return torch_upstream::ScalarType::ComplexDouble;
|
|
|
|
}
|
2024-06-08 04:59:38 +08:00
|
|
|
if (isa<Float8E5M2Type>(type))
|
|
|
|
return torch_upstream::ScalarType::Float8_e5m2;
|
|
|
|
if (isa<Float8E4M3FNType>(type))
|
|
|
|
return torch_upstream::ScalarType::Float8_e4m3fn;
|
|
|
|
if (isa<Float8E5M2FNUZType>(type))
|
|
|
|
return torch_upstream::ScalarType::Float8_e5m2fnuz;
|
|
|
|
if (isa<Float8E4M3FNUZType>(type))
|
|
|
|
return torch_upstream::ScalarType::Float8_e4m3fnuz;
|
2022-02-04 19:43:25 +08:00
|
|
|
llvm::report_fatal_error("unhandled type for getScalarTypeForType");
|
|
|
|
}
|
2022-09-23 10:24:36 +08:00
|
|
|
Type Torch::getTypeForTorchType(
|
|
|
|
MLIRContext *context, Type type,
|
|
|
|
mlir::IntegerType::SignednessSemantics signedness) {
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<Torch::IntType>(type))
|
2022-09-23 10:24:36 +08:00
|
|
|
return IntegerType::get(context, 64, signedness);
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<Torch::FloatType>(type))
|
2022-09-23 10:24:36 +08:00
|
|
|
return Float64Type::get(context);
|
|
|
|
llvm::report_fatal_error("unhandled type for getTypeForTorchType");
|
|
|
|
}
|
|
|
|
|
2023-01-21 02:40:13 +08:00
|
|
|
FailureOr<Type>
|
|
|
|
Torch::getTypeForScalarType(MLIRContext *context,
|
2023-11-30 01:43:09 +08:00
|
|
|
torch_upstream::ScalarType dtypeInt) {
|
2022-03-25 00:40:21 +08:00
|
|
|
switch (dtypeInt) {
|
|
|
|
case torch_upstream::ScalarType::Float:
|
|
|
|
return Float32Type::get(context);
|
|
|
|
case torch_upstream::ScalarType::Double:
|
|
|
|
return Float64Type::get(context);
|
|
|
|
case torch_upstream::ScalarType::Long:
|
2023-11-30 01:43:09 +08:00
|
|
|
return IntegerType::get(context, 64, mlir::IntegerType::Signed);
|
2022-03-25 00:40:21 +08:00
|
|
|
case torch_upstream::ScalarType::Int:
|
2023-11-30 01:43:09 +08:00
|
|
|
return IntegerType::get(context, 32, mlir::IntegerType::Signed);
|
2023-12-07 21:36:48 +08:00
|
|
|
case torch_upstream::ScalarType::Short:
|
|
|
|
return IntegerType::get(context, 16, mlir::IntegerType::Signed);
|
2022-03-25 00:40:21 +08:00
|
|
|
case torch_upstream::ScalarType::Bool:
|
|
|
|
return IntegerType::get(context, 1);
|
2022-04-30 00:01:49 +08:00
|
|
|
case torch_upstream::ScalarType::BFloat16:
|
|
|
|
return mlir::FloatType::getBF16(context);
|
2022-08-08 12:37:31 +08:00
|
|
|
case torch_upstream::ScalarType::Half:
|
|
|
|
return mlir::FloatType::getF16(context);
|
2022-09-20 02:50:51 +08:00
|
|
|
case torch_upstream::ScalarType::Byte:
|
2023-06-13 10:38:20 +08:00
|
|
|
return mlir::IntegerType::get(context, 8, mlir::IntegerType::Unsigned);
|
2022-09-20 02:50:51 +08:00
|
|
|
case torch_upstream::ScalarType::Char:
|
2023-11-30 01:43:09 +08:00
|
|
|
return mlir::IntegerType::get(context, 8, mlir::IntegerType::Signed);
|
2024-01-13 11:11:14 +08:00
|
|
|
case torch_upstream::ScalarType::QUInt8:
|
|
|
|
return QUInt8Type::get(context);
|
|
|
|
case torch_upstream::ScalarType::QInt8:
|
|
|
|
return QInt8Type::get(context);
|
2024-06-12 13:07:22 +08:00
|
|
|
case torch_upstream::ScalarType::QInt16:
|
|
|
|
return QInt16Type::get(context);
|
2024-01-13 11:11:14 +08:00
|
|
|
case torch_upstream::ScalarType::QInt32:
|
|
|
|
return QInt32Type::get(context);
|
2022-12-16 05:40:01 +08:00
|
|
|
case torch_upstream::ScalarType::ComplexHalf:
|
2023-09-02 12:12:01 +08:00
|
|
|
return mlir::ComplexType::get(Float16Type::get(context));
|
2022-12-16 05:40:01 +08:00
|
|
|
case torch_upstream::ScalarType::ComplexFloat:
|
2023-09-02 12:12:01 +08:00
|
|
|
return mlir::ComplexType::get(Float32Type::get(context));
|
2022-12-16 05:40:01 +08:00
|
|
|
case torch_upstream::ScalarType::ComplexDouble:
|
2023-09-02 12:12:01 +08:00
|
|
|
return mlir::ComplexType::get(Float64Type::get(context));
|
2024-06-08 04:59:38 +08:00
|
|
|
case torch_upstream::ScalarType::Float8_e5m2:
|
|
|
|
return Float8E5M2Type::get(context);
|
|
|
|
case torch_upstream::ScalarType::Float8_e4m3fn:
|
|
|
|
return Float8E4M3FNType::get(context);
|
|
|
|
case torch_upstream::ScalarType::Float8_e5m2fnuz:
|
|
|
|
return Float8E5M2FNUZType::get(context);
|
|
|
|
case torch_upstream::ScalarType::Float8_e4m3fnuz:
|
|
|
|
return Float8E4M3FNUZType::get(context);
|
2023-01-21 02:40:13 +08:00
|
|
|
case torch_upstream::ScalarType::Undefined:
|
|
|
|
return failure();
|
2022-03-25 00:40:21 +08:00
|
|
|
default:
|
2022-12-16 00:33:14 +08:00
|
|
|
llvm::report_fatal_error("unhandled type for getTypeForScalarType");
|
2022-03-25 00:40:21 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-12-14 00:25:41 +08:00
|
|
|
FailureOr<Type>
|
|
|
|
Torch::getTorchTypeForScalarType(MLIRContext *context,
|
|
|
|
torch_upstream::ScalarType dtypeInt) {
|
2022-05-06 09:35:34 +08:00
|
|
|
switch (dtypeInt) {
|
|
|
|
case torch_upstream::ScalarType::Double:
|
|
|
|
return Torch::FloatType::get(context);
|
|
|
|
case torch_upstream::ScalarType::Long:
|
|
|
|
return Torch::IntType::get(context);
|
2023-01-21 02:40:13 +08:00
|
|
|
case torch_upstream::ScalarType::Undefined:
|
2022-05-06 09:35:34 +08:00
|
|
|
default:
|
2022-12-14 00:25:41 +08:00
|
|
|
return failure();
|
2022-05-06 09:35:34 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-12-14 00:25:41 +08:00
|
|
|
Type Torch::getDefaultDtypeForTorchScalar(Type type) {
|
|
|
|
MLIRContext *context = type.getContext();
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<Torch::FloatType>(type)) {
|
2022-12-14 00:25:41 +08:00
|
|
|
// For now, use float32 which is the initial default dtype returned by
|
|
|
|
// `torch.get_default_dtype`.
|
|
|
|
return Float32Type::get(context);
|
|
|
|
}
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<Torch::IntType>(type))
|
2022-12-14 00:25:41 +08:00
|
|
|
return IntegerType::get(context, 64, IntegerType::Signed);
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<Torch::BoolType>(type))
|
2022-12-14 00:25:41 +08:00
|
|
|
return IntegerType::get(context, 1);
|
|
|
|
llvm_unreachable(
|
|
|
|
"getDefaultDtypeForTorchScalar called on an unsupported type");
|
|
|
|
}
|
|
|
|
|
|
|
|
Type Torch::getBuiltInTypeForTorchScalar(Type type) {
|
|
|
|
MLIRContext *context = type.getContext();
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<Torch::FloatType>(type))
|
2022-12-14 00:25:41 +08:00
|
|
|
return Float64Type::get(context);
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<Torch::IntType>(type))
|
2022-12-14 00:25:41 +08:00
|
|
|
return IntegerType::get(context, 64, IntegerType::Signed);
|
2024-04-11 21:47:35 +08:00
|
|
|
if (isa<Torch::BoolType>(type))
|
2022-12-14 00:25:41 +08:00
|
|
|
return IntegerType::get(context, 1);
|
|
|
|
llvm_unreachable(
|
|
|
|
"getBuiltInTypeForTorchScalar called on an unsupported type");
|
|
|
|
}
|
|
|
|
|
2022-03-25 00:40:21 +08:00
|
|
|
Value Torch::getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
|
2022-03-16 08:08:45 +08:00
|
|
|
Type dtype) {
|
2022-03-03 00:48:15 +08:00
|
|
|
int intType = (int)getScalarTypeForType(dtype);
|
|
|
|
return rewriter.create<ConstantIntOp>(loc,
|
|
|
|
rewriter.getI64IntegerAttr(intType));
|
|
|
|
}
|
|
|
|
|
|
|
|
// Helper to convert a tensor to a specific scalar type.
|
2022-03-16 08:08:45 +08:00
|
|
|
Value Torch::convertTensorToDtype(PatternRewriter &rewriter, Location loc,
|
|
|
|
Value input, Type dtype) {
|
2024-04-28 05:00:56 +08:00
|
|
|
BaseTensorType origType = cast<BaseTensorType>(input.getType());
|
2022-03-03 00:48:15 +08:00
|
|
|
Type newType = origType.getWithSizesAndDtype(origType.getSizes(), dtype);
|
|
|
|
// `convertIntVal` contains the corresponding integer for the dtype which is
|
|
|
|
// used by the aten.to.dtype op.
|
|
|
|
Value convertIntVal = getDtypeIntValueForType(rewriter, loc, dtype);
|
|
|
|
Value falseVal = rewriter.create<ConstantBoolOp>(loc, false);
|
|
|
|
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
Value converted = rewriter.create<AtenToDtypeOp>(
|
|
|
|
loc, newType, input, convertIntVal, falseVal, falseVal, noneVal);
|
|
|
|
return converted;
|
|
|
|
}
|
2022-04-01 00:27:21 +08:00
|
|
|
|
2022-05-06 09:35:34 +08:00
|
|
|
bool Torch::isBuiltInType(Type type) {
|
|
|
|
return isa<BuiltinDialect>(type.getDialect());
|
|
|
|
}
|
|
|
|
|
2022-12-20 18:17:27 +08:00
|
|
|
std::optional<unsigned> Torch::getTensorRank(Value tensor) {
|
2024-04-28 05:00:56 +08:00
|
|
|
BaseTensorType tensorType = cast<BaseTensorType>(tensor.getType());
|
2022-12-13 00:56:28 +08:00
|
|
|
if (!tensorType.hasSizes())
|
2022-12-14 16:06:39 +08:00
|
|
|
return std::nullopt;
|
2022-12-13 00:56:28 +08:00
|
|
|
return tensorType.getSizes().size();
|
2022-04-01 00:27:21 +08:00
|
|
|
}
|
2022-09-30 00:40:56 +08:00
|
|
|
|
2024-05-16 15:27:25 +08:00
|
|
|
std::optional<int64_t> Torch::getTensorNumel(Value tensor) {
|
|
|
|
BaseTensorType tensorType = cast<BaseTensorType>(tensor.getType());
|
|
|
|
if (!tensorType.hasSizes())
|
|
|
|
return std::nullopt;
|
|
|
|
int64_t numel = 1;
|
|
|
|
for (auto dim : tensorType.getSizes()) {
|
|
|
|
if (dim == ShapedType::kDynamic)
|
|
|
|
return ShapedType::kDynamic;
|
|
|
|
numel *= dim;
|
|
|
|
}
|
|
|
|
return numel;
|
|
|
|
}
|
|
|
|
|
2022-09-30 00:40:56 +08:00
|
|
|
bool Torch::isViewLikeOp(Operation *op) {
|
|
|
|
// AtenContiguousOp might return a view, so this is conservatively
|
|
|
|
// correct. We could potentially be more precise and identify the cases
|
|
|
|
// that it does not return a view and treat those as having value
|
|
|
|
// semantics.
|
|
|
|
return isa<AtenBroadcastToOp, AtenContiguousOp, AtenDetachOp, AtenExpandAsOp,
|
2023-10-14 09:39:41 +08:00
|
|
|
AtenExpandOp, AtenFlattenUsingIntsOp, AtenUnflattenIntOp,
|
|
|
|
AtenPermuteOp, AtenReshapeOp, Aten_ReshapeAliasOp, AtenSelectIntOp,
|
|
|
|
AtenSliceTensorOp, AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp,
|
|
|
|
AtenToDtypeOp, AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
|
2022-09-30 00:40:56 +08:00
|
|
|
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp,
|
2023-07-20 16:46:44 +08:00
|
|
|
AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp,
|
|
|
|
AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp,
|
2023-11-21 23:56:09 +08:00
|
|
|
PrimsSplitDimOp, AtenViewAsComplexOp, AtenViewAsRealOp,
|
2024-01-23 01:47:13 +08:00
|
|
|
AtenPixelShuffleOp, AtenDiagonalOp>(op);
|
2022-09-30 00:40:56 +08:00
|
|
|
}
|
2022-10-04 21:05:59 +08:00
|
|
|
|
|
|
|
Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,
|
|
|
|
Location loc, float value,
|
|
|
|
Type dtype) {
|
|
|
|
// Creating constants satisfying backend contract.
|
2023-12-07 21:36:48 +08:00
|
|
|
if (dtype.isInteger(64) || dtype.isInteger(32) || dtype.isInteger(16) ||
|
|
|
|
dtype.isInteger(8) || dtype.isInteger(1))
|
2022-10-04 21:05:59 +08:00
|
|
|
return rewriter.create<ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr((int64_t)value));
|
|
|
|
if (dtype.isF64() || dtype.isF32() || dtype.isF16() || dtype.isBF16())
|
|
|
|
return rewriter.create<ConstantFloatOp>(loc,
|
|
|
|
rewriter.getF64FloatAttr(value));
|
|
|
|
llvm::report_fatal_error(
|
|
|
|
"unhandled type for getConstantWithGivenDtypeAndValue");
|
|
|
|
}
|
2022-12-09 01:49:54 +08:00
|
|
|
|
|
|
|
// Return the number of elements of a tensor if the shape is static; otherwise,
|
|
|
|
// return -1.
|
|
|
|
int64_t Torch::getNumberOfElements(RankedTensorType inputType) {
|
|
|
|
if (!inputType.hasStaticShape())
|
|
|
|
return -1;
|
|
|
|
SmallVector<int64_t> inputShape =
|
|
|
|
makeShapeTorchCompatible(inputType.getShape());
|
|
|
|
int64_t numel = 1;
|
|
|
|
for (int64_t i = 0; i < inputType.getRank(); i++)
|
|
|
|
numel *= inputShape[i];
|
|
|
|
return numel;
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<int64_t> Torch::makeShapeLLVMCompatible(ArrayRef<int64_t> shape) {
|
|
|
|
SmallVector<int64_t> updatedShape(shape);
|
|
|
|
int64_t kDynamic = ShapedType::kDynamic;
|
|
|
|
for (unsigned i = 0; i < shape.size(); i++) {
|
|
|
|
assert(shape[i] >= 0 || shape[i] == kUnknownSize);
|
|
|
|
if (shape[i] == kUnknownSize)
|
|
|
|
updatedShape[i] = kDynamic;
|
|
|
|
}
|
|
|
|
return updatedShape;
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<int64_t> Torch::makeShapeTorchCompatible(ArrayRef<int64_t> shape) {
|
|
|
|
SmallVector<int64_t> updatedShape(shape);
|
|
|
|
int64_t kDynamic = ShapedType::kDynamic;
|
|
|
|
for (unsigned i = 0; i < shape.size(); i++) {
|
|
|
|
assert(shape[i] >= 0 || shape[i] == kDynamic);
|
|
|
|
if (shape[i] == kDynamic)
|
|
|
|
updatedShape[i] = kUnknownSize;
|
|
|
|
}
|
|
|
|
return updatedShape;
|
|
|
|
}
|
2022-11-16 13:57:58 +08:00
|
|
|
|
2024-05-22 23:16:57 +08:00
|
|
|
ValueTensorType Torch::getTensorTypeFromShapeValues(ArrayRef<Value> shapes,
|
|
|
|
Type dtype) {
|
|
|
|
assert(!shapes.empty() && "shape vector cannot be empty");
|
|
|
|
SmallVector<int64_t> shapeInts;
|
|
|
|
for (Value shape : shapes) {
|
|
|
|
int64_t dim;
|
|
|
|
if (matchPattern(shape, m_TorchConstantInt(&dim)))
|
|
|
|
shapeInts.push_back(dim);
|
|
|
|
else
|
|
|
|
shapeInts.push_back(kUnknownSize);
|
|
|
|
}
|
|
|
|
return Torch::ValueTensorType::get(shapes[0].getContext(), shapeInts, dtype);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Helper function to get the size of the tensor at the given dimension.
|
|
|
|
Value Torch::getTensorDimSize(PatternRewriter &rewriter, Value tensor,
|
|
|
|
int64_t dim) {
|
|
|
|
auto loc = tensor.getLoc();
|
|
|
|
auto dimVal =
|
|
|
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(dim));
|
|
|
|
// Use 'createOrFold' instead of 'create':
|
|
|
|
// If the dimension is a constant, then the AtenSizeIntOp is folded to a
|
|
|
|
// ContantIntOp.
|
|
|
|
return rewriter.createOrFold<AtenSizeIntOp>(loc, tensor, dimVal);
|
|
|
|
}
|
|
|
|
|
2022-11-16 13:57:58 +08:00
|
|
|
// Helper function to squeeze the input tensor at given dim.
|
|
|
|
// Return the squeezed tensor or failure.
|
|
|
|
FailureOr<Value> Torch::squeezeTensor(PatternRewriter &rewriter, Operation *op,
|
|
|
|
Location loc, int64_t dim, Value input) {
|
2024-04-28 05:00:56 +08:00
|
|
|
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
2022-11-16 13:57:58 +08:00
|
|
|
if (!inputType.hasSizes()) {
|
|
|
|
return rewriter.notifyMatchFailure(loc, "input tensor must have size");
|
|
|
|
}
|
|
|
|
SmallVector<int64_t> inputShape{inputType.getSizes()};
|
|
|
|
unsigned inputRank = inputShape.size();
|
|
|
|
dim = toPositiveDim(dim, inputRank);
|
|
|
|
if (!isValidDim(dim, inputRank)) {
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "dimension to be squeezed is an invalid dim");
|
|
|
|
}
|
|
|
|
inputShape.erase(inputShape.begin() + dim);
|
|
|
|
Type squeezedType =
|
|
|
|
inputType.getWithSizesAndDtype(inputShape, inputType.getOptionalDtype());
|
|
|
|
|
|
|
|
Value cstDim = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(dim));
|
|
|
|
// Adding a check to verify if the dimension to be squeezed has size 1 or not.
|
|
|
|
Value cstOne =
|
|
|
|
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, input, cstDim);
|
|
|
|
Value cmp = rewriter.create<Torch::AtenEqIntOp>(loc, dimSize, cstOne);
|
|
|
|
rewriter.create<Torch::RuntimeAssertOp>(
|
|
|
|
loc, cmp,
|
|
|
|
"squeeze operation possible for dim only when input_shape[dim] == 1.");
|
|
|
|
|
|
|
|
Value result =
|
|
|
|
rewriter.create<AtenSqueezeDimOp>(loc, squeezedType, input, cstDim);
|
|
|
|
return result;
|
|
|
|
}
|
2022-12-08 13:46:54 +08:00
|
|
|
|
|
|
|
// Helper function to unsqueeze the input tensor at given dim.
|
|
|
|
// Return the unsqueezed tensor or failure.
|
|
|
|
FailureOr<Value> Torch::unsqueezeTensor(PatternRewriter &rewriter,
|
|
|
|
Operation *op, Value input, Value dim) {
|
2024-04-28 05:00:56 +08:00
|
|
|
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
2022-12-08 13:46:54 +08:00
|
|
|
if (!inputType.hasSizes()) {
|
|
|
|
return rewriter.notifyMatchFailure(op, "input tensor must have size");
|
|
|
|
}
|
2024-05-16 01:09:27 +08:00
|
|
|
FailureOr<Attribute> enc =
|
|
|
|
getSparsityWithDenseLTAtDim(inputType.getOptionalSparsity(), dim);
|
|
|
|
if (failed(enc)) {
|
|
|
|
return failure();
|
|
|
|
}
|
2022-12-08 13:46:54 +08:00
|
|
|
|
|
|
|
SmallVector<int64_t> unsqueezedShape;
|
|
|
|
ArrayRef<int64_t> inputShape = inputType.getSizes();
|
|
|
|
// `input` has a reduced rank. Hence add 1.
|
|
|
|
int64_t unsqueezedRank = inputShape.size() + 1;
|
|
|
|
int64_t dimInt = 0;
|
|
|
|
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
|
|
|
|
dimInt = toPositiveDim(dimInt, unsqueezedRank);
|
|
|
|
if (!isValidDim(dimInt, unsqueezedRank)) {
|
|
|
|
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
|
|
|
}
|
|
|
|
unsqueezedShape.append(inputShape.begin(), inputShape.end());
|
|
|
|
unsqueezedShape.insert(unsqueezedShape.begin() + dimInt, 1);
|
|
|
|
} else {
|
|
|
|
unsqueezedShape.resize(unsqueezedRank, kUnknownSize);
|
|
|
|
}
|
2024-05-16 01:09:27 +08:00
|
|
|
Type unsqueezedType = inputType.getWithSizesAndDtypeAndSparsity(
|
|
|
|
unsqueezedShape, inputType.getOptionalDtype(), enc.value());
|
2022-12-08 13:46:54 +08:00
|
|
|
Value unsqueezed = rewriter.create<AtenUnsqueezeOp>(
|
|
|
|
op->getLoc(), unsqueezedType, input, dim);
|
|
|
|
return unsqueezed;
|
|
|
|
}
|
2023-09-30 07:45:48 +08:00
|
|
|
|
2023-11-08 15:28:30 +08:00
|
|
|
// Checks whether the `shapeA` and `shapeB` are broadcast compatible or not. If
|
|
|
|
// yes, then computes the final broadcast shape.
|
|
|
|
void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc,
|
2024-01-30 01:59:33 +08:00
|
|
|
Value inputA, Value inputB,
|
|
|
|
SmallVector<int64_t> &resultShape,
|
|
|
|
SmallVector<Value> &resultShapeValue) {
|
2023-11-08 15:28:30 +08:00
|
|
|
SmallVector<int64_t> shapeA{
|
2024-04-28 05:00:56 +08:00
|
|
|
cast<BaseTensorType>(inputA.getType()).getSizes()};
|
2023-11-08 15:28:30 +08:00
|
|
|
SmallVector<int64_t> shapeB{
|
2024-04-28 05:00:56 +08:00
|
|
|
cast<BaseTensorType>(inputB.getType()).getSizes()};
|
2023-11-08 15:28:30 +08:00
|
|
|
unsigned rankA = shapeA.size();
|
|
|
|
unsigned rankB = shapeB.size();
|
|
|
|
unsigned minRank = rankA > rankB ? rankB : rankA;
|
|
|
|
// Check whether the shapes of the tensors are broadcastable or not.
|
|
|
|
// Two tensors are “broadcastable” if the following rules hold:
|
|
|
|
// 1.) Each tensor has at least one dimension.
|
|
|
|
// 2.) When iterating over the dimension sizes, starting at the trailing
|
|
|
|
// dimension, the dimension sizes must either be equal, one of them is 1, or
|
|
|
|
// one of them does not exist.
|
|
|
|
for (unsigned i = 0; i < minRank; i++) {
|
|
|
|
Value sizeDimA = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(rankA - i - 1));
|
|
|
|
Value sizeDimB = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(rankB - i - 1));
|
|
|
|
Value sizeInputA =
|
|
|
|
rewriter.createOrFold<AtenSizeIntOp>(loc, inputA, sizeDimA);
|
|
|
|
Value sizeInputB =
|
|
|
|
rewriter.createOrFold<AtenSizeIntOp>(loc, inputB, sizeDimB);
|
|
|
|
Value torchCstOne = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
Value cmpSizeAEqualsSizeB =
|
|
|
|
rewriter.create<Torch::AtenEqIntOp>(loc, sizeInputA, sizeInputB);
|
|
|
|
Value cmpSizeAEqualsOne =
|
|
|
|
rewriter.create<Torch::AtenEqIntOp>(loc, sizeInputA, torchCstOne);
|
|
|
|
Value cmpSizeBEqualsOne =
|
|
|
|
rewriter.create<Torch::AtenEqIntOp>(loc, sizeInputB, torchCstOne);
|
|
|
|
Value anyBoolOpList = rewriter.create<PrimListConstructOp>(
|
|
|
|
loc, Torch::ListType::get(cmpSizeAEqualsOne.getType()),
|
|
|
|
SmallVector<Value>{cmpSizeAEqualsSizeB, cmpSizeAEqualsOne,
|
|
|
|
cmpSizeBEqualsOne});
|
|
|
|
Value cmp = rewriter.create<Torch::AtenAnyBoolOp>(loc, anyBoolOpList);
|
|
|
|
rewriter.create<Torch::RuntimeAssertOp>(
|
|
|
|
loc, cmp, "tensors are not broadcast compatible");
|
|
|
|
}
|
|
|
|
// If we reach here then it means both the shapes are broadcast compatible.
|
|
|
|
resultShape = rankA >= rankB ? shapeA : shapeB;
|
|
|
|
Value shapeTensor = rankA >= rankB ? inputA : inputB;
|
|
|
|
for (unsigned i = 0; i < resultShape.size(); i++) {
|
|
|
|
Value sizeDim = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(i));
|
|
|
|
resultShapeValue.push_back(
|
|
|
|
rewriter.createOrFold<AtenSizeIntOp>(loc, shapeTensor, sizeDim));
|
|
|
|
}
|
|
|
|
|
|
|
|
unsigned resultRank = resultShape.size();
|
|
|
|
for (unsigned i = 0; i < minRank; i++) {
|
|
|
|
Value sizeDimA = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(rankA - i - 1));
|
|
|
|
Value sizeDimB = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(rankB - i - 1));
|
|
|
|
Value sizeInputA =
|
|
|
|
rewriter.createOrFold<AtenSizeIntOp>(loc, inputA, sizeDimA);
|
|
|
|
Value sizeInputB =
|
|
|
|
rewriter.createOrFold<AtenSizeIntOp>(loc, inputB, sizeDimB);
|
|
|
|
resultShapeValue[resultRank - i - 1] =
|
|
|
|
rewriter.create<PrimMaxIntOp>(loc, sizeInputA, sizeInputB);
|
|
|
|
if (shapeA[rankA - i - 1] == kUnknownSize ||
|
|
|
|
shapeB[rankB - i - 1] == kUnknownSize) {
|
|
|
|
resultShape[resultRank - i - 1] = kUnknownSize;
|
|
|
|
} else {
|
|
|
|
resultShape[resultRank - i - 1] =
|
|
|
|
std::max(shapeA[rankA - i - 1], shapeB[rankB - i - 1]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-09-30 07:45:48 +08:00
|
|
|
bool Torch::isAssumingStrictSymbolicShapes(Block *block) {
|
|
|
|
for (Operation *parentOp = block->getParentOp(); parentOp;
|
|
|
|
parentOp = parentOp->getParentOp()) {
|
|
|
|
if (parentOp->hasAttr("torch.assume_strict_symbolic_shapes"))
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
return false;
|
|
|
|
}
|
2023-11-01 11:56:54 +08:00
|
|
|
|
|
|
|
LogicalResult Torch::checkDefaultStrideHelper(Operation *op,
|
|
|
|
PatternRewriter &rewriter,
|
|
|
|
Value opSize, Value opStride,
|
|
|
|
Location loc) {
|
|
|
|
|
|
|
|
SmallVector<int64_t> sizeListInts, strideListInts;
|
|
|
|
if (matchPattern(opSize, m_TorchListOfConstantInts(sizeListInts)) &&
|
|
|
|
matchPattern(opStride, m_TorchListOfConstantInts(strideListInts))) {
|
|
|
|
|
|
|
|
// We only support the cases with default stride values.
|
|
|
|
// For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1])
|
|
|
|
// Here the stride[0] == size[1] * size[2], stride[1] == size[2], and
|
|
|
|
// stride[2] == 1.
|
|
|
|
bool isDefaultStride = true;
|
|
|
|
for (unsigned i = 0; i < strideListInts.size(); i++) {
|
|
|
|
int64_t defaultStride = 1;
|
|
|
|
for (unsigned j = i + 1; j < sizeListInts.size(); j++)
|
|
|
|
defaultStride *= sizeListInts[j];
|
|
|
|
if (defaultStride != strideListInts[i]) {
|
|
|
|
isDefaultStride = false;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (!isDefaultStride)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "only default strides supported for empty_strided op");
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
|
|
|
} else {
|
|
|
|
SmallVector<Value> sizeListValues;
|
|
|
|
if (!getListConstructElements(opSize, sizeListValues))
|
|
|
|
return rewriter.notifyMatchFailure(op, "couldn't get size list values");
|
|
|
|
SmallVector<Value> strideListValues;
|
|
|
|
if (!getListConstructElements(opStride, strideListValues))
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"couldn't get stride list values.");
|
|
|
|
SmallVector<Value> boolVector;
|
|
|
|
for (unsigned i = 0; i < strideListValues.size(); i++) {
|
|
|
|
Value defaultStride = rewriter.createOrFold<Torch::ConstantIntOp>(
|
|
|
|
loc, rewriter.getI64IntegerAttr(1));
|
|
|
|
for (unsigned j = i + 1; j < sizeListValues.size(); j++) {
|
|
|
|
defaultStride = rewriter.createOrFold<Torch::AtenMulIntOp>(
|
|
|
|
loc, defaultStride, sizeListValues[j]);
|
|
|
|
}
|
|
|
|
boolVector.push_back(rewriter.createOrFold<Torch::AtenEqIntOp>(
|
|
|
|
loc, defaultStride, strideListValues[i]));
|
|
|
|
}
|
|
|
|
Value allBoolOpList = rewriter.createOrFold<PrimListConstructOp>(
|
|
|
|
loc, Torch::ListType::get(rewriter.getType<Torch::BoolType>()),
|
|
|
|
boolVector);
|
|
|
|
Value cmp = rewriter.createOrFold<Torch::AtenAllBoolOp>(loc, allBoolOpList);
|
|
|
|
rewriter.createOrFold<Torch::RuntimeAssertOp>(
|
|
|
|
loc, cmp, "not all strides are default");
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
}
|
2023-12-21 23:29:22 +08:00
|
|
|
|
|
|
|
// Helper to create a tensor filled with the given scalar. Scalar would be
|
|
|
|
// converted the to the element type of the given tensor type.
|
|
|
|
Value Torch::createInitTensor(PatternRewriter &rewriter, Location loc,
|
|
|
|
BaseTensorType resultType, Value scalar,
|
|
|
|
Value sizeList) {
|
|
|
|
assert(resultType.hasDtype() && "result must have dtype");
|
|
|
|
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
|
|
|
Value dtype = getDtypeIntValueForType(rewriter, loc, resultType.getDtype());
|
|
|
|
return rewriter.create<AtenFullOp>(loc, resultType, sizeList, scalar, dtype,
|
|
|
|
/*layout=*/noneVal,
|
|
|
|
/*device=*/noneVal,
|
|
|
|
/*memory_format=*/noneVal);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Helper to create a rank 0 tensor filled with the given `scalar`. `scalar`
|
|
|
|
// would be converted to the element type of the given `inputType`.
|
|
|
|
Value Torch::createRank0Tensor(PatternRewriter &rewriter, Location loc,
|
|
|
|
BaseTensorType inputType, Value scalar) {
|
|
|
|
assert(inputType.hasDtype() && "input must have dtype");
|
|
|
|
SmallVector<int64_t> sizes;
|
2024-04-28 05:00:56 +08:00
|
|
|
BaseTensorType rank0TensorTy = cast<BaseTensorType>(
|
|
|
|
inputType.getWithSizesAndDtype(ArrayRef(sizes), inputType.getDtype()));
|
2023-12-21 23:29:22 +08:00
|
|
|
Value dimList = rewriter.create<PrimListConstructOp>(
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())),
|
|
|
|
ValueRange{});
|
|
|
|
return createInitTensor(rewriter, loc, rank0TensorTy, scalar, dimList);
|
|
|
|
}
|
2024-01-03 20:55:56 +08:00
|
|
|
|
|
|
|
LogicalResult Torch::getTransposedType(BaseTensorType inType, int64_t dimA,
|
2024-01-30 01:59:33 +08:00
|
|
|
int64_t dimB, Type &transposedType) {
|
2024-01-03 20:55:56 +08:00
|
|
|
if (!inType.hasSizes())
|
|
|
|
return failure();
|
|
|
|
SmallVector<int64_t> shape(inType.getSizes());
|
|
|
|
int64_t tmp = shape[dimA];
|
|
|
|
shape[dimA] = shape[dimB];
|
|
|
|
shape[dimB] = tmp;
|
|
|
|
transposedType = inType.getWithSizesAndDtype(llvm::ArrayRef(shape),
|
|
|
|
inType.getOptionalDtype());
|
2024-06-03 22:59:39 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
LogicalResult Torch::getPermutedType(BaseTensorType inType,
|
|
|
|
SmallVector<int64_t> permuteDims,
|
|
|
|
Type &permutedType) {
|
|
|
|
if (!inType.hasSizes())
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
SmallVector<int64_t> shape(inType.getSizes());
|
|
|
|
if (shape.size() != permuteDims.size())
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
SmallVector<int64_t> permutedShape;
|
|
|
|
for (unsigned i = 0; i < shape.size(); i++)
|
|
|
|
permutedShape.push_back(shape[permuteDims[i]]);
|
|
|
|
permutedType = inType.getWithSizesAndDtype(llvm::ArrayRef(permutedShape),
|
|
|
|
inType.getOptionalDtype());
|
2024-01-03 20:55:56 +08:00
|
|
|
return success();
|
|
|
|
}
|
2024-03-13 06:07:45 +08:00
|
|
|
|
|
|
|
Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) {
|
|
|
|
if (inputType.isF16())
|
|
|
|
return rewriter.getF32Type();
|
|
|
|
if (inputType.isBF16())
|
|
|
|
return rewriter.getF32Type();
|
2024-04-28 05:00:56 +08:00
|
|
|
if (isa<Float32Type>(inputType))
|
2024-03-13 06:07:45 +08:00
|
|
|
return rewriter.getF32Type();
|
2024-04-28 05:00:56 +08:00
|
|
|
if (isa<Float64Type>(inputType))
|
2024-03-13 06:07:45 +08:00
|
|
|
return rewriter.getF64Type();
|
|
|
|
if (inputType.isFloat8E5M2())
|
|
|
|
return rewriter.getF32Type();
|
|
|
|
if (inputType.isFloat8E4M3FN())
|
|
|
|
return rewriter.getF32Type();
|
|
|
|
if (inputType.isFloat8E5M2FNUZ())
|
|
|
|
return rewriter.getF32Type();
|
|
|
|
if (inputType.isFloat8E4M3FNUZ())
|
|
|
|
return rewriter.getF32Type();
|
|
|
|
if (inputType.isSignedInteger(8))
|
|
|
|
return rewriter.getI64Type();
|
|
|
|
if (inputType.isUnsignedInteger(8))
|
|
|
|
return rewriter.getI64Type();
|
|
|
|
if (inputType.isSignedInteger(16))
|
|
|
|
return rewriter.getI64Type();
|
|
|
|
if (inputType.isSignedInteger(32))
|
|
|
|
return rewriter.getI64Type();
|
|
|
|
if (inputType.isSignedInteger(64))
|
|
|
|
return rewriter.getI64Type();
|
2024-03-15 07:40:40 +08:00
|
|
|
return inputType;
|
2024-03-13 06:07:45 +08:00
|
|
|
}
|