build: update llvm tag to 3a020527 (#1717)

Summary of changes:

 - Replace `llvm::None` with `std::nullopt`, since the former is deprecated
   (https://reviews.llvm.org/D139763)

 - Use setter for symbol visibility instead of passing string attribute when
   creating FuncOp
pull/1718/head oneshot-20221214.73
Ashay Rane 2022-12-14 02:06:39 -06:00 committed by GitHub
parent b1f6832849
commit f63bb9f86c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 57 additions and 56 deletions

@ -1 +1 @@
Subproject commit 798fa4b415eea55c868ae42b874083cb9886991e Subproject commit 3a020527c2af10741b12e756de45bd6f774885a4

2
externals/mlir-hlo vendored

@ -1 +1 @@
Subproject commit 037315c6515b5323ff78bc3c54d70dffad2ddbd0 Subproject commit 8df20065b22be628f2d365c387200df7d02b80c1

View File

@ -84,7 +84,7 @@ SmallVector<Value> getTypeConvertedValues(OpBuilder &b, Location loc,
// should be converted builtin types. // should be converted builtin types.
Value convertScalarToDtype( Value convertScalarToDtype(
OpBuilder &b, Location loc, Value scalar, Type dtype, OpBuilder &b, Location loc, Value scalar, Type dtype,
llvm::Optional<Type> srcOriginalDtype = llvm::None); llvm::Optional<Type> srcOriginalDtype = std::nullopt);
} // namespace Torch } // namespace Torch
} // namespace torch } // namespace torch

View File

@ -190,7 +190,7 @@ struct torch_list_of_optional_constant_ints_op_binder {
if (matchPattern(value, m_TorchConstantInt(&num))) if (matchPattern(value, m_TorchConstantInt(&num)))
bind_values.push_back(num); bind_values.push_back(num);
else if (value.getType().isa<Torch::NoneType>()) else if (value.getType().isa<Torch::NoneType>())
bind_values.push_back(llvm::None); bind_values.push_back(std::nullopt);
else else
return false; return false;
} }

View File

@ -198,7 +198,7 @@ MlirType torchMlirTorchNonValueTensorTypeGet(MlirContext context,
intptr_t numSizes, intptr_t numSizes,
const int64_t *optionalSizes, const int64_t *optionalSizes,
MlirType optionalDtype) { MlirType optionalDtype) {
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None; Optional<ArrayRef<int64_t>> optionalSizesArrayRef = std::nullopt;
// if numSizes == -1, then it is unranked. // if numSizes == -1, then it is unranked.
if (numSizes > -1) if (numSizes > -1)
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes); optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
@ -232,7 +232,7 @@ MlirType torchMlirTorchValueTensorTypeGet(MlirContext context,
intptr_t numSizes, intptr_t numSizes,
const int64_t *optionalSizes, const int64_t *optionalSizes,
MlirType optionalDtype) { MlirType optionalDtype) {
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None; Optional<ArrayRef<int64_t>> optionalSizesArrayRef = std::nullopt;
// if numSizes == -1, then it is unranked. // if numSizes == -1, then it is unranked.
if (numSizes > -1) if (numSizes > -1)
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes); optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);

View File

