mirror of https://github.com/llvm/torch-mlir
parent
5bb1a65ec9
commit
27169dcda9
|
@ -51,7 +51,7 @@ MlirType torchMlirTorchOptionalTypeGet(MlirType containedType) {
|
|||
}
|
||||
|
||||
MlirType torchMlirTorchOptionalTypeGetContained(MlirType t) {
|
||||
auto type = unwrap(t).cast<Torch::OptionalType>();
|
||||
auto type = cast<Torch::OptionalType>(unwrap(t));
|
||||
return wrap(type.getContainedType());
|
||||
}
|
||||
|
||||
|
@ -77,12 +77,12 @@ MlirType torchMlirTorchTupleTypeGet(MlirContext context,
|
|||
}
|
||||
|
||||
size_t torchMlirTorchTupleTypeGetNumTypes(MlirType t) {
|
||||
auto type = unwrap(t).cast<Torch::TupleType>();
|
||||
auto type = cast<Torch::TupleType>(unwrap(t));
|
||||
return type.getContainedTypes().size();
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchTupleTypeGetType(MlirType t, intptr_t pos) {
|
||||
auto type = unwrap(t).cast<Torch::TupleType>();
|
||||
auto type = cast<Torch::TupleType>(unwrap(t));
|
||||
return wrap(type.getContainedTypes()[pos]);
|
||||
}
|
||||
|
||||
|
@ -108,12 +108,12 @@ MlirType torchMlirTorchUnionTypeGet(MlirContext context,
|
|||
}
|
||||
|
||||
size_t torchMlirTorchUnionTypeGetNumTypes(MlirType t) {
|
||||
auto type = unwrap(t).cast<Torch::UnionType>();
|
||||
auto type = cast<Torch::UnionType>(unwrap(t));
|
||||
return type.getContainedTypes().size();
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchUnionTypeGetType(MlirType t, intptr_t pos) {
|
||||
auto type = unwrap(t).cast<Torch::UnionType>();
|
||||
auto type = cast<Torch::UnionType>(unwrap(t));
|
||||
return wrap(type.getContainedTypes()[pos]);
|
||||
}
|
||||
|
||||
|
@ -134,7 +134,7 @@ MlirType torchMlirTorchListTypeGet(MlirType containedType) {
|
|||
}
|
||||
|
||||
MlirType torchMlirTorchListTypeGetContainedType(MlirType t) {
|
||||
return wrap(unwrap(t).cast<Torch::ListType>().getContainedType());
|
||||
return wrap(cast<Torch::ListType>(unwrap(t)).getContainedType());
|
||||
}
|
||||
|
||||
MlirTypeID torchMlirTorchListTypeGetTypeID() {
|
||||
|
@ -297,26 +297,26 @@ MlirType torchMlirTorchNonValueTensorTypeGetWithLeastStaticInformation(
|
|||
|
||||
MlirType torchMlirTorchNonValueTensorTypeGetFromAttribute(MlirAttribute attr) {
|
||||
auto attrTensorType =
|
||||
unwrap(attr).cast<TypedAttr>().getType().cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(cast<TypedAttr>(unwrap(attr)).getType());
|
||||
return wrap(Torch::NonValueTensorType::get(attrTensorType.getContext(),
|
||||
attrTensorType.getShape(),
|
||||
attrTensorType.getElementType()));
|
||||
}
|
||||
|
||||
int64_t torchMlirTorchNonValueTensorTypeGetRank(MlirType t) {
|
||||
return unwrap(t).cast<Torch::NonValueTensorType>().getSizes().size();
|
||||
return cast<Torch::NonValueTensorType>(unwrap(t)).getSizes().size();
|
||||
}
|
||||
|
||||
bool torchMlirTorchNonValueTensorTypeHasSizes(MlirType t) {
|
||||
return unwrap(t).cast<Torch::NonValueTensorType>().hasSizes();
|
||||
return cast<Torch::NonValueTensorType>(unwrap(t)).hasSizes();
|
||||
}
|
||||
|
||||
bool torchMlirTorchNonValueTensorTypeHasDtype(MlirType t) {
|
||||
return unwrap(t).cast<Torch::NonValueTensorType>().hasDtype();
|
||||
return cast<Torch::NonValueTensorType>(unwrap(t)).hasDtype();
|
||||
}
|
||||
|
||||
int64_t torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes) {
|
||||
auto tensorType = unwrap(t).cast<Torch::NonValueTensorType>();
|
||||
auto tensorType = cast<Torch::NonValueTensorType>(unwrap(t));
|
||||
bool hasSizes = tensorType.hasSizes();
|
||||
if (!hasSizes)
|
||||
return -1;
|
||||
|
@ -329,7 +329,7 @@ int64_t torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes) {
|
|||
}
|
||||
|
||||
MlirType torchMlirTorchNonValueTensorTypeGetDtype(MlirType t) {
|
||||
return wrap(unwrap(t).cast<Torch::NonValueTensorType>().getDtype());
|
||||
return wrap(cast<Torch::NonValueTensorType>(unwrap(t)).getDtype());
|
||||
}
|
||||
|
||||
MlirTypeID torchMlirTorchNonValueTensorTypeGetTypeID() {
|
||||
|
@ -364,26 +364,26 @@ MlirType torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(
|
|||
|
||||
MlirType torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr) {
|
||||
auto attrTensorType =
|
||||
unwrap(attr).cast<TypedAttr>().getType().cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(cast<TypedAttr>(unwrap(attr)).getType());
|
||||
return wrap(Torch::ValueTensorType::get(attrTensorType.getContext(),
|
||||
attrTensorType.getShape(),
|
||||
attrTensorType.getElementType()));
|
||||
}
|
||||
|
||||
int64_t torchMlirTorchValueTensorTypeGetRank(MlirType t) {
|
||||
return unwrap(t).cast<Torch::ValueTensorType>().getSizes().size();
|
||||
return cast<Torch::ValueTensorType>(unwrap(t)).getSizes().size();
|
||||
}
|
||||
|
||||
bool torchMlirTorchValueTensorTypeHasSizes(MlirType t) {
|
||||
return unwrap(t).cast<Torch::ValueTensorType>().hasSizes();
|
||||
return cast<Torch::ValueTensorType>(unwrap(t)).hasSizes();
|
||||
}
|
||||
|
||||
bool torchMlirTorchValueTensorTypeHasDtype(MlirType t) {
|
||||
return unwrap(t).cast<Torch::ValueTensorType>().hasDtype();
|
||||
return cast<Torch::ValueTensorType>(unwrap(t)).hasDtype();
|
||||
}
|
||||
|
||||
int64_t torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes) {
|
||||
auto tensorType = unwrap(t).cast<Torch::ValueTensorType>();
|
||||
auto tensorType = cast<Torch::ValueTensorType>(unwrap(t));
|
||||
bool hasSizes = tensorType.hasSizes();
|
||||
if (!hasSizes)
|
||||
return -1;
|
||||
|
@ -396,7 +396,7 @@ int64_t torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes) {
|
|||
}
|
||||
|
||||
MlirType torchMlirTorchValueTensorTypeGetDtype(MlirType t) {
|
||||
return wrap(unwrap(t).cast<Torch::ValueTensorType>().getDtype());
|
||||
return wrap(cast<Torch::ValueTensorType>(unwrap(t)).getDtype());
|
||||
}
|
||||
|
||||
MlirTypeID torchMlirTorchValueTensorTypeGetTypeID() {
|
||||
|
@ -487,12 +487,12 @@ MlirType torchMlirTorchDictTypeGetChecked(MlirContext context, MlirType keyType,
|
|||
}
|
||||
|
||||
MlirType torchMlirTorchDictTypeGetKeyType(MlirType t) {
|
||||
auto type = unwrap(t).cast<Torch::DictType>();
|
||||
auto type = cast<Torch::DictType>(unwrap(t));
|
||||
return wrap(type.getKeyType());
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchDictTypeGetValueType(MlirType t) {
|
||||
auto type = unwrap(t).cast<Torch::DictType>();
|
||||
auto type = cast<Torch::DictType>(unwrap(t));
|
||||
return wrap(type.getValueType());
|
||||
}
|
||||
|
||||
|
|
|
@ -63,7 +63,7 @@ LogicalResult windowFunctionImpl(OpBinder binder,
|
|||
// Create an f32 ValueTensorType with thse same size as size, the
|
||||
// operand
|
||||
auto shapeOfOperand =
|
||||
size.getType().dyn_cast<Torch::ValueTensorType>().getOptionalSizes();
|
||||
dyn_cast<Torch::ValueTensorType>(size.getType()).getOptionalSizes();
|
||||
auto f32ResultType = rewriter.getType<Torch::ValueTensorType>(
|
||||
shapeOfOperand, rewriter.getF32Type());
|
||||
Value periodicSizeFloat = b.create<Torch::AtenToDtypeOp>(
|
||||
|
@ -897,8 +897,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
}
|
||||
|
||||
if (DenseResourceElementsAttr attr =
|
||||
binder.op->getAttr("torch.onnx.value")
|
||||
.dyn_cast_or_null<DenseResourceElementsAttr>()) {
|
||||
dyn_cast_or_null<DenseResourceElementsAttr>(
|
||||
binder.op->getAttr("torch.onnx.value"))) {
|
||||
// Bytes are stored in little endian order. Big endian support will
|
||||
// require swizzling.
|
||||
if (!Endian::little) {
|
||||
|
@ -926,8 +926,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
return success();
|
||||
}
|
||||
|
||||
if (ElementsAttr attr = binder.op->getAttr("torch.onnx.value")
|
||||
.dyn_cast_or_null<ElementsAttr>()) {
|
||||
if (ElementsAttr attr = dyn_cast_or_null<ElementsAttr>(
|
||||
binder.op->getAttr("torch.onnx.value"))) {
|
||||
rewriter.replaceOpWithNewOp<Torch::ValueTensorLiteralOp>(
|
||||
binder.op, resultType, attr);
|
||||
return success();
|
||||
|
@ -2283,9 +2283,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
Type listElemType =
|
||||
tensors[0]
|
||||
.getType()
|
||||
.cast<Torch::BaseTensorType>()
|
||||
cast<Torch::BaseTensorType>(tensors[0].getType())
|
||||
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
|
||||
/*optionalDtype=*/nullptr);
|
||||
Type listType = Torch::ListType::get(listElemType);
|
||||
|
|
|
@ -176,7 +176,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
}
|
||||
|
||||
auto conditionType =
|
||||
conditionTensor.getType().cast<Torch::ValueTensorType>();
|
||||
cast<Torch::ValueTensorType>(conditionTensor.getType());
|
||||
if (!conditionType || conditionType.getSizes().size() != 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "condition must have one single element per "
|
||||
|
|
|
@ -1875,10 +1875,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
binder.op, "Axes should be the same size of starts and ends");
|
||||
}
|
||||
|
||||
auto stepsTy = steps.getType()
|
||||
.cast<Torch::ValueTensorType>()
|
||||
.toBuiltinTensor()
|
||||
.dyn_cast<RankedTensorType>();
|
||||
auto stepsTy = dyn_cast<RankedTensorType>(
|
||||
cast<Torch::ValueTensorType>(steps.getType()).toBuiltinTensor());
|
||||
|
||||
if (!(stepsTy && stepsTy.getDimSize(0) == endsTy.getDimSize(0)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -2804,7 +2802,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
Value modeStrValue;
|
||||
|
||||
auto extract = [&rewriter, &binder](Value x, Value v) {
|
||||
auto xTy = x.getType().cast<Torch::ValueTensorType>();
|
||||
auto xTy = cast<Torch::ValueTensorType>(x.getType());
|
||||
Type extractTy = rewriter.getType<Torch::FloatType>();
|
||||
if (isa<IntegerType>(xTy.getDtype()))
|
||||
extractTy = rewriter.getType<Torch::IntType>();
|
||||
|
@ -2818,7 +2816,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
auto sizes =
|
||||
dyn_cast<Torch::ValueTensorType>(operand.getType()).getSizes();
|
||||
Torch::BaseTensorType operandType =
|
||||
operand.getType().cast<Torch::BaseTensorType>();
|
||||
cast<Torch::BaseTensorType>(operand.getType());
|
||||
|
||||
SmallVector<int64_t> selectSizes;
|
||||
selectSizes.push_back(1);
|
||||
|
@ -2835,7 +2833,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
Value item = extract(operand, ext);
|
||||
itemList.push_back(item);
|
||||
}
|
||||
auto xTy = operand.getType().cast<Torch::ValueTensorType>();
|
||||
auto xTy = cast<Torch::ValueTensorType>(operand.getType());
|
||||
Value ValueList;
|
||||
if (isa<IntegerType>(xTy.getDtype())) {
|
||||
ValueList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
|
|
Loading…
Reference in New Issue