Make `getTypeForScalarType` safer by returning `FailureOr<Type>` (#1814)

One of the potential values for a `torch_upstream::ScalarType` is
`Undefined`. This means that conversion of a `ScalarType` to another
type is a computation that can fail. To enforce handling of the
failure case, this commit makes the two helper functions that convert
`ScalarType`s into other types return `failure()` when the
`ScalarType` is `Undefined`.
pull/1812/head
Ramiro Leal-Cavazos 2023-01-20 10:40:13 -08:00 committed by GitHub
parent d3c6183294
commit d849cbad14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 40 additions and 14 deletions

View File

@ -26,7 +26,7 @@ bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems);
std::optional<int64_t> matchLegalConstantIndexIntoListOfSize(Value v,
int64_t length);
torch_upstream::ScalarType getScalarTypeForType(Type type);
Type getTypeForScalarType(
FailureOr<Type> getTypeForScalarType(
MLIRContext *context, torch_upstream::ScalarType dtypeInt,
mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed);

View File

@ -127,9 +127,14 @@ public:
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt)))
return rewriter.notifyMatchFailure(
op, "unimplemented: dtype must be a constant integer or none");
resultElementType = getTypeForScalarType(
FailureOr<Type> maybeResultElementType = getTypeForScalarType(
op->getContext(), (torch_upstream::ScalarType)dtypeInt,
IntegerType::Signless);
if (failed(maybeResultElementType)) {
return rewriter.notifyMatchFailure(
op, "unable to convert `dtypeInt` to builtin type");
}
resultElementType = *maybeResultElementType;
}
// Create an uninitialized tensor of `resultSize` shape and fill it with
@ -227,9 +232,14 @@ public:
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt)))
return rewriter.notifyMatchFailure(
op, "unimplemented: dtype must be a constant integer or none");
resultElementType = getTypeForScalarType(
FailureOr<Type> maybeResultElementType = getTypeForScalarType(
op->getContext(), (torch_upstream::ScalarType)dtypeInt,
IntegerType::Signless);
if (failed(maybeResultElementType)) {
return rewriter.notifyMatchFailure(
op, "unable to convert `dtypeInt` to builtin type");
}
resultElementType = *maybeResultElementType;
}
// Create an uninitialized tensor of `resultSize` shape.

View File

@ -33,9 +33,11 @@ static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) {
int64_t dtypeInt;
if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
return false;
Type resDtype =
FailureOr<Type> resDtype =
getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
return resDtype.isa<mlir::FloatType>();
if (failed(resDtype))
return false;
return resDtype->isa<mlir::FloatType>();
}
// Helper function to compute the return type of the reduction function.

View File

@ -81,7 +81,9 @@ using namespace mlir::torch::Torch;
// -----------------------------------------------------------------------------
static Type getTypeForDTypeInteger(MLIRContext *context, int64_t dtypeInt) {
return getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
FailureOr<Type> result =
getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
return failed(result) ? Type() : *result;
}
static Type getDtypeOrDefault(MLIRContext *context, Value optionalDtype,
@ -563,7 +565,9 @@ static Type getPromotedResultDType(ValueKnowledge *tensor, Type scalarType) {
/*skipRankCheck=*/true);
state =
updateResultTypeState(getDefaultDtypeForTorchScalar(scalarType), state);
return getTypeForScalarType(scalarType.getContext(), result_type(state));
FailureOr<Type> result =
getTypeForScalarType(scalarType.getContext(), result_type(state));
return failed(result) ? Type() : *result;
}
static SmallVector<std::optional<bool>>
@ -600,7 +604,8 @@ static Type getPromotedResultType(MLIRContext *context,
return Type();
state = updateResultTypeState(tensor, rankIsNonZero, state, skipRankCheck);
}
return getTypeForScalarType(context, result_type(state));
FailureOr<Type> result = getTypeForScalarType(context, result_type(state));
return failed(result) ? Type() : *result;
}
static Type getPromotedResultTypeAssumingNonZeroRank(

View File

@ -46,10 +46,15 @@ static LogicalResult refineDtypeCalculateResult(DtypeCalculateOp op,
impliedTypeFromDtype = *torchType;
} else if (auto originalResultType =
result.getType().dyn_cast<BaseTensorType>()) {
FailureOr<Type> builtinType =
getTypeForScalarType(op->getContext(), dtypeScalarType);
if (failed(builtinType)) {
return rewriter.notifyMatchFailure(
op, "Failed to convert `dtypeScalarType` to a builtin type");
}
impliedTypeFromDtype =
originalResultType.cast<BaseTensorType>().getWithSizesAndDtype(
originalResultType.getOptionalSizes(),
getTypeForScalarType(op->getContext(), dtypeScalarType));
originalResultType.getOptionalSizes(), *builtinType);
} else {
return rewriter.notifyMatchFailure(op,
"Unimplemented: Expected result type to "

View File

@ -83,8 +83,9 @@ Type Torch::getTypeForTorchType(
llvm::report_fatal_error("unhandled type for getTypeForTorchType");
}
Type Torch::getTypeForScalarType(
MLIRContext *context, torch_upstream::ScalarType dtypeInt,
FailureOr<Type>
Torch::getTypeForScalarType(MLIRContext *context,
torch_upstream::ScalarType dtypeInt,
mlir::IntegerType::SignednessSemantics signedness) {
switch (dtypeInt) {
case torch_upstream::ScalarType::Float:
@ -110,6 +111,8 @@ Type Torch::getTypeForScalarType(
return mlir::ComplexType::get(Float64Type::get(context));
case torch_upstream::ScalarType::ComplexDouble:
return mlir::ComplexType::get(Float128Type::get(context));
case torch_upstream::ScalarType::Undefined:
return failure();
default:
llvm::report_fatal_error("unhandled type for getTypeForScalarType");
}
@ -123,6 +126,7 @@ Torch::getTorchTypeForScalarType(MLIRContext *context,
return Torch::FloatType::get(context);
case torch_upstream::ScalarType::Long:
return Torch::IntType::get(context);
case torch_upstream::ScalarType::Undefined:
default:
return failure();
}