torch-mlir/lib/Dialect/Torch/Utils/Utils.cpp

98 lines
3.5 KiB
C++
Raw Normal View History

//===----------------------------------------------------------------------===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
Add type promotion code to refine types. The types have different levels of categories: where complex > floating > integral > boolean (> means left hand side has higher category). The operands have different levels of priorities where: dimensioned tensor > 0-dim tensor > scalar == wrapped 0-dim tensor. This is represented by the `ResultTypeState.dimResult`, `ResultTypeState.zeroResult` and `ResultTypeState..wrappedResult` in the source code. For operands of the same priorities, the result type should be the highest categories with sufficient width to hold all operands. By default, only the highest priority operands participate in the type promotion logic. Lower priority operands participate if they are in a higher category than any higher priority operands. For example, <[],f32> (lower priority) and <[1], si64> tensor would result in <[?],f32> tensor because floating > integeral. Another example <[],f64> (lower priority) and <[1], f32> tensor would result in <[?], f32> tensor because f32 and f64 are the same category. The ScalarType enum definition, type promotion table, ResultTypeState struct definition and some helpers are copied from aten/src/ATen/native/TypeProperties.* Other references: - https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc - https://github.com/pytorch/pytorch/issues/9515 Other minor changes: 1. Fix `visitExpandLikeOp` to consider cases where the given sizes list size is larger than the input rank. 2. Add back the somehow deleted `torch.aten.softmax.int` tests in decompose-complex-ops.mlir.
2021-10-21 03:31:28 +08:00
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
int64_t Torch::toPositiveDim(int64_t dim, int64_t inputRank) {
return dim >= 0 ? dim : dim + inputRank;
}
bool Torch::isValidDim(int64_t dim, int64_t inputRank) {
return dim >= 0 && dim < inputRank;
}
llvm::Optional<int64_t>
Torch::matchLegalConstantIndexIntoListOfSize(Value v, int64_t length) {
int64_t dim;
if (!matchPattern(v, m_TorchConstantInt(&dim)))
return llvm::None;
dim = toPositiveDim(dim, length);
if (!isValidDim(dim, length))
return llvm::None;
return dim;
}
bool Torch::getListConstructElements(Value v, SmallVectorImpl<Value> &elems) {
Add type promotion code to refine types. The types have different levels of categories: where complex > floating > integral > boolean (> means left hand side has higher category). The operands have different levels of priorities where: dimensioned tensor > 0-dim tensor > scalar == wrapped 0-dim tensor. This is represented by the `ResultTypeState.dimResult`, `ResultTypeState.zeroResult` and `ResultTypeState..wrappedResult` in the source code. For operands of the same priorities, the result type should be the highest categories with sufficient width to hold all operands. By default, only the highest priority operands participate in the type promotion logic. Lower priority operands participate if they are in a higher category than any higher priority operands. For example, <[],f32> (lower priority) and <[1], si64> tensor would result in <[?],f32> tensor because floating > integeral. Another example <[],f64> (lower priority) and <[1], f32> tensor would result in <[?], f32> tensor because f32 and f64 are the same category. The ScalarType enum definition, type promotion table, ResultTypeState struct definition and some helpers are copied from aten/src/ATen/native/TypeProperties.* Other references: - https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc - https://github.com/pytorch/pytorch/issues/9515 Other minor changes: 1. Fix `visitExpandLikeOp` to consider cases where the given sizes list size is larger than the input rank. 2. Add back the somehow deleted `torch.aten.softmax.int` tests in decompose-complex-ops.mlir.
2021-10-21 03:31:28 +08:00
auto listConstruct = v.getDefiningOp<PrimListConstructOp>();
if (!listConstruct)
return false;
elems = llvm::to_vector<4>(listConstruct.elements());
return true;
}
torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
if (type.isa<Float32Type>())
return torch_upstream::ScalarType::Float;
if (type.isa<Float64Type>())
return torch_upstream::ScalarType::Double;
if (type.isSignedInteger(64))
return torch_upstream::ScalarType::Long;
if (type.isSignedInteger(32))
return torch_upstream::ScalarType::Int;
if (type.isUnsignedInteger(1))
return torch_upstream::ScalarType::Bool;
llvm::report_fatal_error("unhandled type for getScalarTypeForType");
}
Type Torch::getTypeForScalarType(
MLIRContext *context, torch_upstream::ScalarType dtypeInt,
mlir::IntegerType::SignednessSemantics signedness) {
switch (dtypeInt) {
case torch_upstream::ScalarType::Float:
return Float32Type::get(context);
case torch_upstream::ScalarType::Double:
return Float64Type::get(context);
case torch_upstream::ScalarType::Long:
return IntegerType::get(context, 64, signedness);
case torch_upstream::ScalarType::Int:
return IntegerType::get(context, 32, signedness);
case torch_upstream::ScalarType::Bool:
return IntegerType::get(context, 1);
default:
return Type();
}
}
Value Torch::getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
Type dtype) {
int intType = (int)getScalarTypeForType(dtype);
return rewriter.create<ConstantIntOp>(loc,
rewriter.getI64IntegerAttr(intType));
}
// Helper to convert a tensor to a specific scalar type.
Value Torch::convertTensorToDtype(PatternRewriter &rewriter, Location loc,
Value input, Type dtype) {
BaseTensorType origType = input.getType().cast<BaseTensorType>();
Type newType = origType.getWithSizesAndDtype(origType.getSizes(), dtype);
// `convertIntVal` contains the corresponding integer for the dtype which is
// used by the aten.to.dtype op.
Value convertIntVal = getDtypeIntValueForType(rewriter, loc, dtype);
Value falseVal = rewriter.create<ConstantBoolOp>(loc, false);
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
Value converted = rewriter.create<AtenToDtypeOp>(
loc, newType, input, convertIntVal, falseVal, falseVal, noneVal);
return converted;
}