Replace some depreciated uses of cast (#3343)

Contributing towards #3299
pull/3387/head
zjgarvey 2024-05-23 11:01:47 -05:00 committed by GitHub
parent 5bb1a65ec9
commit 27169dcda9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 32 additions and 36 deletions

View File

@ -51,7 +51,7 @@ MlirType torchMlirTorchOptionalTypeGet(MlirType containedType) {
}
MlirType torchMlirTorchOptionalTypeGetContained(MlirType t) {
auto type = unwrap(t).cast<Torch::OptionalType>();
auto type = cast<Torch::OptionalType>(unwrap(t));
return wrap(type.getContainedType());
}
@ -77,12 +77,12 @@ MlirType torchMlirTorchTupleTypeGet(MlirContext context,
}
size_t torchMlirTorchTupleTypeGetNumTypes(MlirType t) {
auto type = unwrap(t).cast<Torch::TupleType>();
auto type = cast<Torch::TupleType>(unwrap(t));
return type.getContainedTypes().size();
}
MlirType torchMlirTorchTupleTypeGetType(MlirType t, intptr_t pos) {
auto type = unwrap(t).cast<Torch::TupleType>();
auto type = cast<Torch::TupleType>(unwrap(t));
return wrap(type.getContainedTypes()[pos]);
}
@ -108,12 +108,12 @@ MlirType torchMlirTorchUnionTypeGet(MlirContext context,
}
size_t torchMlirTorchUnionTypeGetNumTypes(MlirType t) {
auto type = unwrap(t).cast<Torch::UnionType>();
auto type = cast<Torch::UnionType>(unwrap(t));
return type.getContainedTypes().size();
}
MlirType torchMlirTorchUnionTypeGetType(MlirType t, intptr_t pos) {
auto type = unwrap(t).cast<Torch::UnionType>();
auto type = cast<Torch::UnionType>(unwrap(t));
return wrap(type.getContainedTypes()[pos]);
}
@ -134,7 +134,7 @@ MlirType torchMlirTorchListTypeGet(MlirType containedType) {
}
MlirType torchMlirTorchListTypeGetContainedType(MlirType t) {
return wrap(unwrap(t).cast<Torch::ListType>().getContainedType());
return wrap(cast<Torch::ListType>(unwrap(t)).getContainedType());
}
MlirTypeID torchMlirTorchListTypeGetTypeID() {
@ -297,26 +297,26 @@ MlirType torchMlirTorchNonValueTensorTypeGetWithLeastStaticInformation(
MlirType torchMlirTorchNonValueTensorTypeGetFromAttribute(MlirAttribute attr) {
auto attrTensorType =
unwrap(attr).cast<TypedAttr>().getType().cast<RankedTensorType>();
cast<RankedTensorType>(cast<TypedAttr>(unwrap(attr)).getType());
return wrap(Torch::NonValueTensorType::get(attrTensorType.getContext(),
attrTensorType.getShape(),
attrTensorType.getElementType()));
}
int64_t torchMlirTorchNonValueTensorTypeGetRank(MlirType t) {
return unwrap(t).cast<Torch::NonValueTensorType>().getSizes().size();
return cast<Torch::NonValueTensorType>(unwrap(t)).getSizes().size();
}
bool torchMlirTorchNonValueTensorTypeHasSizes(MlirType t) {
return unwrap(t).cast<Torch::NonValueTensorType>().hasSizes();
return cast<Torch::NonValueTensorType>(unwrap(t)).hasSizes();
}
bool torchMlirTorchNonValueTensorTypeHasDtype(MlirType t) {
return unwrap(t).cast<Torch::NonValueTensorType>().hasDtype();
return cast<Torch::NonValueTensorType>(unwrap(t)).hasDtype();
}
int64_t torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes) {
auto tensorType = unwrap(t).cast<Torch::NonValueTensorType>();
auto tensorType = cast<Torch::NonValueTensorType>(unwrap(t));
bool hasSizes = tensorType.hasSizes();
if (!hasSizes)
return -1;
@ -329,7 +329,7 @@ int64_t torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes) {
}
MlirType torchMlirTorchNonValueTensorTypeGetDtype(MlirType t) {
return wrap(unwrap(t).cast<Torch::NonValueTensorType>().getDtype());
return wrap(cast<Torch::NonValueTensorType>(unwrap(t)).getDtype());
}
MlirTypeID torchMlirTorchNonValueTensorTypeGetTypeID() {
@ -364,26 +364,26 @@ MlirType torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(
MlirType torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr) {
auto attrTensorType =
unwrap(attr).cast<TypedAttr>().getType().cast<RankedTensorType>();
cast<RankedTensorType>(cast<TypedAttr>(unwrap(attr)).getType());
return wrap(Torch::ValueTensorType::get(attrTensorType.getContext(),
attrTensorType.getShape(),
attrTensorType.getElementType()));
}
int64_t torchMlirTorchValueTensorTypeGetRank(MlirType t) {
return unwrap(t).cast<Torch::ValueTensorType>().getSizes().size();
return cast<Torch::ValueTensorType>(unwrap(t)).getSizes().size();
}
bool torchMlirTorchValueTensorTypeHasSizes(MlirType t) {
return unwrap(t).cast<Torch::ValueTensorType>().hasSizes();
return cast<Torch::ValueTensorType>(unwrap(t)).hasSizes();
}
bool torchMlirTorchValueTensorTypeHasDtype(MlirType t) {
return unwrap(t).cast<Torch::ValueTensorType>().hasDtype();
return cast<Torch::ValueTensorType>(unwrap(t)).hasDtype();
}
int64_t torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes) {
auto tensorType = unwrap(t).cast<Torch::ValueTensorType>();
auto tensorType = cast<Torch::ValueTensorType>(unwrap(t));
bool hasSizes = tensorType.hasSizes();
if (!hasSizes)
return -1;
@ -396,7 +396,7 @@ int64_t torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes) {
}
MlirType torchMlirTorchValueTensorTypeGetDtype(MlirType t) {
return wrap(unwrap(t).cast<Torch::ValueTensorType>().getDtype());
return wrap(cast<Torch::ValueTensorType>(unwrap(t)).getDtype());
}
MlirTypeID torchMlirTorchValueTensorTypeGetTypeID() {
@ -487,12 +487,12 @@ MlirType torchMlirTorchDictTypeGetChecked(MlirContext context, MlirType keyType,
}
MlirType torchMlirTorchDictTypeGetKeyType(MlirType t) {
auto type = unwrap(t).cast<Torch::DictType>();
auto type = cast<Torch::DictType>(unwrap(t));
return wrap(type.getKeyType());
}
MlirType torchMlirTorchDictTypeGetValueType(MlirType t) {
auto type = unwrap(t).cast<Torch::DictType>();
auto type = cast<Torch::DictType>(unwrap(t));
return wrap(type.getValueType());
}

View File

@ -63,7 +63,7 @@ LogicalResult windowFunctionImpl(OpBinder binder,
// Create an f32 ValueTensorType with thse same size as size, the
// operand
auto shapeOfOperand =
size.getType().dyn_cast<Torch::ValueTensorType>().getOptionalSizes();
dyn_cast<Torch::ValueTensorType>(size.getType()).getOptionalSizes();
auto f32ResultType = rewriter.getType<Torch::ValueTensorType>(
shapeOfOperand, rewriter.getF32Type());
Value periodicSizeFloat = b.create<Torch::AtenToDtypeOp>(
@ -897,8 +897,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
}
if (DenseResourceElementsAttr attr =
binder.op->getAttr("torch.onnx.value")
.dyn_cast_or_null<DenseResourceElementsAttr>()) {
dyn_cast_or_null<DenseResourceElementsAttr>(
binder.op->getAttr("torch.onnx.value"))) {
// Bytes are stored in little endian order. Big endian support will
// require swizzling.
if (!Endian::little) {
@ -926,8 +926,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
return success();
}
if (ElementsAttr attr = binder.op->getAttr("torch.onnx.value")
.dyn_cast_or_null<ElementsAttr>()) {
if (ElementsAttr attr = dyn_cast_or_null<ElementsAttr>(
binder.op->getAttr("torch.onnx.value"))) {
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
binder.op, resultType, attr);
return success();
@ -2283,9 +2283,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.tensorResultType(resultType))
return failure();
Type listElemType =
tensors[0]
.getType()
.cast<Torch::BaseTensorType>()
cast<Torch::BaseTensorType>(tensors[0].getType())
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
/*optionalDtype=*/nullptr);
Type listType = Torch::ListType::get(listElemType);

View File

@ -176,7 +176,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
}
auto conditionType =
conditionTensor.getType().cast<Torch::ValueTensorType>();
cast<Torch::ValueTensorType>(conditionTensor.getType());
if (!conditionType || conditionType.getSizes().size() != 1)
return rewriter.notifyMatchFailure(
binder.op, "condition must have one single element per "

View File

@ -1875,10 +1875,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.op, "Axes should be the same size of starts and ends");
}
auto stepsTy = steps.getType()
.cast<Torch::ValueTensorType>()
.toBuiltinTensor()
.dyn_cast<RankedTensorType>();
auto stepsTy = dyn_cast<RankedTensorType>(
cast<Torch::ValueTensorType>(steps.getType()).toBuiltinTensor());
if (!(stepsTy && stepsTy.getDimSize(0) == endsTy.getDimSize(0)))
return rewriter.notifyMatchFailure(
@ -2804,7 +2802,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
Value modeStrValue;
auto extract = [&rewriter, &binder](Value x, Value v) {
auto xTy = x.getType().cast<Torch::ValueTensorType>();
auto xTy = cast<Torch::ValueTensorType>(x.getType());
Type extractTy = rewriter.getType<Torch::FloatType>();
if (isa<IntegerType>(xTy.getDtype()))
extractTy = rewriter.getType<Torch::IntType>();
@ -2818,7 +2816,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
auto sizes =
dyn_cast<Torch::ValueTensorType>(operand.getType()).getSizes();
Torch::BaseTensorType operandType =
operand.getType().cast<Torch::BaseTensorType>();
cast<Torch::BaseTensorType>(operand.getType());
SmallVector<int64_t> selectSizes;
selectSizes.push_back(1);
@ -2835,7 +2833,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
Value item = extract(operand, ext);
itemList.push_back(item);
}
auto xTy = operand.getType().cast<Torch::ValueTensorType>();
auto xTy = cast<Torch::ValueTensorType>(operand.getType());
Value ValueList;
if (isa<IntegerType>(xTy.getDtype())) {
ValueList = rewriter.create<Torch::PrimListConstructOp>(