//===----------------------------------------------------------------------===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "mlir/IR/BuiltinDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; int64_t Torch::toPositiveDim(int64_t dim, int64_t inputRank) { return dim >= 0 ? dim : dim + inputRank; } bool Torch::isValidDim(int64_t dim, int64_t inputRank) { return dim >= 0 && dim < inputRank; } llvm::Optional Torch::matchLegalConstantIndexIntoListOfSize(Value v, int64_t length) { int64_t dim; if (!matchPattern(v, m_TorchConstantInt(&dim))) return llvm::None; dim = toPositiveDim(dim, length); if (!isValidDim(dim, length)) return llvm::None; return dim; } bool Torch::getListConstructElements(Value v, SmallVectorImpl &elems) { auto listConstruct = v.getDefiningOp(); if (!listConstruct) return false; elems = llvm::to_vector<4>(listConstruct.elements()); return true; } torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) { if (type.isa()) return torch_upstream::ScalarType::Float; if (type.isa()) return torch_upstream::ScalarType::Double; if (type.isSignedInteger(64)) return torch_upstream::ScalarType::Long; if (type.isSignedInteger(32)) return torch_upstream::ScalarType::Int; if (type.isSignlessInteger(1)) return torch_upstream::ScalarType::Bool; if (type.isBF16()) return torch_upstream::ScalarType::BFloat16; llvm::report_fatal_error("unhandled type for getScalarTypeForType"); } Type Torch::getTypeForScalarType( MLIRContext *context, torch_upstream::ScalarType dtypeInt, mlir::IntegerType::SignednessSemantics signedness) { switch (dtypeInt) { case torch_upstream::ScalarType::Float: return Float32Type::get(context); case torch_upstream::ScalarType::Double: return Float64Type::get(context); case torch_upstream::ScalarType::Long: return IntegerType::get(context, 64, signedness); case torch_upstream::ScalarType::Int: return IntegerType::get(context, 32, signedness); case torch_upstream::ScalarType::Bool: return IntegerType::get(context, 1); case torch_upstream::ScalarType::BFloat16: return mlir::FloatType::getBF16(context); default: return Type(); } } Type Torch::getTorchTypeForScalarType(MLIRContext *context, torch_upstream::ScalarType dtypeInt) { switch (dtypeInt) { case torch_upstream::ScalarType::Double: return Torch::FloatType::get(context); case torch_upstream::ScalarType::Long: return Torch::IntType::get(context); default: llvm::report_fatal_error( "Unsupported scalar type to Torch type conversion"); } } Value Torch::getDtypeIntValueForType(PatternRewriter &rewriter, Location loc, Type dtype) { int intType = (int)getScalarTypeForType(dtype); return rewriter.create(loc, rewriter.getI64IntegerAttr(intType)); } // Helper to convert a tensor to a specific scalar type. Value Torch::convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input, Type dtype) { BaseTensorType origType = input.getType().cast(); Type newType = origType.getWithSizesAndDtype(origType.getSizes(), dtype); // `convertIntVal` contains the corresponding integer for the dtype which is // used by the aten.to.dtype op. Value convertIntVal = getDtypeIntValueForType(rewriter, loc, dtype); Value falseVal = rewriter.create(loc, false); Value noneVal = rewriter.create(loc); Value converted = rewriter.create( loc, newType, input, convertIntVal, falseVal, falseVal, noneVal); return converted; } bool Torch::isBuiltInType(Type type) { return isa(type.getDialect()); } int Torch::getTensorRank(Value tensor) { int tensorRank = -1; BaseTensorType tensorType = tensor.getType().cast(); if (tensorType.hasSizes()) { ArrayRef tensorShape = tensorType.getSizes(); tensorRank = tensorShape.size(); } return tensorRank; }