mirror of https://github.com/llvm/torch-mlir
Make `getTypeForScalarType` safer by returning `FailureOr<Type>` (#1814)
One of the potential values for a `torch_upstream::ScalarType` is `Undefined`. This means that conversion of a `ScalarType` to another type is a computation that can fail. To enforce handling of the failure case, this commit makes the two helper functions that convert `ScalarType`s into other types return `failure()` when the `ScalarType` is `Undefined`.pull/1812/head
parent
d3c6183294
commit
d849cbad14
|
@ -26,7 +26,7 @@ bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems);
|
|||
std::optional<int64_t> matchLegalConstantIndexIntoListOfSize(Value v,
|
||||
int64_t length);
|
||||
torch_upstream::ScalarType getScalarTypeForType(Type type);
|
||||
Type getTypeForScalarType(
|
||||
FailureOr<Type> getTypeForScalarType(
|
||||
MLIRContext *context, torch_upstream::ScalarType dtypeInt,
|
||||
mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed);
|
||||
|
||||
|
|
|
@ -127,9 +127,14 @@ public:
|
|||
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: dtype must be a constant integer or none");
|
||||
resultElementType = getTypeForScalarType(
|
||||
FailureOr<Type> maybeResultElementType = getTypeForScalarType(
|
||||
op->getContext(), (torch_upstream::ScalarType)dtypeInt,
|
||||
IntegerType::Signless);
|
||||
if (failed(maybeResultElementType)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unable to convert `dtypeInt` to builtin type");
|
||||
}
|
||||
resultElementType = *maybeResultElementType;
|
||||
}
|
||||
|
||||
// Create an uninitialized tensor of `resultSize` shape and fill it with
|
||||
|
@ -227,9 +232,14 @@ public:
|
|||
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: dtype must be a constant integer or none");
|
||||
resultElementType = getTypeForScalarType(
|
||||
FailureOr<Type> maybeResultElementType = getTypeForScalarType(
|
||||
op->getContext(), (torch_upstream::ScalarType)dtypeInt,
|
||||
IntegerType::Signless);
|
||||
if (failed(maybeResultElementType)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unable to convert `dtypeInt` to builtin type");
|
||||
}
|
||||
resultElementType = *maybeResultElementType;
|
||||
}
|
||||
|
||||
// Create an uninitialized tensor of `resultSize` shape.
|
||||
|
|
|
@ -33,9 +33,11 @@ static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) {
|
|||
int64_t dtypeInt;
|
||||
if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
|
||||
return false;
|
||||
Type resDtype =
|
||||
FailureOr<Type> resDtype =
|
||||
getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
|
||||
return resDtype.isa<mlir::FloatType>();
|
||||
if (failed(resDtype))
|
||||
return false;
|
||||
return resDtype->isa<mlir::FloatType>();
|
||||
}
|
||||
|
||||
// Helper function to compute the return type of the reduction function.
|
||||
|
@ -3803,4 +3805,4 @@ std::unique_ptr<OperationPass<func::FuncOp>>
|
|||
mlir::torch::Torch::createDecomposeComplexOpsPass(
|
||||
ArrayRef<std::string> legalOps) {
|
||||
return std::make_unique<DecomposeComplexOpsPass>(legalOps);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -81,7 +81,9 @@ using namespace mlir::torch::Torch;
|
|||
// -----------------------------------------------------------------------------
|
||||
|
||||
static Type getTypeForDTypeInteger(MLIRContext *context, int64_t dtypeInt) {
|
||||
return getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
|
||||
FailureOr<Type> result =
|
||||
getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
|
||||
return failed(result) ? Type() : *result;
|
||||
}
|
||||
|
||||
static Type getDtypeOrDefault(MLIRContext *context, Value optionalDtype,
|
||||
|
@ -563,7 +565,9 @@ static Type getPromotedResultDType(ValueKnowledge *tensor, Type scalarType) {
|
|||
/*skipRankCheck=*/true);
|
||||
state =
|
||||
updateResultTypeState(getDefaultDtypeForTorchScalar(scalarType), state);
|
||||
return getTypeForScalarType(scalarType.getContext(), result_type(state));
|
||||
FailureOr<Type> result =
|
||||
getTypeForScalarType(scalarType.getContext(), result_type(state));
|
||||
return failed(result) ? Type() : *result;
|
||||
}
|
||||
|
||||
static SmallVector<std::optional<bool>>
|
||||
|
@ -600,7 +604,8 @@ static Type getPromotedResultType(MLIRContext *context,
|
|||
return Type();
|
||||
state = updateResultTypeState(tensor, rankIsNonZero, state, skipRankCheck);
|
||||
}
|
||||
return getTypeForScalarType(context, result_type(state));
|
||||
FailureOr<Type> result = getTypeForScalarType(context, result_type(state));
|
||||
return failed(result) ? Type() : *result;
|
||||
}
|
||||
|
||||
static Type getPromotedResultTypeAssumingNonZeroRank(
|
||||
|
|
|
@ -46,10 +46,15 @@ static LogicalResult refineDtypeCalculateResult(DtypeCalculateOp op,
|
|||
impliedTypeFromDtype = *torchType;
|
||||
} else if (auto originalResultType =
|
||||
result.getType().dyn_cast<BaseTensorType>()) {
|
||||
FailureOr<Type> builtinType =
|
||||
getTypeForScalarType(op->getContext(), dtypeScalarType);
|
||||
if (failed(builtinType)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Failed to convert `dtypeScalarType` to a builtin type");
|
||||
}
|
||||
impliedTypeFromDtype =
|
||||
originalResultType.cast<BaseTensorType>().getWithSizesAndDtype(
|
||||
originalResultType.getOptionalSizes(),
|
||||
getTypeForScalarType(op->getContext(), dtypeScalarType));
|
||||
originalResultType.getOptionalSizes(), *builtinType);
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Unimplemented: Expected result type to "
|
||||
|
|
|
@ -83,9 +83,10 @@ Type Torch::getTypeForTorchType(
|
|||
llvm::report_fatal_error("unhandled type for getTypeForTorchType");
|
||||
}
|
||||
|
||||
Type Torch::getTypeForScalarType(
|
||||
MLIRContext *context, torch_upstream::ScalarType dtypeInt,
|
||||
mlir::IntegerType::SignednessSemantics signedness) {
|
||||
FailureOr<Type>
|
||||
Torch::getTypeForScalarType(MLIRContext *context,
|
||||
torch_upstream::ScalarType dtypeInt,
|
||||
mlir::IntegerType::SignednessSemantics signedness) {
|
||||
switch (dtypeInt) {
|
||||
case torch_upstream::ScalarType::Float:
|
||||
return Float32Type::get(context);
|
||||
|
@ -110,6 +111,8 @@ Type Torch::getTypeForScalarType(
|
|||
return mlir::ComplexType::get(Float64Type::get(context));
|
||||
case torch_upstream::ScalarType::ComplexDouble:
|
||||
return mlir::ComplexType::get(Float128Type::get(context));
|
||||
case torch_upstream::ScalarType::Undefined:
|
||||
return failure();
|
||||
default:
|
||||
llvm::report_fatal_error("unhandled type for getTypeForScalarType");
|
||||
}
|
||||
|
@ -123,6 +126,7 @@ Torch::getTorchTypeForScalarType(MLIRContext *context,
|
|||
return Torch::FloatType::get(context);
|
||||
case torch_upstream::ScalarType::Long:
|
||||
return Torch::IntType::get(context);
|
||||
case torch_upstream::ScalarType::Undefined:
|
||||
default:
|
||||
return failure();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue