diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index db2988f25..9ecf0e3e2 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -3768,15 +3768,18 @@ OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) { Type resultType = getResult().getType(); BaseTensorType resultTensorType = resultType.dyn_cast(); - if (!resultTensorType || !resultTensorType.hasDtype()) { + if (!resultTensorType || !resultTensorType.hasDtype() || + !resultTensorType.hasSizes()) { return nullptr; } - int64_t ct = sizes.size(); - if (resultTensorType.getSizes().size() != 1) - return nullptr; - if (resultTensorType.getSizes()[0] != ct) - return nullptr; + for (auto sz : sizes) + if (sz == Torch::kUnknownSize || sz < 0) + return nullptr; + + for (auto sz : resultTensorType.getSizes()) + if (sz == Torch::kUnknownSize || sz < 0) + return nullptr; ShapedType shapedty = mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType @@ -3804,15 +3807,18 @@ OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) { Type resultType = getResult().getType(); BaseTensorType resultTensorType = resultType.dyn_cast(); - if (!resultTensorType || !resultTensorType.hasDtype()) { + if (!resultTensorType || !resultTensorType.hasDtype() || + !resultTensorType.hasSizes()) { return nullptr; } - int64_t ct = sizes.size(); - if (resultTensorType.getSizes().size() != 1) - return nullptr; - if (resultTensorType.getSizes()[0] != ct) - return nullptr; + for (auto sz : sizes) + if (sz == Torch::kUnknownSize || sz < 0) + return nullptr; + + for (auto sz : resultTensorType.getSizes()) + if (sz == Torch::kUnknownSize || sz < 0) + return nullptr; ShapedType shapedty = mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType @@ -3842,22 +3848,22 @@ OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) { Type resultType = getResult().getType(); BaseTensorType resultTensorType = resultType.dyn_cast(); - if (!resultTensorType || !resultTensorType.hasDtype()) { + if (!resultTensorType || !resultTensorType.hasDtype() || + !resultTensorType.hasSizes()) { return nullptr; } - int64_t ct = sizes.size(); - if (resultTensorType.getSizes().size() != 1) - return nullptr; - if (resultTensorType.getSizes()[0] != ct) - return nullptr; + for (auto sz : sizes) + if (sz == Torch::kUnknownSize || sz < 0) + return nullptr; + + for (auto sz : resultTensorType.getSizes()) + if (sz == Torch::kUnknownSize || sz < 0) + return nullptr; ShapedType shapedty = - mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType - sizes, resultTensorType.getDtype()); - if (!shapedty) { - return nullptr; - } + mlir::RankedTensorType::get(sizes, resultTensorType.getDtype()); + auto elementType = shapedty.getElementType(); if (elementType.isa()) { int64_t value = 0;