From e78c99e74e115b4733f06f2ed186f74514982f74 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 11 Mar 2024 13:45:49 -0700 Subject: [PATCH] [torch] Update folders for splat operators (#3012) Splat operators required the output is 1-D. This was not a required restriction and was loosened to 2d. --- lib/Dialect/Torch/IR/TorchOps.cpp | 52 +++++++++++++++++-------------- 1 file changed, 29 insertions(+), 23 deletions(-) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index db2988f25..9ecf0e3e2 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -3768,15 +3768,18 @@ OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) { Type resultType = getResult().getType(); BaseTensorType resultTensorType = resultType.dyn_cast(); - if (!resultTensorType || !resultTensorType.hasDtype()) { + if (!resultTensorType || !resultTensorType.hasDtype() || + !resultTensorType.hasSizes()) { return nullptr; } - int64_t ct = sizes.size(); - if (resultTensorType.getSizes().size() != 1) - return nullptr; - if (resultTensorType.getSizes()[0] != ct) - return nullptr; + for (auto sz : sizes) + if (sz == Torch::kUnknownSize || sz < 0) + return nullptr; + + for (auto sz : resultTensorType.getSizes()) + if (sz == Torch::kUnknownSize || sz < 0) + return nullptr; ShapedType shapedty = mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType @@ -3804,15 +3807,18 @@ OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) { Type resultType = getResult().getType(); BaseTensorType resultTensorType = resultType.dyn_cast(); - if (!resultTensorType || !resultTensorType.hasDtype()) { + if (!resultTensorType || !resultTensorType.hasDtype() || + !resultTensorType.hasSizes()) { return nullptr; } - int64_t ct = sizes.size(); - if (resultTensorType.getSizes().size() != 1) - return nullptr; - if (resultTensorType.getSizes()[0] != ct) - return nullptr; + for (auto sz : sizes) + if (sz == Torch::kUnknownSize || sz < 0) + return nullptr; + + for (auto sz : resultTensorType.getSizes()) + if (sz == Torch::kUnknownSize || sz < 0) + return nullptr; ShapedType shapedty = mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType @@ -3842,22 +3848,22 @@ OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) { Type resultType = getResult().getType(); BaseTensorType resultTensorType = resultType.dyn_cast(); - if (!resultTensorType || !resultTensorType.hasDtype()) { + if (!resultTensorType || !resultTensorType.hasDtype() || + !resultTensorType.hasSizes()) { return nullptr; } - int64_t ct = sizes.size(); - if (resultTensorType.getSizes().size() != 1) - return nullptr; - if (resultTensorType.getSizes()[0] != ct) - return nullptr; + for (auto sz : sizes) + if (sz == Torch::kUnknownSize || sz < 0) + return nullptr; + + for (auto sz : resultTensorType.getSizes()) + if (sz == Torch::kUnknownSize || sz < 0) + return nullptr; ShapedType shapedty = - mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType - sizes, resultTensorType.getDtype()); - if (!shapedty) { - return nullptr; - } + mlir::RankedTensorType::get(sizes, resultTensorType.getDtype()); + auto elementType = shapedty.getElementType(); if (elementType.isa()) { int64_t value = 0;