mirror of https://github.com/llvm/torch-mlir
Make `getTypeForScalarType` safer by returning `FailureOr<Type>` (#1814)
One of the potential values for a `torch_upstream::ScalarType` is `Undefined`. This means that conversion of a `ScalarType` to another type is a computation that can fail. To enforce handling of the failure case, this commit makes the two helper functions that convert `ScalarType`s into other types return `failure()` when the `ScalarType` is `Undefined`.pull/1812/head
parent
d3c6183294
commit
d849cbad14
|
@ -26,7 +26,7 @@ bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems);
|
||||||
std::optional<int64_t> matchLegalConstantIndexIntoListOfSize(Value v,
|
std::optional<int64_t> matchLegalConstantIndexIntoListOfSize(Value v,
|
||||||
int64_t length);
|
int64_t length);
|
||||||
torch_upstream::ScalarType getScalarTypeForType(Type type);
|
torch_upstream::ScalarType getScalarTypeForType(Type type);
|
||||||
Type getTypeForScalarType(
|
FailureOr<Type> getTypeForScalarType(
|
||||||
MLIRContext *context, torch_upstream::ScalarType dtypeInt,
|
MLIRContext *context, torch_upstream::ScalarType dtypeInt,
|
||||||
mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed);
|
mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed);
|
||||||
|
|
||||||
|
|
|
@ -127,9 +127,14 @@ public:
|
||||||
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt)))
|
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt)))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "unimplemented: dtype must be a constant integer or none");
|
op, "unimplemented: dtype must be a constant integer or none");
|
||||||
resultElementType = getTypeForScalarType(
|
FailureOr<Type> maybeResultElementType = getTypeForScalarType(
|
||||||
op->getContext(), (torch_upstream::ScalarType)dtypeInt,
|
op->getContext(), (torch_upstream::ScalarType)dtypeInt,
|
||||||
IntegerType::Signless);
|
IntegerType::Signless);
|
||||||
|
if (failed(maybeResultElementType)) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unable to convert `dtypeInt` to builtin type");
|
||||||
|
}
|
||||||
|
resultElementType = *maybeResultElementType;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create an uninitialized tensor of `resultSize` shape and fill it with
|
// Create an uninitialized tensor of `resultSize` shape and fill it with
|
||||||
|
@ -227,9 +232,14 @@ public:
|
||||||
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt)))
|
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt)))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "unimplemented: dtype must be a constant integer or none");
|
op, "unimplemented: dtype must be a constant integer or none");
|
||||||
resultElementType = getTypeForScalarType(
|
FailureOr<Type> maybeResultElementType = getTypeForScalarType(
|
||||||
op->getContext(), (torch_upstream::ScalarType)dtypeInt,
|
op->getContext(), (torch_upstream::ScalarType)dtypeInt,
|
||||||
IntegerType::Signless);
|
IntegerType::Signless);
|
||||||
|
if (failed(maybeResultElementType)) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unable to convert `dtypeInt` to builtin type");
|
||||||
|
}
|
||||||
|
resultElementType = *maybeResultElementType;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create an uninitialized tensor of `resultSize` shape.
|
// Create an uninitialized tensor of `resultSize` shape.
|
||||||
|
|
|
@ -33,9 +33,11 @@ static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) {
|
||||||
int64_t dtypeInt;
|
int64_t dtypeInt;
|
||||||
if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
|
if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
|
||||||
return false;
|
return false;
|
||||||
Type resDtype =
|
FailureOr<Type> resDtype =
|
||||||
getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
|
getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
|
||||||
return resDtype.isa<mlir::FloatType>();
|
if (failed(resDtype))
|
||||||
|
return false;
|
||||||
|
return resDtype->isa<mlir::FloatType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to compute the return type of the reduction function.
|
// Helper function to compute the return type of the reduction function.
|
||||||
|
|
|
@ -81,7 +81,9 @@ using namespace mlir::torch::Torch;
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
static Type getTypeForDTypeInteger(MLIRContext *context, int64_t dtypeInt) {
|
static Type getTypeForDTypeInteger(MLIRContext *context, int64_t dtypeInt) {
|
||||||
return getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
|
FailureOr<Type> result =
|
||||||
|
getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
|
||||||
|
return failed(result) ? Type() : *result;
|
||||||
}
|
}
|
||||||
|
|
||||||
static Type getDtypeOrDefault(MLIRContext *context, Value optionalDtype,
|
static Type getDtypeOrDefault(MLIRContext *context, Value optionalDtype,
|
||||||
|
@ -563,7 +565,9 @@ static Type getPromotedResultDType(ValueKnowledge *tensor, Type scalarType) {
|
||||||
/*skipRankCheck=*/true);
|
/*skipRankCheck=*/true);
|
||||||
state =
|
state =
|
||||||
updateResultTypeState(getDefaultDtypeForTorchScalar(scalarType), state);
|
updateResultTypeState(getDefaultDtypeForTorchScalar(scalarType), state);
|
||||||
return getTypeForScalarType(scalarType.getContext(), result_type(state));
|
FailureOr<Type> result =
|
||||||
|
getTypeForScalarType(scalarType.getContext(), result_type(state));
|
||||||
|
return failed(result) ? Type() : *result;
|
||||||
}
|
}
|
||||||
|
|
||||||
static SmallVector<std::optional<bool>>
|
static SmallVector<std::optional<bool>>
|
||||||
|
@ -600,7 +604,8 @@ static Type getPromotedResultType(MLIRContext *context,
|
||||||
return Type();
|
return Type();
|
||||||
state = updateResultTypeState(tensor, rankIsNonZero, state, skipRankCheck);
|
state = updateResultTypeState(tensor, rankIsNonZero, state, skipRankCheck);
|
||||||
}
|
}
|
||||||
return getTypeForScalarType(context, result_type(state));
|
FailureOr<Type> result = getTypeForScalarType(context, result_type(state));
|
||||||
|
return failed(result) ? Type() : *result;
|
||||||
}
|
}
|
||||||
|
|
||||||
static Type getPromotedResultTypeAssumingNonZeroRank(
|
static Type getPromotedResultTypeAssumingNonZeroRank(
|
||||||
|
|
|
@ -46,10 +46,15 @@ static LogicalResult refineDtypeCalculateResult(DtypeCalculateOp op,
|
||||||
impliedTypeFromDtype = *torchType;
|
impliedTypeFromDtype = *torchType;
|
||||||
} else if (auto originalResultType =
|
} else if (auto originalResultType =
|
||||||
result.getType().dyn_cast<BaseTensorType>()) {
|
result.getType().dyn_cast<BaseTensorType>()) {
|
||||||
|
FailureOr<Type> builtinType =
|
||||||
|
getTypeForScalarType(op->getContext(), dtypeScalarType);
|
||||||
|
if (failed(builtinType)) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Failed to convert `dtypeScalarType` to a builtin type");
|
||||||
|
}
|
||||||
impliedTypeFromDtype =
|
impliedTypeFromDtype =
|
||||||
originalResultType.cast<BaseTensorType>().getWithSizesAndDtype(
|
originalResultType.cast<BaseTensorType>().getWithSizesAndDtype(
|
||||||
originalResultType.getOptionalSizes(),
|
originalResultType.getOptionalSizes(), *builtinType);
|
||||||
getTypeForScalarType(op->getContext(), dtypeScalarType));
|
|
||||||
} else {
|
} else {
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"Unimplemented: Expected result type to "
|
"Unimplemented: Expected result type to "
|
||||||
|
|
|
@ -83,8 +83,9 @@ Type Torch::getTypeForTorchType(
|
||||||
llvm::report_fatal_error("unhandled type for getTypeForTorchType");
|
llvm::report_fatal_error("unhandled type for getTypeForTorchType");
|
||||||
}
|
}
|
||||||
|
|
||||||
Type Torch::getTypeForScalarType(
|
FailureOr<Type>
|
||||||
MLIRContext *context, torch_upstream::ScalarType dtypeInt,
|
Torch::getTypeForScalarType(MLIRContext *context,
|
||||||
|
torch_upstream::ScalarType dtypeInt,
|
||||||
mlir::IntegerType::SignednessSemantics signedness) {
|
mlir::IntegerType::SignednessSemantics signedness) {
|
||||||
switch (dtypeInt) {
|
switch (dtypeInt) {
|
||||||
case torch_upstream::ScalarType::Float:
|
case torch_upstream::ScalarType::Float:
|
||||||
|
@ -110,6 +111,8 @@ Type Torch::getTypeForScalarType(
|
||||||
return mlir::ComplexType::get(Float64Type::get(context));
|
return mlir::ComplexType::get(Float64Type::get(context));
|
||||||
case torch_upstream::ScalarType::ComplexDouble:
|
case torch_upstream::ScalarType::ComplexDouble:
|
||||||
return mlir::ComplexType::get(Float128Type::get(context));
|
return mlir::ComplexType::get(Float128Type::get(context));
|
||||||
|
case torch_upstream::ScalarType::Undefined:
|
||||||
|
return failure();
|
||||||
default:
|
default:
|
||||||
llvm::report_fatal_error("unhandled type for getTypeForScalarType");
|
llvm::report_fatal_error("unhandled type for getTypeForScalarType");
|
||||||
}
|
}
|
||||||
|
@ -123,6 +126,7 @@ Torch::getTorchTypeForScalarType(MLIRContext *context,
|
||||||
return Torch::FloatType::get(context);
|
return Torch::FloatType::get(context);
|
||||||
case torch_upstream::ScalarType::Long:
|
case torch_upstream::ScalarType::Long:
|
||||||
return Torch::IntType::get(context);
|
return Torch::IntType::get(context);
|
||||||
|
case torch_upstream::ScalarType::Undefined:
|
||||||
default:
|
default:
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue