diff --git a/lib/CAPI/TorchTypes.cpp b/lib/CAPI/TorchTypes.cpp index f4a9ca032..edc85c7e7 100644 --- a/lib/CAPI/TorchTypes.cpp +++ b/lib/CAPI/TorchTypes.cpp @@ -51,7 +51,7 @@ MlirType torchMlirTorchOptionalTypeGet(MlirType containedType) { } MlirType torchMlirTorchOptionalTypeGetContained(MlirType t) { - auto type = unwrap(t).cast(); + auto type = cast(unwrap(t)); return wrap(type.getContainedType()); } @@ -77,12 +77,12 @@ MlirType torchMlirTorchTupleTypeGet(MlirContext context, } size_t torchMlirTorchTupleTypeGetNumTypes(MlirType t) { - auto type = unwrap(t).cast(); + auto type = cast(unwrap(t)); return type.getContainedTypes().size(); } MlirType torchMlirTorchTupleTypeGetType(MlirType t, intptr_t pos) { - auto type = unwrap(t).cast(); + auto type = cast(unwrap(t)); return wrap(type.getContainedTypes()[pos]); } @@ -108,12 +108,12 @@ MlirType torchMlirTorchUnionTypeGet(MlirContext context, } size_t torchMlirTorchUnionTypeGetNumTypes(MlirType t) { - auto type = unwrap(t).cast(); + auto type = cast(unwrap(t)); return type.getContainedTypes().size(); } MlirType torchMlirTorchUnionTypeGetType(MlirType t, intptr_t pos) { - auto type = unwrap(t).cast(); + auto type = cast(unwrap(t)); return wrap(type.getContainedTypes()[pos]); } @@ -134,7 +134,7 @@ MlirType torchMlirTorchListTypeGet(MlirType containedType) { } MlirType torchMlirTorchListTypeGetContainedType(MlirType t) { - return wrap(unwrap(t).cast().getContainedType()); + return wrap(cast(unwrap(t)).getContainedType()); } MlirTypeID torchMlirTorchListTypeGetTypeID() { @@ -297,26 +297,26 @@ MlirType torchMlirTorchNonValueTensorTypeGetWithLeastStaticInformation( MlirType torchMlirTorchNonValueTensorTypeGetFromAttribute(MlirAttribute attr) { auto attrTensorType = - unwrap(attr).cast().getType().cast(); + cast(cast(unwrap(attr)).getType()); return wrap(Torch::NonValueTensorType::get(attrTensorType.getContext(), attrTensorType.getShape(), attrTensorType.getElementType())); } int64_t torchMlirTorchNonValueTensorTypeGetRank(MlirType t) { - return unwrap(t).cast().getSizes().size(); + return cast(unwrap(t)).getSizes().size(); } bool torchMlirTorchNonValueTensorTypeHasSizes(MlirType t) { - return unwrap(t).cast().hasSizes(); + return cast(unwrap(t)).hasSizes(); } bool torchMlirTorchNonValueTensorTypeHasDtype(MlirType t) { - return unwrap(t).cast().hasDtype(); + return cast(unwrap(t)).hasDtype(); } int64_t torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes) { - auto tensorType = unwrap(t).cast(); + auto tensorType = cast(unwrap(t)); bool hasSizes = tensorType.hasSizes(); if (!hasSizes) return -1; @@ -329,7 +329,7 @@ int64_t torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes) { } MlirType torchMlirTorchNonValueTensorTypeGetDtype(MlirType t) { - return wrap(unwrap(t).cast().getDtype()); + return wrap(cast(unwrap(t)).getDtype()); } MlirTypeID torchMlirTorchNonValueTensorTypeGetTypeID() { @@ -364,26 +364,26 @@ MlirType torchMlirTorchValueTensorTypeGetWithLeastStaticInformation( MlirType torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr) { auto attrTensorType = - unwrap(attr).cast().getType().cast(); + cast(cast(unwrap(attr)).getType()); return wrap(Torch::ValueTensorType::get(attrTensorType.getContext(), attrTensorType.getShape(), attrTensorType.getElementType())); } int64_t torchMlirTorchValueTensorTypeGetRank(MlirType t) { - return unwrap(t).cast().getSizes().size(); + return cast(unwrap(t)).getSizes().size(); } bool torchMlirTorchValueTensorTypeHasSizes(MlirType t) { - return unwrap(t).cast().hasSizes(); + return cast(unwrap(t)).hasSizes(); } bool torchMlirTorchValueTensorTypeHasDtype(MlirType t) { - return unwrap(t).cast().hasDtype(); + return cast(unwrap(t)).hasDtype(); } int64_t torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes) { - auto tensorType = unwrap(t).cast(); + auto tensorType = cast(unwrap(t)); bool hasSizes = tensorType.hasSizes(); if (!hasSizes) return -1; @@ -396,7 +396,7 @@ int64_t torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes) { } MlirType torchMlirTorchValueTensorTypeGetDtype(MlirType t) { - return wrap(unwrap(t).cast().getDtype()); + return wrap(cast(unwrap(t)).getDtype()); } MlirTypeID torchMlirTorchValueTensorTypeGetTypeID() { @@ -487,12 +487,12 @@ MlirType torchMlirTorchDictTypeGetChecked(MlirContext context, MlirType keyType, } MlirType torchMlirTorchDictTypeGetKeyType(MlirType t) { - auto type = unwrap(t).cast(); + auto type = cast(unwrap(t)); return wrap(type.getKeyType()); } MlirType torchMlirTorchDictTypeGetValueType(MlirType t) { - auto type = unwrap(t).cast(); + auto type = cast(unwrap(t)); return wrap(type.getValueType()); } diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 889a5fe88..2d074ec59 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -63,7 +63,7 @@ LogicalResult windowFunctionImpl(OpBinder binder, // Create an f32 ValueTensorType with thse same size as size, the // operand auto shapeOfOperand = - size.getType().dyn_cast().getOptionalSizes(); + dyn_cast(size.getType()).getOptionalSizes(); auto f32ResultType = rewriter.getType( shapeOfOperand, rewriter.getF32Type()); Value periodicSizeFloat = b.create( @@ -897,8 +897,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( } if (DenseResourceElementsAttr attr = - binder.op->getAttr("torch.onnx.value") - .dyn_cast_or_null()) { + dyn_cast_or_null( + binder.op->getAttr("torch.onnx.value"))) { // Bytes are stored in little endian order. Big endian support will // require swizzling. if (!Endian::little) { @@ -926,8 +926,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return success(); } - if (ElementsAttr attr = binder.op->getAttr("torch.onnx.value") - .dyn_cast_or_null()) { + if (ElementsAttr attr = dyn_cast_or_null( + binder.op->getAttr("torch.onnx.value"))) { rewriter.replaceOpWithNewOp( binder.op, resultType, attr); return success(); @@ -2283,9 +2283,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.tensorResultType(resultType)) return failure(); Type listElemType = - tensors[0] - .getType() - .cast() + cast(tensors[0].getType()) .getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr); Type listType = Torch::ListType::get(listElemType); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index cf14fc026..cfa170c2e 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -176,7 +176,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } auto conditionType = - conditionTensor.getType().cast(); + cast(conditionTensor.getType()); if (!conditionType || conditionType.getSizes().size() != 1) return rewriter.notifyMatchFailure( binder.op, "condition must have one single element per " diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 576534d18..b1ef07a8b 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1875,10 +1875,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, "Axes should be the same size of starts and ends"); } - auto stepsTy = steps.getType() - .cast() - .toBuiltinTensor() - .dyn_cast(); + auto stepsTy = dyn_cast( + cast(steps.getType()).toBuiltinTensor()); if (!(stepsTy && stepsTy.getDimSize(0) == endsTy.getDimSize(0))) return rewriter.notifyMatchFailure( @@ -2804,7 +2802,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value modeStrValue; auto extract = [&rewriter, &binder](Value x, Value v) { - auto xTy = x.getType().cast(); + auto xTy = cast(x.getType()); Type extractTy = rewriter.getType(); if (isa(xTy.getDtype())) extractTy = rewriter.getType(); @@ -2818,7 +2816,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( auto sizes = dyn_cast(operand.getType()).getSizes(); Torch::BaseTensorType operandType = - operand.getType().cast(); + cast(operand.getType()); SmallVector selectSizes; selectSizes.push_back(1); @@ -2835,7 +2833,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value item = extract(operand, ext); itemList.push_back(item); } - auto xTy = operand.getType().cast(); + auto xTy = cast(operand.getType()); Value ValueList; if (isa(xTy.getDtype())) { ValueList = rewriter.create(