@ -37,7 +37,7 @@ llvm::Optional<Value> convertReduceOpCommon(
RankedTensorType input_type = RankedTensorType input_type =
input_value.getType().dyn_cast<RankedTensorType>(); input_value.getType().dyn_cast<RankedTensorType>();
if (!input_type) if (!input_type)
return llvm::None; return std::nullopt;
ArrayRef<int64_t> input_shape = input_type.getShape(); ArrayRef<int64_t> input_shape = input_type.getShape();
ArrayRef<int64_t> output_shape = output_type.getShape(); ArrayRef<int64_t> output_shape = output_type.getShape();
@ -101,7 +101,7 @@ convertReduceAllOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType input_type = RankedTensorType input_type =
input_value.getType().dyn_cast<RankedTensorType>(); input_value.getType().dyn_cast<RankedTensorType>();
if (!input_type) if (!input_type)
return llvm::None; return std::nullopt;
return convertReduceOpCommon<tosa::ReduceAllOp>( return convertReduceOpCommon<tosa::ReduceAllOp>(
rewriter, op, output_type, input_value, axes_elems, keep_dims, rewriter, op, output_type, input_value, axes_elems, keep_dims,
@ -116,7 +116,7 @@ convertReduceAnyOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType input_type = RankedTensorType input_type =
input_value.getType().dyn_cast<RankedTensorType>(); input_value.getType().dyn_cast<RankedTensorType>();
if (!input_type) if (!input_type)
return llvm::None; return std::nullopt;
return convertReduceOpCommon<tosa::ReduceAnyOp>( return convertReduceOpCommon<tosa::ReduceAnyOp>(
rewriter, op, output_type, input_value, axes_elems, keep_dims, rewriter, op, output_type, input_value, axes_elems, keep_dims,
@ -131,7 +131,7 @@ convertReduceMinOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType input_type = RankedTensorType input_type =
input_value.getType().dyn_cast<RankedTensorType>(); input_value.getType().dyn_cast<RankedTensorType>();
if (!input_type) if (!input_type)
return llvm::None; return std::nullopt;
return convertReduceOpCommon<tosa::ReduceMinOp>( return convertReduceOpCommon<tosa::ReduceMinOp>(
rewriter, op, output_type, input_value, axes_elems, keep_dims, rewriter, op, output_type, input_value, axes_elems, keep_dims,
@ -146,7 +146,7 @@ convertReduceMaxOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType input_type = RankedTensorType input_type =
input_value.getType().dyn_cast<RankedTensorType>(); input_value.getType().dyn_cast<RankedTensorType>();
if (!input_type) if (!input_type)
return llvm::None; return std::nullopt;
return convertReduceOpCommon<tosa::ReduceMaxOp>( return convertReduceOpCommon<tosa::ReduceMaxOp>(
rewriter, op, output_type, input_value, axes_elems, keep_dims, rewriter, op, output_type, input_value, axes_elems, keep_dims,
@ -161,7 +161,7 @@ convertReduceProdOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType input_type = RankedTensorType input_type =
input_value.getType().dyn_cast<RankedTensorType>(); input_value.getType().dyn_cast<RankedTensorType>();
if (!input_type) if (!input_type)
return llvm::None; return std::nullopt;
bool input_is_qtype = bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>(); input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
@ -171,7 +171,7 @@ convertReduceProdOp(PatternRewriter &rewriter, Operation *op,
if (input_is_qtype || output_is_qtype) { if (input_is_qtype || output_is_qtype) {
op->emitOpError("ConvertReduceProdOp: input/output tensor should " op->emitOpError("ConvertReduceProdOp: input/output tensor should "
"be all floating-point."); "be all floating-point.");
return llvm::None; return std::nullopt;
} }
return convertReduceOpCommon<tosa::ReduceProdOp>( return convertReduceOpCommon<tosa::ReduceProdOp>(
@ -187,7 +187,7 @@ convertReduceSumOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType input_type = RankedTensorType input_type =
input_value.getType().dyn_cast<RankedTensorType>(); input_value.getType().dyn_cast<RankedTensorType>();
if (!input_type) if (!input_type)
return llvm::None; return std::nullopt;
bool input_is_qtype = bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>(); input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
@ -197,7 +197,7 @@ convertReduceSumOp(PatternRewriter &rewriter, Operation *op,
if (input_is_qtype != output_is_qtype) { if (input_is_qtype != output_is_qtype) {
op->emitOpError("ConvertReduceSumOp: input/output tensor should " op->emitOpError("ConvertReduceSumOp: input/output tensor should "
"be all quantized or all floating-point."); "be all quantized or all floating-point.");
return llvm::None; return std::nullopt;
} }
double input_scale = 1.0f; double input_scale = 1.0f;
@ -242,7 +242,7 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
RankedTensorType input_type = RankedTensorType input_type =
input_value.getType().dyn_cast<RankedTensorType>(); input_value.getType().dyn_cast<RankedTensorType>();
if (!input_type) if (!input_type)
return llvm::None; return std::nullopt;
bool input_is_qtype = bool input_is_qtype =
input_type.getElementType().isa<mlir::quant::UniformQuantizedType>(); input_type.getElementType().isa<mlir::quant::UniformQuantizedType>();
@ -252,7 +252,7 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
if (input_is_qtype != output_is_qtype) { if (input_is_qtype != output_is_qtype) {
op->emitOpError("ConvertReduceSumOp: input/output tensor should " op->emitOpError("ConvertReduceSumOp: input/output tensor should "
"be all quantized or all floating-point."); "be all quantized or all floating-point.");
return llvm::None; return std::nullopt;
} }
// Only supports float type mean() if it's non-quantized // Only supports float type mean() if it's non-quantized
@ -260,7 +260,7 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
op->emitWarning( op->emitWarning(
"Failed convertReduceMean: input unquantized type but output element " "Failed convertReduceMean: input unquantized type but output element "
"not FloatType!"); "not FloatType!");
return llvm::None; return std::nullopt;
} }
int64_t input_rank = input_type.getRank(); int64_t input_rank = input_type.getRank();
@ -303,7 +303,7 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
output_zp); output_zp);
if (!val.has_value()) if (!val.has_value())
return llvm::None; return std::nullopt;
if (!input_is_qtype) { if (!input_is_qtype) {
Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale); Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale);

View File

@ -162,7 +162,7 @@ llvm::Optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
if (vec.size() != num_total_elements) { if (vec.size() != num_total_elements) {
op->emitOpError("getConstTensor(): number of elements mismatch."); op->emitOpError("getConstTensor(): number of elements mismatch.");
return llvm::None; return std::nullopt;
} }
auto const_type = auto const_type =
@ -186,7 +186,7 @@ llvm::Optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
if (vec.size() != num_total_elements) { if (vec.size() != num_total_elements) {
op->emitOpError("getConstTensor(): number of elements mismatch."); op->emitOpError("getConstTensor(): number of elements mismatch.");
return llvm::None; return std::nullopt;
} }
auto const_type = RankedTensorType::get( auto const_type = RankedTensorType::get(
@ -210,7 +210,7 @@ llvm::Optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
if (vec.size() != num_total_elements) { if (vec.size() != num_total_elements) {
op->emitOpError("getConstTensor(): number of elements mismatch."); op->emitOpError("getConstTensor(): number of elements mismatch.");
return llvm::None; return std::nullopt;
} }
auto const_type = RankedTensorType::get(shape, rewriter.getF32Type()); auto const_type = RankedTensorType::get(shape, rewriter.getF32Type());

View File

@ -1028,7 +1028,7 @@ traceKnownSizeTensorType(Value value, llvm::Optional<int64_t> dim) {
if (!tensorType.hasSizes()) if (!tensorType.hasSizes())
return false; return false;
if (dim == llvm::None) if (dim == std::nullopt)
return tensorType.areAllSizesKnown(); return tensorType.areAllSizesKnown();
// If the dimension value is negative, then convert it to a positive value. // If the dimension value is negative, then convert it to a positive value.
@ -1062,7 +1062,7 @@ traceKnownSizeTensorType(Value value, llvm::Optional<int64_t> dim) {
void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) { MLIRContext *context) {
patterns.add(+[](AtenSizeOp op, PatternRewriter &rewriter) { patterns.add(+[](AtenSizeOp op, PatternRewriter &rewriter) {
auto type = traceKnownSizeTensorType(op.getOperand(), llvm::None); auto type = traceKnownSizeTensorType(op.getOperand(), std::nullopt);
if (failed(type)) if (failed(type))
return rewriter.notifyMatchFailure(op, "all sizes not known"); return rewriter.notifyMatchFailure(op, "all sizes not known");
SmallVector<Value> listElements; SmallVector<Value> listElements;

View File

@ -89,7 +89,7 @@ bool Torch::isValidSubtype(Type subtype, Type type) {
static Optional<SmallVector<Type>> static Optional<SmallVector<Type>>
parseMultipleContainedTypes(AsmParser &parser) { parseMultipleContainedTypes(AsmParser &parser) {
if (parser.parseLess()) if (parser.parseLess())
return None; return std::nullopt;
SmallVector<Type> containedTypes; SmallVector<Type> containedTypes;
if (!parser.parseOptionalGreater()) if (!parser.parseOptionalGreater())
@ -97,11 +97,11 @@ parseMultipleContainedTypes(AsmParser &parser) {
do { do {
Type containedType = parseTorchDialectType(parser); Type containedType = parseTorchDialectType(parser);
if (!containedType) if (!containedType)
return None; return std::nullopt;
containedTypes.push_back(containedType); containedTypes.push_back(containedType);
} while (!parser.parseOptionalComma()); } while (!parser.parseOptionalComma());
if (parser.parseGreater()) if (parser.parseGreater())
return None; return std::nullopt;
return containedTypes; return containedTypes;
} }
@ -222,7 +222,8 @@ Type parseTensorType(MLIRContext *context, AsmParser &parser,
llvm::SMLoc startLoc = parser.getCurrentLocation(); llvm::SMLoc startLoc = parser.getCurrentLocation();
if (parser.parseOptionalLess()) if (parser.parseOptionalLess())
return getTensorType(context, return getTensorType(context,
/*optionalSizes=*/None, /*optionalDtype=*/Type()); /*optionalSizes=*/std::nullopt,
/*optionalDtype=*/Type());
bool hasSizes; bool hasSizes;
SmallVector<int64_t> sizes; SmallVector<int64_t> sizes;
if (succeeded(parser.parseOptionalStar())) { if (succeeded(parser.parseOptionalStar())) {
@ -320,7 +321,7 @@ ValueTensorType NonValueTensorType::getWithValueSemantics() const {
NonValueTensorType NonValueTensorType
NonValueTensorType::getWithLeastStaticInformation(MLIRContext *context) { NonValueTensorType::getWithLeastStaticInformation(MLIRContext *context) {
return NonValueTensorType::get(context, return NonValueTensorType::get(context,
/*optionalSizes=*/None, /*optionalSizes=*/std::nullopt,
/*optionalDtype=*/Type()); /*optionalDtype=*/Type());
} }
@ -357,7 +358,7 @@ NonValueTensorType ValueTensorType::getWithoutValueSemantics() const {
ValueTensorType ValueTensorType
ValueTensorType::getWithLeastStaticInformation(MLIRContext *context) { ValueTensorType::getWithLeastStaticInformation(MLIRContext *context) {
return ValueTensorType::get(context, return ValueTensorType::get(context,
/*optionalSizes=*/None, /*optionalSizes=*/std::nullopt,
/*optionalDtype=*/Type()); /*optionalDtype=*/Type());
} }
@ -428,8 +429,8 @@ Type Torch::meetTensorTypes(BaseTensorType lhs, BaseTensorType rhs) {
// If neither has sizes, we have nothing left to do. // If neither has sizes, we have nothing left to do.
if (!lhs.hasSizes() && !rhs.hasSizes()) { if (!lhs.hasSizes() && !rhs.hasSizes()) {
return ValueTensorType::get(lhs.getContext(), /*optionalSizes=*/None, return ValueTensorType::get(lhs.getContext(),
dtype); /*optionalSizes=*/std::nullopt, dtype);
} }
// If the number of sizes is different, the two types are contradictory. // If the number of sizes is different, the two types are contradictory.

View File

@ -85,7 +85,7 @@ public:
func::FuncOp methodFunc) { func::FuncOp methodFunc) {
auto it = funcLinkageInfo.find({instance, methodFunc}); auto it = funcLinkageInfo.find({instance, methodFunc});
if (it == funcLinkageInfo.end()) if (it == funcLinkageInfo.end())
return None; return std::nullopt;
return it->second; return it->second;
} }
@ -638,7 +638,7 @@ static LogicalResult globalizeObjectGraph(ModuleOp module) {
for (auto &monomorphization : tracker.getMonomorphizations()) { for (auto &monomorphization : tracker.getMonomorphizations()) {
auto newFunc = cast<func::FuncOp>(monomorphization.func->clone()); auto newFunc = cast<func::FuncOp>(monomorphization.func->clone());
newFuncs[monomorphization] = newFunc; newFuncs[monomorphization] = newFunc;
Optional<LinkageInfo> linkageInfo = None; Optional<LinkageInfo> linkageInfo = std::nullopt;
// If it is potentially a method, check its linkage info. // If it is potentially a method, check its linkage info.
if (monomorphization.argInstances.size() != 0 && if (monomorphization.argInstances.size() != 0 &&
monomorphization.argInstances[0].argIndex == 0) { monomorphization.argInstances[0].argIndex == 0) {

View File

@ -112,9 +112,9 @@ static torch_upstream::TypeKind getTypeKind(Type type) {
} }
/// Returns the dtype that assumes information from both `lhs` and `rhs`. /// Returns the dtype that assumes information from both `lhs` and `rhs`.
/// Returns `None` if the types are contradictory. Note this can only be used /// Returns `std::nullopt` if the types are contradictory. Note this can only
/// on the `dtype` from tensors and can't be used on other types like scalar /// be used on the `dtype` from tensors and can't be used on other types like
/// types. /// scalar types.
static Optional<Type> meetElementTypes(Type lhs, Type rhs) { static Optional<Type> meetElementTypes(Type lhs, Type rhs) {
auto isNullOrBuiltIn = [](Type type) { return !type || isBuiltInType(type); }; auto isNullOrBuiltIn = [](Type type) { return !type || isBuiltInType(type); };
(void)isNullOrBuiltIn; (void)isNullOrBuiltIn;
@ -127,7 +127,7 @@ static Optional<Type> meetElementTypes(Type lhs, Type rhs) {
return lhs; return lhs;
if (lhs == rhs) if (lhs == rhs)
return lhs; return lhs;
return None; return std::nullopt;
} }
enum class OptionalKnowledge { enum class OptionalKnowledge {
@ -137,7 +137,7 @@ enum class OptionalKnowledge {
}; };
/// Returns the OptionalKnowledge that assumes information from both `lhs` and /// Returns the OptionalKnowledge that assumes information from both `lhs` and
/// `rhs`. Returns `None` if the knowledges are contradictory. /// `rhs`. Returns `std::nullopt` if the knowledges are contradictory.
static Optional<OptionalKnowledge> static Optional<OptionalKnowledge>
meetOptionalKnowledge(OptionalKnowledge lhs, OptionalKnowledge rhs) { meetOptionalKnowledge(OptionalKnowledge lhs, OptionalKnowledge rhs) {
if (lhs == OptionalKnowledge::unKnown) if (lhs == OptionalKnowledge::unKnown)
@ -146,7 +146,7 @@ meetOptionalKnowledge(OptionalKnowledge lhs, OptionalKnowledge rhs) {
return lhs; return lhs;
if (lhs == rhs) if (lhs == rhs)
return lhs; return lhs;
return None; return std::nullopt;
} }
static OptionalKnowledge joinOptionalKnowledge(OptionalKnowledge lhs, static OptionalKnowledge joinOptionalKnowledge(OptionalKnowledge lhs,
@ -327,7 +327,7 @@ struct ValueKnowledge {
// Given two pieces of static knowledge, calculate new knowledge that assumes // Given two pieces of static knowledge, calculate new knowledge that assumes
// the facts from both. // the facts from both.
// If the two pieces of knowledge are contradictory, None is returned. // If the two pieces of knowledge are contradictory, std::nullopt is returned.
static Optional<ValueKnowledge> meet(const ValueKnowledge &lhs, static Optional<ValueKnowledge> meet(const ValueKnowledge &lhs,
const ValueKnowledge &rhs) { const ValueKnowledge &rhs) {
if (!lhs.isInitialized) if (!lhs.isInitialized)
@ -338,13 +338,13 @@ struct ValueKnowledge {
Optional<ValueKnowledge> knowledge = meetTypes(lhs, rhs); Optional<ValueKnowledge> knowledge = meetTypes(lhs, rhs);
if (!knowledge.has_value()) if (!knowledge.has_value())
return None; return std::nullopt;
ValueKnowledge result = knowledge.value(); ValueKnowledge result = knowledge.value();
Optional<OptionalKnowledge> optional = Optional<OptionalKnowledge> optional =
meetOptionalKnowledge(lhs.optional, rhs.optional); meetOptionalKnowledge(lhs.optional, rhs.optional);
if (!optional.has_value()) if (!optional.has_value())
return None; return std::nullopt;
result.optional = optional.value(); result.optional = optional.value();
return result; return result;
} }
@ -362,7 +362,7 @@ struct ValueKnowledge {
return rhs; return rhs;
if (lhs == rhs) if (lhs == rhs)
return lhs; return lhs;
return None; return std::nullopt;
} }
// We start in the uninitialized state by default. // We start in the uninitialized state by default.
@ -559,7 +559,7 @@ static Type getPromotedResultDType(ValueKnowledge *tensor, Type scalarType) {
torch_upstream::ResultTypeState state = {}; torch_upstream::ResultTypeState state = {};
// No need to check if rank is zero for tensor because scalar uses // No need to check if rank is zero for tensor because scalar uses
// wrappedResult which is a lower priority than both dimResult and zeroResult. // wrappedResult which is a lower priority than both dimResult and zeroResult.
state = updateResultTypeState(tensor, /*rankIsNonZero=*/None, state, state = updateResultTypeState(tensor, /*rankIsNonZero=*/std::nullopt, state,
/*skipRankCheck=*/true); /*skipRankCheck=*/true);
state = state =
updateResultTypeState(getDefaultDtypeForTorchScalar(scalarType), state); updateResultTypeState(getDefaultDtypeForTorchScalar(scalarType), state);
@ -573,7 +573,7 @@ static SmallVector<Optional<bool>> getRankIsNonZeroArray(ValueRange values) {
if (tensorType.hasSizes()) { if (tensorType.hasSizes()) {
rankIsNonZero.push_back(tensorType.getSizes().size() != 0); rankIsNonZero.push_back(tensorType.getSizes().size() != 0);
} else { } else {
rankIsNonZero.push_back(None); rankIsNonZero.push_back(std::nullopt);
} }
} }
} }

View File

@ -27,10 +27,10 @@ llvm::Optional<int64_t>
Torch::matchLegalConstantIndexIntoListOfSize(Value v, int64_t length) { Torch::matchLegalConstantIndexIntoListOfSize(Value v, int64_t length) {
int64_t dim; int64_t dim;
if (!matchPattern(v, m_TorchConstantInt(&dim))) if (!matchPattern(v, m_TorchConstantInt(&dim)))
return llvm::None; return std::nullopt;
dim = toPositiveDim(dim, length); dim = toPositiveDim(dim, length);
if (!isValidDim(dim, length)) if (!isValidDim(dim, length))
return llvm::None; return std::nullopt;
return dim; return dim;
} }
@ -169,7 +169,7 @@ bool Torch::isBuiltInType(Type type) {
Optional<unsigned> Torch::getTensorRank(Value tensor) { Optional<unsigned> Torch::getTensorRank(Value tensor) {
BaseTensorType tensorType = tensor.getType().cast<BaseTensorType>(); BaseTensorType tensorType = tensor.getType().cast<BaseTensorType>();
if (!tensorType.hasSizes()) if (!tensorType.hasSizes())
return llvm::None; return std::nullopt;
return tensorType.getSizes().size(); return tensorType.getSizes().size();
} }

View File

@ -61,7 +61,7 @@ static void setupTorchBoolToI1Conversion(ConversionTarget &target,
Location loc) -> Optional<Value> { Location loc) -> Optional<Value> {
// Other builtin integer types could be handled by other materializers. // Other builtin integer types could be handled by other materializers.
if (!(type.getWidth() == 1 && type.isSignless())) if (!(type.getWidth() == 1 && type.isSignless()))
return None; return std::nullopt;
assert(inputs.size() == 1); assert(inputs.size() == 1);
assert(inputs[0].getType().isa<Torch::BoolType>()); assert(inputs[0].getType().isa<Torch::BoolType>());
return builder.create<ToI1Op>(loc, inputs[0]).getResult(); return builder.create<ToI1Op>(loc, inputs[0]).getResult();
@ -87,11 +87,11 @@ static void setupTorchIntToI64Conversion(ConversionTarget &target,
Location loc) -> Optional<Value> { Location loc) -> Optional<Value> {
// Other builtin integer types could be handled by other materializers. // Other builtin integer types could be handled by other materializers.
if (!(type.getWidth() == 64 && type.isSignless())) if (!(type.getWidth() == 64 && type.isSignless()))
return None; return std::nullopt;
// Other input type to be converted to i64 are handled by other // Other input type to be converted to i64 are handled by other
// materializers. // materializers.
if (!inputs[0].getType().isa<Torch::IntType>()) if (!inputs[0].getType().isa<Torch::IntType>())
return None; return std::nullopt;
assert(inputs.size() == 1); assert(inputs.size() == 1);
return builder.create<ToI64Op>(loc, inputs[0]).getResult(); return builder.create<ToI64Op>(loc, inputs[0]).getResult();
}); });
@ -140,11 +140,11 @@ static void setupTorchGeneratorToI64Conversion(ConversionTarget &target,
Location loc) -> Optional<Value> { Location loc) -> Optional<Value> {
// Other builtin integer types could be handled by other materializers. // Other builtin integer types could be handled by other materializers.
if (!(type.getWidth() == 64 && type.isSignless())) if (!(type.getWidth() == 64 && type.isSignless()))
return None; return std::nullopt;
// Other input type to be converted to i64 are handled by other // Other input type to be converted to i64 are handled by other
// materializers. // materializers.
if (!inputs[0].getType().isa<Torch::GeneratorType>()) if (!inputs[0].getType().isa<Torch::GeneratorType>())
return None; return std::nullopt;
assert(inputs.size() == 1); assert(inputs.size() == 1);
return builder.create<GeneratorToI64Op>(loc, inputs[0]).getResult(); return builder.create<GeneratorToI64Op>(loc, inputs[0]).getResult();
}); });

View File

@ -206,8 +206,8 @@ class MungeCallingConventions
for (auto &p : invokedConsumeFuncReturnFuncs) { for (auto &p : invokedConsumeFuncReturnFuncs) {
auto consumeFuncReturnFunc = b.create<func::FuncOp>( auto consumeFuncReturnFunc = b.create<func::FuncOp>(
module.getLoc(), p.first, module.getLoc(), p.first,
FunctionType::get(module.getContext(), p.second, {}), FunctionType::get(module.getContext(), p.second, {}));
b.getStringAttr("private")); consumeFuncReturnFunc.setPrivate();
addEmitCInterfaceAttr(consumeFuncReturnFunc); addEmitCInterfaceAttr(consumeFuncReturnFunc);
} }
} }