//===----------------------------------------------------------------------===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "mlir/IR/BuiltinDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; int64_t Torch::toPositiveDim(int64_t dim, int64_t inputRank) { return dim >= 0 ? dim : dim + inputRank; } bool Torch::isValidDim(int64_t dim, int64_t inputRank) { return dim >= 0 && dim < inputRank; } std::optional Torch::matchLegalConstantIndexIntoListOfSize(Value v, int64_t length) { int64_t dim; if (!matchPattern(v, m_TorchConstantInt(&dim))) return std::nullopt; dim = toPositiveDim(dim, length); if (!isValidDim(dim, length)) return std::nullopt; return dim; } bool Torch::getListConstructElements(Value v, SmallVectorImpl &elems) { auto listConstruct = v.getDefiningOp(); if (!listConstruct) return false; elems = llvm::to_vector<4>(listConstruct.getElements()); return true; } torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) { if (type.isa()) return torch_upstream::ScalarType::Float; if (type.isa()) return torch_upstream::ScalarType::Double; if (type.isSignedInteger(64)) return torch_upstream::ScalarType::Long; if (type.isSignedInteger(32)) return torch_upstream::ScalarType::Int; if (type.isSignlessInteger(1)) return torch_upstream::ScalarType::Bool; if (type.isBF16()) return torch_upstream::ScalarType::BFloat16; if (type.isF16()) return torch_upstream::ScalarType::Half; if (type.isUnsignedInteger(8)) return torch_upstream::ScalarType::Byte; if (type.isSignedInteger(8)) return torch_upstream::ScalarType::Char; if (type.isa()) { mlir::Type complexElemType = type.cast().getElementType(); if (complexElemType.isF32()) return torch_upstream::ScalarType::ComplexHalf; if (complexElemType.isF64()) return torch_upstream::ScalarType::ComplexFloat; if (complexElemType.isF128()) return torch_upstream::ScalarType::ComplexDouble; } llvm::report_fatal_error("unhandled type for getScalarTypeForType"); } Type Torch::getTypeForTorchType( MLIRContext *context, Type type, mlir::IntegerType::SignednessSemantics signedness) { if (type.isa()) return IntegerType::get(context, 64, signedness); if (type.isa()) return Float64Type::get(context); llvm::report_fatal_error("unhandled type for getTypeForTorchType"); } FailureOr Torch::getTypeForScalarType(MLIRContext *context, torch_upstream::ScalarType dtypeInt, mlir::IntegerType::SignednessSemantics signedness) { switch (dtypeInt) { case torch_upstream::ScalarType::Float: return Float32Type::get(context); case torch_upstream::ScalarType::Double: return Float64Type::get(context); case torch_upstream::ScalarType::Long: return IntegerType::get(context, 64, signedness); case torch_upstream::ScalarType::Int: return IntegerType::get(context, 32, signedness); case torch_upstream::ScalarType::Bool: return IntegerType::get(context, 1); case torch_upstream::ScalarType::BFloat16: return mlir::FloatType::getBF16(context); case torch_upstream::ScalarType::Half: return mlir::FloatType::getF16(context); case torch_upstream::ScalarType::Byte: return mlir::IntegerType::get(context, 8, mlir::IntegerType::Unsigned); case torch_upstream::ScalarType::Char: return mlir::IntegerType::get(context, 8, signedness); case torch_upstream::ScalarType::ComplexHalf: return mlir::ComplexType::get(Float32Type::get(context)); case torch_upstream::ScalarType::ComplexFloat: return mlir::ComplexType::get(Float64Type::get(context)); case torch_upstream::ScalarType::ComplexDouble: return mlir::ComplexType::get(Float128Type::get(context)); case torch_upstream::ScalarType::Undefined: return failure(); default: llvm::report_fatal_error("unhandled type for getTypeForScalarType"); } } FailureOr Torch::getTorchTypeForScalarType(MLIRContext *context, torch_upstream::ScalarType dtypeInt) { switch (dtypeInt) { case torch_upstream::ScalarType::Double: return Torch::FloatType::get(context); case torch_upstream::ScalarType::Long: return Torch::IntType::get(context); case torch_upstream::ScalarType::Undefined: default: return failure(); } } Type Torch::getDefaultDtypeForTorchScalar(Type type) { MLIRContext *context = type.getContext(); if (type.isa()) { // For now, use float32 which is the initial default dtype returned by // `torch.get_default_dtype`. return Float32Type::get(context); } if (type.isa()) return IntegerType::get(context, 64, IntegerType::Signed); if (type.isa()) return IntegerType::get(context, 1); llvm_unreachable( "getDefaultDtypeForTorchScalar called on an unsupported type"); } Type Torch::getBuiltInTypeForTorchScalar(Type type) { MLIRContext *context = type.getContext(); if (type.isa()) return Float64Type::get(context); if (type.isa()) return IntegerType::get(context, 64, IntegerType::Signed); if (type.isa()) return IntegerType::get(context, 1); llvm_unreachable( "getBuiltInTypeForTorchScalar called on an unsupported type"); } Value Torch::getDtypeIntValueForType(PatternRewriter &rewriter, Location loc, Type dtype) { int intType = (int)getScalarTypeForType(dtype); return rewriter.create(loc, rewriter.getI64IntegerAttr(intType)); } // Helper to convert a tensor to a specific scalar type. Value Torch::convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input, Type dtype) { BaseTensorType origType = input.getType().cast(); Type newType = origType.getWithSizesAndDtype(origType.getSizes(), dtype); // `convertIntVal` contains the corresponding integer for the dtype which is // used by the aten.to.dtype op. Value convertIntVal = getDtypeIntValueForType(rewriter, loc, dtype); Value falseVal = rewriter.create(loc, false); Value noneVal = rewriter.create(loc); Value converted = rewriter.create( loc, newType, input, convertIntVal, falseVal, falseVal, noneVal); return converted; } bool Torch::isBuiltInType(Type type) { return isa(type.getDialect()); } std::optional Torch::getTensorRank(Value tensor) { BaseTensorType tensorType = tensor.getType().cast(); if (!tensorType.hasSizes()) return std::nullopt; return tensorType.getSizes().size(); } bool Torch::isViewLikeOp(Operation *op) { // AtenContiguousOp might return a view, so this is conservatively // correct. We could potentially be more precise and identify the cases // that it does not return a view and treat those as having value // semantics. return isa(op); } Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, Location loc, float value, Type dtype) { // Creating constants satisfying backend contract. if (dtype.isInteger(64) || dtype.isInteger(32) || dtype.isInteger(8) || dtype.isInteger(1)) return rewriter.create( loc, rewriter.getI64IntegerAttr((int64_t)value)); if (dtype.isF64() || dtype.isF32() || dtype.isF16() || dtype.isBF16()) return rewriter.create(loc, rewriter.getF64FloatAttr(value)); llvm::report_fatal_error( "unhandled type for getConstantWithGivenDtypeAndValue"); } // Return the number of elements of a tensor if the shape is static; otherwise, // return -1. int64_t Torch::getNumberOfElements(RankedTensorType inputType) { if (!inputType.hasStaticShape()) return -1; SmallVector inputShape = makeShapeTorchCompatible(inputType.getShape()); int64_t numel = 1; for (int64_t i = 0; i < inputType.getRank(); i++) numel *= inputShape[i]; return numel; } SmallVector Torch::makeShapeLLVMCompatible(ArrayRef shape) { SmallVector updatedShape(shape); int64_t kDynamic = ShapedType::kDynamic; for (unsigned i = 0; i < shape.size(); i++) { assert(shape[i] >= 0 || shape[i] == kUnknownSize); if (shape[i] == kUnknownSize) updatedShape[i] = kDynamic; } return updatedShape; } SmallVector Torch::makeShapeTorchCompatible(ArrayRef shape) { SmallVector updatedShape(shape); int64_t kDynamic = ShapedType::kDynamic; for (unsigned i = 0; i < shape.size(); i++) { assert(shape[i] >= 0 || shape[i] == kDynamic); if (shape[i] == kDynamic) updatedShape[i] = kUnknownSize; } return updatedShape; } // Helper function to squeeze the input tensor at given dim. // Return the squeezed tensor or failure. FailureOr Torch::squeezeTensor(PatternRewriter &rewriter, Operation *op, Location loc, int64_t dim, Value input) { BaseTensorType inputType = input.getType().cast(); if (!inputType.hasSizes()) { return rewriter.notifyMatchFailure(loc, "input tensor must have size"); } SmallVector inputShape{inputType.getSizes()}; unsigned inputRank = inputShape.size(); dim = toPositiveDim(dim, inputRank); if (!isValidDim(dim, inputRank)) { return rewriter.notifyMatchFailure( op, "dimension to be squeezed is an invalid dim"); } inputShape.erase(inputShape.begin() + dim); Type squeezedType = inputType.getWithSizesAndDtype(inputShape, inputType.getOptionalDtype()); Value cstDim = rewriter.create( loc, rewriter.getI64IntegerAttr(dim)); // Adding a check to verify if the dimension to be squeezed has size 1 or not. Value cstOne = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); Value dimSize = rewriter.create(loc, input, cstDim); Value cmp = rewriter.create(loc, dimSize, cstOne); rewriter.create( loc, cmp, "squeeze operation possible for dim only when input_shape[dim] == 1."); Value result = rewriter.create(loc, squeezedType, input, cstDim); return result; } // Helper function to unsqueeze the input tensor at given dim. // Return the unsqueezed tensor or failure. FailureOr Torch::unsqueezeTensor(PatternRewriter &rewriter, Operation *op, Value input, Value dim) { BaseTensorType inputType = input.getType().cast(); if (!inputType.hasSizes()) { return rewriter.notifyMatchFailure(op, "input tensor must have size"); } SmallVector unsqueezedShape; ArrayRef inputShape = inputType.getSizes(); // `input` has a reduced rank. Hence add 1. int64_t unsqueezedRank = inputShape.size() + 1; int64_t dimInt = 0; if (matchPattern(dim, m_TorchConstantInt(&dimInt))) { dimInt = toPositiveDim(dimInt, unsqueezedRank); if (!isValidDim(dimInt, unsqueezedRank)) { return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); } unsqueezedShape.append(inputShape.begin(), inputShape.end()); unsqueezedShape.insert(unsqueezedShape.begin() + dimInt, 1); } else { unsqueezedShape.resize(unsqueezedRank, kUnknownSize); } Type unsqueezedType = inputType.getWithSizesAndDtype( unsqueezedShape, inputType.getOptionalDtype()); Value unsqueezed = rewriter.create( op->getLoc(), unsqueezedType, input, dim); return unsqueezed; }