[torch] Update folders for splat operators (#3012)

Splat operators required the output is 1-D. This was not a required
restriction and was loosened to 2d.
pull/2920/head
Rob Suderman 2024-03-11 13:45:49 -07:00 committed by GitHub
parent 4b1e87ce67
commit e78c99e74e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 29 additions and 23 deletions

View File

@ -3768,15 +3768,18 @@ OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) {
Type resultType = getResult().getType();
BaseTensorType resultTensorType = resultType.dyn_cast<BaseTensorType>();
if (!resultTensorType || !resultTensorType.hasDtype()) {
if (!resultTensorType || !resultTensorType.hasDtype() ||
!resultTensorType.hasSizes()) {
return nullptr;
}
int64_t ct = sizes.size();
if (resultTensorType.getSizes().size() != 1)
return nullptr;
if (resultTensorType.getSizes()[0] != ct)
return nullptr;
for (auto sz : sizes)
if (sz == Torch::kUnknownSize || sz < 0)
return nullptr;
for (auto sz : resultTensorType.getSizes())
if (sz == Torch::kUnknownSize || sz < 0)
return nullptr;
ShapedType shapedty =
mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType
@ -3804,15 +3807,18 @@ OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) {
Type resultType = getResult().getType();
BaseTensorType resultTensorType = resultType.dyn_cast<BaseTensorType>();
if (!resultTensorType || !resultTensorType.hasDtype()) {
if (!resultTensorType || !resultTensorType.hasDtype() ||
!resultTensorType.hasSizes()) {
return nullptr;
}
int64_t ct = sizes.size();
if (resultTensorType.getSizes().size() != 1)
return nullptr;
if (resultTensorType.getSizes()[0] != ct)
return nullptr;
for (auto sz : sizes)
if (sz == Torch::kUnknownSize || sz < 0)
return nullptr;
for (auto sz : resultTensorType.getSizes())
if (sz == Torch::kUnknownSize || sz < 0)
return nullptr;
ShapedType shapedty =
mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType
@ -3842,22 +3848,22 @@ OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) {
Type resultType = getResult().getType();
BaseTensorType resultTensorType = resultType.dyn_cast<BaseTensorType>();
if (!resultTensorType || !resultTensorType.hasDtype()) {
if (!resultTensorType || !resultTensorType.hasDtype() ||
!resultTensorType.hasSizes()) {
return nullptr;
}
int64_t ct = sizes.size();
if (resultTensorType.getSizes().size() != 1)
return nullptr;
if (resultTensorType.getSizes()[0] != ct)
return nullptr;
for (auto sz : sizes)
if (sz == Torch::kUnknownSize || sz < 0)
return nullptr;
for (auto sz : resultTensorType.getSizes())
if (sz == Torch::kUnknownSize || sz < 0)
return nullptr;
ShapedType shapedty =
mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType
sizes, resultTensorType.getDtype());
if (!shapedty) {
return nullptr;
}
mlir::RankedTensorType::get(sizes, resultTensorType.getDtype());
auto elementType = shapedty.getElementType();
if (elementType.isa<IntegerType>()) {
int64_t value = 0;