Make `getTypeForScalarType` safer by returning `FailureOr<Type>` (#1814)

One of the potential values for a `torch_upstream::ScalarType` is
`Undefined`. This means that conversion of a `ScalarType` to another
type is a computation that can fail. To enforce handling of the
failure case, this commit makes the two helper functions that convert
`ScalarType`s into other types return `failure()` when the
`ScalarType` is `Undefined`.
pull/1812/head
Ramiro Leal-Cavazos 2023-01-20 10:40:13 -08:00 committed by GitHub
parent d3c6183294
commit d849cbad14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 40 additions and 14 deletions

View File

@ -26,7 +26,7 @@ bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems);
std::optional<int64_t> matchLegalConstantIndexIntoListOfSize(Value v, std::optional<int64_t> matchLegalConstantIndexIntoListOfSize(Value v,
int64_t length); int64_t length);
torch_upstream::ScalarType getScalarTypeForType(Type type); torch_upstream::ScalarType getScalarTypeForType(Type type);
Type getTypeForScalarType( FailureOr<Type> getTypeForScalarType(
MLIRContext *context, torch_upstream::ScalarType dtypeInt, MLIRContext *context, torch_upstream::ScalarType dtypeInt,
mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed); mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed);

View File

@ -127,9 +127,14 @@ public:
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt)))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "unimplemented: dtype must be a constant integer or none"); op, "unimplemented: dtype must be a constant integer or none");
resultElementType = getTypeForScalarType( FailureOr<Type> maybeResultElementType = getTypeForScalarType(
op->getContext(), (torch_upstream::ScalarType)dtypeInt, op->getContext(), (torch_upstream::ScalarType)dtypeInt,
IntegerType::Signless); IntegerType::Signless);
if (failed(maybeResultElementType)) {
return rewriter.notifyMatchFailure(
op, "unable to convert `dtypeInt` to builtin type");
}
resultElementType = *maybeResultElementType;
} }
// Create an uninitialized tensor of `resultSize` shape and fill it with // Create an uninitialized tensor of `resultSize` shape and fill it with
@ -227,9 +232,14 @@ public:
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt)))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "unimplemented: dtype must be a constant integer or none"); op, "unimplemented: dtype must be a constant integer or none");
resultElementType = getTypeForScalarType( FailureOr<Type> maybeResultElementType = getTypeForScalarType(
op->getContext(), (torch_upstream::ScalarType)dtypeInt, op->getContext(), (torch_upstream::ScalarType)dtypeInt,
IntegerType::Signless); IntegerType::Signless);
if (failed(maybeResultElementType)) {
return rewriter.notifyMatchFailure(
op, "unable to convert `dtypeInt` to builtin type");
}
resultElementType = *maybeResultElementType;
} }
// Create an uninitialized tensor of `resultSize` shape. // Create an uninitialized tensor of `resultSize` shape.

View File

@ -33,9 +33,11 @@ static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) {
int64_t dtypeInt; int64_t dtypeInt;
if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
return false; return false;
Type resDtype = FailureOr<Type> resDtype =
getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt); getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
return resDtype.isa<mlir::FloatType>(); if (failed(resDtype))
return false;
return resDtype->isa<mlir::FloatType>();
} }
// Helper function to compute the return type of the reduction function. // Helper function to compute the return type of the reduction function.

View File

