diff --git a/include/torch-mlir/Conversion/Utils/Utils.h b/include/torch-mlir/Conversion/Utils/Utils.h new file mode 100644 index 000000000..e492e0efc --- /dev/null +++ b/include/torch-mlir/Conversion/Utils/Utils.h @@ -0,0 +1,84 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_UTILS_H +#define TORCHMLIR_CONVERSION_UTILS_H + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace torch { +namespace Torch { + +LogicalResult verifyLinalgCompatibleTypes(Operation *op, + PatternRewriter &rewriter); + +LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op, Value v); + +Value toPositiveDimDynamic(OpBuilder &b, Location loc, Value dim, + Value inputRank); + +void assertIsValidDim(OpBuilder &b, Location loc, Value dim, Value inputRank); + +bool isConstantIntListMatching(Value value, SmallVectorImpl &expects); + +void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDim, + Value rhsDim); + +// Creates a tensor with required `sizes` and `elemTy` and fills it with +// initElem. +Value createInitTensor(OpBuilder &b, Location loc, ValueRange sizes, + Type elemTy, Value initElem); + +Value castIntToIndex(OpBuilder &b, Location loc, Value v); + +Value castIndexToInt(OpBuilder &b, Location loc, Value idx); + +Value getDimOp(OpBuilder &b, Location loc, Value v, int dim); + +SmallVector getTensorSizesUntilDim(OpBuilder &b, Location loc, + Value tensor, int dim); + +SmallVector getTensorSizes(OpBuilder &b, Location loc, Value tensor); + +Value getTensorSize(OpBuilder &b, Location loc, Value tensor); + +Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes, + Type elemTy); + +// Creates a constant of type `elemType` with value `val`. +Value getConstant(OpBuilder &b, Location loc, int64_t val, Type elemType); + +SmallVector getAsConstantIntValues(OpBuilder &b, Location loc, + SmallVectorImpl &ints); + +SmallVector getAsConstantIndexValues(OpBuilder &b, Location loc, + SmallVectorImpl &ints); + +// This is a temporary solution to deal with types that are not fully supported +// like list, dict. For those container tyes, this helper can be used to +// convert their elements to valid target type. +// TODO: remove this when list gets full support. +SmallVector getTypeConvertedValues(OpBuilder &b, Location loc, + TypeConverter *converter, + SmallVectorImpl &vs); + +// Convert a scalar value to the target type. The scalar value can be an element +// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype +// should be converted builtin types. +Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, + Type dtype); + +} // namespace Torch +} // namespace torch +} // namespace mlir + +#endif // TORCHMLIR_CONVERSION_UTILS_H diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 9f9c6c653..dad3f6be7 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(TorchToLinalg) add_subdirectory(TorchToSCF) add_subdirectory(TorchToStd) add_subdirectory(TorchToTosa) +add_subdirectory(Utils) # TODO: Automate this with add_torch_mlir_conversion_library. #get_property(torch_mlir_conversion_libs GLOBAL PROPERTY TORCH_MLIR_CONVERSION_LIBS) @@ -20,5 +21,6 @@ add_mlir_library(TorchMLIRConversionPasses TorchMLIRTorchToSCF TorchMLIRTorchToStd TorchMLIRTorchToTosa + TorchMLIRConversionUtils #${torch_mlir_conversion_libs} ) diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 6b929979e..8ade5762b 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -25,6 +25,7 @@ #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" @@ -57,117 +58,6 @@ using namespace mlir::torch::torch_upstream; // For ScalarType and type // that these patterns become mostly mechanical associations of // "aten.foo -> linalg.foo". -static LogicalResult verifyLinalgCompatibleTypes(Operation *op, - PatternRewriter &rewriter) { - // Check the value tensor is ranked as expected by Linalg. - // TODO: Remove this check but use a separate verification pass to verify the - // invariants expected by later passes. - auto isValidLinalgType = [](Type type) { - auto tensor = type.dyn_cast(); - return !tensor || - tensor.toBuiltinTensor().dyn_cast_or_null(); - }; - - bool valid = llvm::all_of(op->getOperandTypes(), isValidLinalgType) && - llvm::all_of(op->getResultTypes(), isValidLinalgType); - if (!valid) - return rewriter.notifyMatchFailure(op, "type cannot be lowered to linalg"); - return success(); -} - -static LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op, - Value v) { - Type type = v.getType(); - if (type.isa() || type.isa() || - type.isa()) - return rewriter.notifyMatchFailure(op, "unimplemented None type arg"); - return success(); -} - -// Generate IR: dim = dim >= 0 ? dim : dim + inputRank -static Value toPositiveDimDynamic(OpBuilder &b, Location loc, Value dim, - Value inputRank) { - assert(dim.getType().isa() && - "dim arg of toPositiveDim must be integer type"); - Value dimAddInputRank = b.create(loc, dim, inputRank); - Value cst0 = - b.create(loc, b.getZeroAttr(inputRank.getType())); - Value predDimGEZero = - b.create(loc, arith::CmpIPredicate::sge, dim, cst0); - Value dimInt = - b.create(loc, predDimGEZero, dim, dimAddInputRank); - return dimInt; -} - -// Generate IR: assert(dim >= 0 && dim < inputRank) -static void assertIsValidDim(OpBuilder &b, Location loc, Value dim, - Value inputRank) { - assert(dim.getType().isa() && - "dim arg of assertIsValidDim must be integer type"); - Value cst0 = - b.create(loc, b.getZeroAttr(inputRank.getType())); - Value predGEZero = - b.create(loc, arith::CmpIPredicate::sge, dim, cst0); - b.create( - loc, predGEZero, b.getStringAttr("dim must be greater or equal to zero")); - Value predLTInputRank = - b.create(loc, arith::CmpIPredicate::slt, dim, inputRank); - b.create(loc, predLTInputRank, - b.getStringAttr("dim must be smaller than inputRank")); -} - -// Hack to deal with the Torch list type arguments which is not supported end -// to end. Constant values can be be extracted directly and non constant -// list values are not supported. -// TODO: loose this constraint when properly support list type -static bool isConstantIntListMatching(Value value, - SmallVectorImpl &expects) { - SmallVector intValues; - if (!matchPattern(value, m_TorchConstantIntList(intValues))) - return false; - - if (intValues.size() != expects.size()) - return false; - - for (auto it : llvm::zip(intValues, expects)) { - if (std::get<0>(it) != std::get<1>(it)) - return false; - } - return true; -} - -static Value castIntToIndex(OpBuilder &b, Location loc, Value v) { - assert(v.getType().isa() && "must be called with integer type"); - return b.create(loc, b.getIndexType(), v); -} - -static Value castIndexToInt(OpBuilder &b, Location loc, Value idx) { - assert(idx.getType().isa() && "must be called with integer type"); - return b.create(loc, b.getI64Type(), idx); -} - -static Value getDimOp(OpBuilder &b, Location loc, Value v, int dim) { - return b.createOrFold(loc, v, dim); -} - -static void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDim, - Value rhsDim) { - Type lhsType = lhsDim.getType(); - Type rhsType = rhsDim.getType(); - auto checkIntOrIndex = [](Type type) { - assert(type.isa() || - type.isa() && "must be either integer or index type"); - }; - checkIntOrIndex(lhsType); - checkIntOrIndex(rhsType); - Value lhsDimInt = lhsType.isIndex() ? castIndexToInt(b, loc, lhsDim) : lhsDim; - Value rhsDimInt = rhsType.isIndex() ? castIndexToInt(b, loc, rhsDim) : rhsDim; - Value contractingDimEqual = b.create( - loc, arith::CmpIPredicate::eq, lhsDimInt, rhsDimInt); - b.create(loc, contractingDimEqual, - b.getStringAttr("mismatching contracting dimension")); -} - template static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type, @@ -199,64 +89,6 @@ static Value createLessThan(OpBuilder &b, Location loc, Type elementalType, b, loc, elementalType, lhs, rhs); } -static SmallVector getTensorSizesUntilDim(OpBuilder &b, Location loc, - Value tensor, int dim) { - RankedTensorType type = tensor.getType().cast(); - assert(dim < type.getRank() && - "The given dim must be smaller than tensor rank"); - (void)type; - SmallVector sizes; - for (int i = 0; i <= dim; i++) - sizes.push_back(getDimOp(b, loc, tensor, i)); - return sizes; -} - -static SmallVector getTensorSizes(OpBuilder &b, Location loc, - Value tensor) { - RankedTensorType type = tensor.getType().cast(); - return getTensorSizesUntilDim(b, loc, tensor, type.getRank() - 1); -} - -static Value getTensorSize(OpBuilder &b, Location loc, Value tensor) { - SmallVector sizes(getTensorSizes(b, loc, tensor)); - Value productResult = b.create(loc, b.getIndexAttr(1)); - for (Value size : sizes) - productResult = b.create(loc, productResult, size); - return castIndexToInt(b, loc, productResult); -} - -static Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes, - Type elemTy) { - Value initTensor = b.create(loc, sizes, elemTy); - RankedTensorType type = initTensor.getType().cast(); - Value c0 = - b.create(loc, b.getZeroAttr(type.getElementType())); - return b.create(loc, c0, initTensor).getResult(0); -} - -// Creates a tensor with required `sizes` and `elemTy` and fills it with -// initElem. -static Value createInitTensor(OpBuilder &b, Location loc, ValueRange sizes, - Type elemTy, Value initElem) { - Value initTensor = b.create(loc, sizes, elemTy); - return b.create(loc, initElem, initTensor).getResult(0); -} -// Creates a constant of type `elemType` with value `val`. -static Value getConstant(OpBuilder &b, Location loc, int64_t val, - Type elemType) { - Attribute attr = {}; - if (elemType.isa()) - attr = b.getFloatAttr(elemType, val); - if (elemType.isa()) - attr = b.getIndexAttr(val); - if (elemType.isa()) - attr = b.getIntegerAttr( - elemType, APInt(elemType.cast().getWidth(), val)); - if (!attr) - return nullptr; - return b.create(loc, elemType, attr); -} - // Helper function to caculate the output tensor dims for convolution-like ops. // Along each dim: // dim_out = @@ -285,42 +117,12 @@ static Value getOutputDimForConvOps(OpBuilder &b, Location loc, Value in, return castIntToIndex(b, loc, out); } -static SmallVector -getAsConstantIntValues(OpBuilder &b, Location loc, - SmallVectorImpl &ints) { - return llvm::to_vector<4>(llvm::map_range(ints, [&](int64_t val) -> Value { - return b.create(loc, - b.getIntegerAttr(b.getI64Type(), val)); - })); -} - -static SmallVector -getAsConstantIndexValues(OpBuilder &b, Location loc, - SmallVectorImpl &ints) { - return llvm::to_vector<4>(llvm::map_range(ints, [&](int64_t val) -> Value { - return b.create(loc, b.getIndexAttr(val)); - })); -} - static SmallVector getAsOpFoldResult(OpBuilder &b, Location loc, SmallVectorImpl &ints) { return llvm::to_vector<4>(llvm::map_range( ints, [&](int64_t val) -> OpFoldResult { return b.getIndexAttr(val); })); } -// This is a temporary solution to deal with types that are not fully supported -// like list, dict. For those container tyes, this helper can be used to -// convert their elements to valid target type. -// TODO: remove this when list gets full support. -static SmallVector getTypeConvertedValues(OpBuilder &b, Location loc, - TypeConverter *converter, - SmallVectorImpl &vs) { - return llvm::to_vector<4>(llvm::map_range(vs, [&](Value v) { - return converter->materializeTargetConversion( - b, loc, converter->convertType(v.getType()), v); - })); -} - // Helper function to get the padding tensor given the padding int values. static Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input, SmallVectorImpl &lowPaddingInts, @@ -377,66 +179,6 @@ static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) { return buildNormalCdf(b, loc, x, zero, one); } -// Convert a scalar value to the target type. The scalar value can be an element -// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype -// should be converted builtin types. -static Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, - Type dtype) { - Type scalarType = scalar.getType(); - if (scalarType == dtype) - return scalar; - - // TODO: For the byte(ui8) or char(i8) case, we need the unconverted dtype to - // be able to know if we need signed or unsigned conversion. - auto isByteOrChar = [](Type type) { - if (auto integerTy = type.dyn_cast()) { - return integerTy.getWidth() == 8; - } - return false; - }; - - if (isByteOrChar(scalarType) || isByteOrChar(dtype) || - dtype.isSignlessInteger(1)) { - // TODO: Handle to-boolean conversion(from-boolean conversion is handled). - mlir::emitError(loc) - << "unsupported byte, char or bool type for convertScalarToDtype " - << scalarType << "(scalar type) -> " << dtype << "(dtype)"; - return nullptr; - } - - if (auto dtypeFloat = dtype.dyn_cast()) { - if (auto scalarFloat = scalarType.dyn_cast()) { - if (scalarFloat.getWidth() > dtypeFloat.getWidth()) - return b.create(loc, dtype, scalar); - // Only scalarFloat width < dtypeFloat width can reach here. - return b.create(loc, dtype, scalar); - } - assert(scalarType.isa()); - if (scalarType.isSignlessInteger(1)) - return b.create(loc, dtype, scalar); - // It's safe to use SIToFPOp because ui8/si8 are the only ones where - // unsigned handling is needed, and we checked for that case above. - return b.create(loc, dtype, scalar); - } - - if (auto dtypeInteger = dtype.dyn_cast()) { - if (auto scalarFloat = scalarType.dyn_cast()) - return b.create(loc, dtype, scalar); - assert(scalarType.isa()); - auto scalarInteger = scalarType.cast(); - if (scalarInteger.getWidth() > dtypeInteger.getWidth()) - return b.create(loc, dtype, scalar); - if (scalarType.isSignlessInteger(1)) - return b.create(loc, dtype, scalar); - // Only scalarInteger width < dtypeInteger width can reach here. - // It's safe to use ExtSIOp here because ui8/si8 are the only ones where - // unsigned handling is needed, and we checked for that case above. - return b.create(loc, dtype, scalar); - } - - llvm_unreachable("convertScalarToDtype should handle all the types"); -} - // Create a reduction of `tensorOperand`, reducing along the dimensions // in `dimSet`. If `keepDim` is true, the output tensor is the same // rank as the `tensorOperand` and reduced dimensions are set to size 1. diff --git a/lib/Conversion/Utils/CMakeLists.txt b/lib/Conversion/Utils/CMakeLists.txt new file mode 100644 index 000000000..f4baf8634 --- /dev/null +++ b/lib/Conversion/Utils/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_conversion_library(TorchMLIRConversionUtils + Utils.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/Utils +) diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp new file mode 100644 index 000000000..a50e210a9 --- /dev/null +++ b/lib/Conversion/Utils/Utils.cpp @@ -0,0 +1,277 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/Utils/Utils.h" + +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" + +namespace mlir { +namespace torch { +namespace Torch { + +LogicalResult verifyLinalgCompatibleTypes(Operation *op, + PatternRewriter &rewriter) { + // Check the value tensor is ranked as expected by Linalg. + // TODO: Remove this check but use a separate verification pass to verify the + // invariants expected by later passes. + auto isValidLinalgType = [](Type type) { + auto tensor = type.dyn_cast(); + return !tensor || + tensor.toBuiltinTensor().dyn_cast_or_null(); + }; + + bool valid = llvm::all_of(op->getOperandTypes(), isValidLinalgType) && + llvm::all_of(op->getResultTypes(), isValidLinalgType); + if (!valid) + return rewriter.notifyMatchFailure(op, "type cannot be lowered to linalg"); + return success(); +} + +LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op, Value v) { + Type type = v.getType(); + if (type.isa() || type.isa() || + type.isa()) + return rewriter.notifyMatchFailure(op, "unimplemented None type arg"); + return success(); +} + +// Generate IR: dim = dim >= 0 ? dim : dim + inputRank +Value toPositiveDimDynamic(OpBuilder &b, Location loc, Value dim, + Value inputRank) { + assert(dim.getType().isa() && + "dim arg of toPositiveDim must be integer type"); + Value dimAddInputRank = b.create(loc, dim, inputRank); + Value cst0 = + b.create(loc, b.getZeroAttr(inputRank.getType())); + Value predDimGEZero = + b.create(loc, arith::CmpIPredicate::sge, dim, cst0); + Value dimInt = + b.create(loc, predDimGEZero, dim, dimAddInputRank); + return dimInt; +} + +// Generate IR: assert(dim >= 0 && dim < inputRank) +void assertIsValidDim(OpBuilder &b, Location loc, Value dim, Value inputRank) { + assert(dim.getType().isa() && + "dim arg of assertIsValidDim must be integer type"); + Value cst0 = + b.create(loc, b.getZeroAttr(inputRank.getType())); + Value predGEZero = + b.create(loc, arith::CmpIPredicate::sge, dim, cst0); + b.create( + loc, predGEZero, b.getStringAttr("dim must be greater or equal to zero")); + Value predLTInputRank = + b.create(loc, arith::CmpIPredicate::slt, dim, inputRank); + b.create(loc, predLTInputRank, + b.getStringAttr("dim must be smaller than inputRank")); +} + +// Hack to deal with the Torch list type arguments which is not supported end +// to end. Constant values can be be extracted directly and non constant +// list values are not supported. +// TODO: loose this constraint when properly support list type +bool isConstantIntListMatching(Value value, SmallVectorImpl &expects) { + SmallVector intValues; + if (!matchPattern(value, m_TorchConstantIntList(intValues))) + return false; + + if (intValues.size() != expects.size()) + return false; + + for (auto it : llvm::zip(intValues, expects)) { + if (std::get<0>(it) != std::get<1>(it)) + return false; + } + return true; +} + +void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDim, + Value rhsDim) { + Type lhsType = lhsDim.getType(); + Type rhsType = rhsDim.getType(); + auto checkIntOrIndex = [](Type type) { + assert(type.isa() || + type.isa() && "must be either integer or index type"); + }; + checkIntOrIndex(lhsType); + checkIntOrIndex(rhsType); + Value lhsDimInt = lhsType.isIndex() ? castIndexToInt(b, loc, lhsDim) : lhsDim; + Value rhsDimInt = rhsType.isIndex() ? castIndexToInt(b, loc, rhsDim) : rhsDim; + Value contractingDimEqual = b.create( + loc, arith::CmpIPredicate::eq, lhsDimInt, rhsDimInt); + b.create(loc, contractingDimEqual, + b.getStringAttr("mismatching contracting dimension")); +} + +// Creates a tensor with required `sizes` and `elemTy` and fills it with +// initElem. +Value createInitTensor(OpBuilder &b, Location loc, ValueRange sizes, + Type elemTy, Value initElem) { + Value initTensor = b.create(loc, sizes, elemTy); + return b.create(loc, initElem, initTensor).getResult(0); +} + +Value castIntToIndex(OpBuilder &b, Location loc, Value v) { + assert(v.getType().isa() && "must be called with integer type"); + return b.create(loc, b.getIndexType(), v); +} + +Value castIndexToInt(OpBuilder &b, Location loc, Value idx) { + assert(idx.getType().isa() && "must be called with integer type"); + return b.create(loc, b.getI64Type(), idx); +} + +Value getDimOp(OpBuilder &b, Location loc, Value v, int dim) { + return b.createOrFold(loc, v, dim); +} + +SmallVector getTensorSizesUntilDim(OpBuilder &b, Location loc, + Value tensor, int dim) { + RankedTensorType type = tensor.getType().cast(); + assert(dim < type.getRank() && + "The given dim must be smaller than tensor rank"); + (void)type; + SmallVector sizes; + for (int i = 0; i <= dim; i++) + sizes.push_back(getDimOp(b, loc, tensor, i)); + return sizes; +} + +SmallVector getTensorSizes(OpBuilder &b, Location loc, Value tensor) { + RankedTensorType type = tensor.getType().cast(); + return getTensorSizesUntilDim(b, loc, tensor, type.getRank() - 1); +} + +Value getTensorSize(OpBuilder &b, Location loc, Value tensor) { + SmallVector sizes(getTensorSizes(b, loc, tensor)); + Value productResult = b.create(loc, b.getIndexAttr(1)); + for (Value size : sizes) + productResult = b.create(loc, productResult, size); + return castIndexToInt(b, loc, productResult); +} + +Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes, + Type elemTy) { + Value initTensor = b.create(loc, sizes, elemTy); + RankedTensorType type = initTensor.getType().cast(); + Value c0 = + b.create(loc, b.getZeroAttr(type.getElementType())); + return b.create(loc, c0, initTensor).getResult(0); +} + +// Creates a constant of type `elemType` with value `val`. +Value getConstant(OpBuilder &b, Location loc, int64_t val, Type elemType) { + Attribute attr = {}; + if (elemType.isa()) + attr = b.getFloatAttr(elemType, val); + if (elemType.isa()) + attr = b.getIndexAttr(val); + if (elemType.isa()) + attr = b.getIntegerAttr( + elemType, APInt(elemType.cast().getWidth(), val)); + if (!attr) + return nullptr; + return b.create(loc, elemType, attr); +} + +SmallVector getAsConstantIntValues(OpBuilder &b, Location loc, + SmallVectorImpl &ints) { + return llvm::to_vector<4>(llvm::map_range(ints, [&](int64_t val) -> Value { + return b.create(loc, + b.getIntegerAttr(b.getI64Type(), val)); + })); +} + +SmallVector getAsConstantIndexValues(OpBuilder &b, Location loc, + SmallVectorImpl &ints) { + return llvm::to_vector<4>(llvm::map_range(ints, [&](int64_t val) -> Value { + return b.create(loc, b.getIndexAttr(val)); + })); +} + +// This is a temporary solution to deal with types that are not fully supported +// like list, dict. For those container tyes, this helper can be used to +// convert their elements to valid target type. +// TODO: remove this when list gets full support. +SmallVector getTypeConvertedValues(OpBuilder &b, Location loc, + TypeConverter *converter, + SmallVectorImpl &vs) { + return llvm::to_vector<4>(llvm::map_range(vs, [&](Value v) { + return converter->materializeTargetConversion( + b, loc, converter->convertType(v.getType()), v); + })); +} + +// Convert a scalar value to the target type. The scalar value can be an element +// from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype +// should be converted builtin types. +Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, + Type dtype) { + Type scalarType = scalar.getType(); + if (scalarType == dtype) + return scalar; + + // TODO: For the byte(ui8) or char(i8) case, we need the unconverted dtype to + // be able to know if we need signed or unsigned conversion. + auto isByteOrChar = [](Type type) { + if (auto integerTy = type.dyn_cast()) { + return integerTy.getWidth() == 8; + } + return false; + }; + + if (isByteOrChar(scalarType) || isByteOrChar(dtype) || + dtype.isSignlessInteger(1)) { + // TODO: Handle to-boolean conversion(from-boolean conversion is handled). + mlir::emitError(loc) + << "unsupported byte, char or bool type for convertScalarToDtype " + << scalarType << "(scalar type) -> " << dtype << "(dtype)"; + return nullptr; + } + + if (auto dtypeFloat = dtype.dyn_cast()) { + if (auto scalarFloat = scalarType.dyn_cast()) { + if (scalarFloat.getWidth() > dtypeFloat.getWidth()) + return b.create(loc, dtype, scalar); + // Only scalarFloat width < dtypeFloat width can reach here. + return b.create(loc, dtype, scalar); + } + assert(scalarType.isa()); + if (scalarType.isSignlessInteger(1)) + return b.create(loc, dtype, scalar); + // It's safe to use SIToFPOp because ui8/si8 are the only ones where + // unsigned handling is needed, and we checked for that case above. + return b.create(loc, dtype, scalar); + } + + if (auto dtypeInteger = dtype.dyn_cast()) { + if (auto scalarFloat = scalarType.dyn_cast()) + return b.create(loc, dtype, scalar); + assert(scalarType.isa()); + auto scalarInteger = scalarType.cast(); + if (scalarInteger.getWidth() > dtypeInteger.getWidth()) + return b.create(loc, dtype, scalar); + if (scalarType.isSignlessInteger(1)) + return b.create(loc, dtype, scalar); + // Only scalarInteger width < dtypeInteger width can reach here. + // It's safe to use ExtSIOp here because ui8/si8 are the only ones where + // unsigned handling is needed, and we checked for that case above. + return b.create(loc, dtype, scalar); + } + + llvm_unreachable("convertScalarToDtype should handle all the types"); +} + +} // namespace Torch +} // namespace torch +} // namespace mlir