@ -81,7 +81,9 @@ using namespace mlir::torch::Torch;
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
static Type getTypeForDTypeInteger(MLIRContext *context, int64_t dtypeInt) { static Type getTypeForDTypeInteger(MLIRContext *context, int64_t dtypeInt) {
return getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt); FailureOr<Type> result =
getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
return failed(result) ? Type() : *result;
} }
static Type getDtypeOrDefault(MLIRContext *context, Value optionalDtype, static Type getDtypeOrDefault(MLIRContext *context, Value optionalDtype,
@ -563,7 +565,9 @@ static Type getPromotedResultDType(ValueKnowledge *tensor, Type scalarType) {
/*skipRankCheck=*/true); /*skipRankCheck=*/true);
state = state =
updateResultTypeState(getDefaultDtypeForTorchScalar(scalarType), state); updateResultTypeState(getDefaultDtypeForTorchScalar(scalarType), state);
return getTypeForScalarType(scalarType.getContext(), result_type(state)); FailureOr<Type> result =
getTypeForScalarType(scalarType.getContext(), result_type(state));
return failed(result) ? Type() : *result;
} }
static SmallVector<std::optional<bool>> static SmallVector<std::optional<bool>>
@ -600,7 +604,8 @@ static Type getPromotedResultType(MLIRContext *context,
return Type(); return Type();
state = updateResultTypeState(tensor, rankIsNonZero, state, skipRankCheck); state = updateResultTypeState(tensor, rankIsNonZero, state, skipRankCheck);
} }
return getTypeForScalarType(context, result_type(state)); FailureOr<Type> result = getTypeForScalarType(context, result_type(state));
return failed(result) ? Type() : *result;
} }
static Type getPromotedResultTypeAssumingNonZeroRank( static Type getPromotedResultTypeAssumingNonZeroRank(

View File

@ -46,10 +46,15 @@ static LogicalResult refineDtypeCalculateResult(DtypeCalculateOp op,
impliedTypeFromDtype = *torchType; impliedTypeFromDtype = *torchType;
} else if (auto originalResultType = } else if (auto originalResultType =
result.getType().dyn_cast<BaseTensorType>()) { result.getType().dyn_cast<BaseTensorType>()) {
FailureOr<Type> builtinType =
getTypeForScalarType(op->getContext(), dtypeScalarType);
if (failed(builtinType)) {
return rewriter.notifyMatchFailure(
op, "Failed to convert `dtypeScalarType` to a builtin type");
}
impliedTypeFromDtype = impliedTypeFromDtype =
originalResultType.cast<BaseTensorType>().getWithSizesAndDtype( originalResultType.cast<BaseTensorType>().getWithSizesAndDtype(
originalResultType.getOptionalSizes(), originalResultType.getOptionalSizes(), *builtinType);
getTypeForScalarType(op->getContext(), dtypeScalarType));
} else { } else {
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"Unimplemented: Expected result type to " "Unimplemented: Expected result type to "

View File

@ -83,9 +83,10 @@ Type Torch::getTypeForTorchType(
llvm::report_fatal_error("unhandled type for getTypeForTorchType"); llvm::report_fatal_error("unhandled type for getTypeForTorchType");
} }
Type Torch::getTypeForScalarType( FailureOr<Type>
MLIRContext *context, torch_upstream::ScalarType dtypeInt, Torch::getTypeForScalarType(MLIRContext *context,
mlir::IntegerType::SignednessSemantics signedness) { torch_upstream::ScalarType dtypeInt,
mlir::IntegerType::SignednessSemantics signedness) {
switch (dtypeInt) { switch (dtypeInt) {
case torch_upstream::ScalarType::Float: case torch_upstream::ScalarType::Float:
return Float32Type::get(context); return Float32Type::get(context);
@ -110,6 +111,8 @@ Type Torch::getTypeForScalarType(
return mlir::ComplexType::get(Float64Type::get(context)); return mlir::ComplexType::get(Float64Type::get(context));
case torch_upstream::ScalarType::ComplexDouble: case torch_upstream::ScalarType::ComplexDouble:
return mlir::ComplexType::get(Float128Type::get(context)); return mlir::ComplexType::get(Float128Type::get(context));
case torch_upstream::ScalarType::Undefined:
return failure();
default: default:
llvm::report_fatal_error("unhandled type for getTypeForScalarType"); llvm::report_fatal_error("unhandled type for getTypeForScalarType");
} }
@ -123,6 +126,7 @@ Torch::getTorchTypeForScalarType(MLIRContext *context,
return Torch::FloatType::get(context); return Torch::FloatType::get(context);
case torch_upstream::ScalarType::Long: case torch_upstream::ScalarType::Long:
return Torch::IntType::get(context); return Torch::IntType::get(context);
case torch_upstream::ScalarType::Undefined:
default: default:
return failure(); return failure();
} }