mirror of https://github.com/llvm/torch-mlir
[torch] Update folders for splat operators (#3012)
Splat operators required the output is 1-D. This was not a required restriction and was loosened to 2d.pull/2920/head
parent
4b1e87ce67
commit
e78c99e74e
|
@ -3768,15 +3768,18 @@ OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) {
|
|||
|
||||
Type resultType = getResult().getType();
|
||||
BaseTensorType resultTensorType = resultType.dyn_cast<BaseTensorType>();
|
||||
if (!resultTensorType || !resultTensorType.hasDtype()) {
|
||||
if (!resultTensorType || !resultTensorType.hasDtype() ||
|
||||
!resultTensorType.hasSizes()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int64_t ct = sizes.size();
|
||||
if (resultTensorType.getSizes().size() != 1)
|
||||
return nullptr;
|
||||
if (resultTensorType.getSizes()[0] != ct)
|
||||
return nullptr;
|
||||
for (auto sz : sizes)
|
||||
if (sz == Torch::kUnknownSize || sz < 0)
|
||||
return nullptr;
|
||||
|
||||
for (auto sz : resultTensorType.getSizes())
|
||||
if (sz == Torch::kUnknownSize || sz < 0)
|
||||
return nullptr;
|
||||
|
||||
ShapedType shapedty =
|
||||
mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType
|
||||
|
@ -3804,15 +3807,18 @@ OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) {
|
|||
|
||||
Type resultType = getResult().getType();
|
||||
BaseTensorType resultTensorType = resultType.dyn_cast<BaseTensorType>();
|
||||
if (!resultTensorType || !resultTensorType.hasDtype()) {
|
||||
if (!resultTensorType || !resultTensorType.hasDtype() ||
|
||||
!resultTensorType.hasSizes()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int64_t ct = sizes.size();
|
||||
if (resultTensorType.getSizes().size() != 1)
|
||||
return nullptr;
|
||||
if (resultTensorType.getSizes()[0] != ct)
|
||||
return nullptr;
|
||||
for (auto sz : sizes)
|
||||
if (sz == Torch::kUnknownSize || sz < 0)
|
||||
return nullptr;
|
||||
|
||||
for (auto sz : resultTensorType.getSizes())
|
||||
if (sz == Torch::kUnknownSize || sz < 0)
|
||||
return nullptr;
|
||||
|
||||
ShapedType shapedty =
|
||||
mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType
|
||||
|
@ -3842,22 +3848,22 @@ OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) {
|
|||
|
||||
Type resultType = getResult().getType();
|
||||
BaseTensorType resultTensorType = resultType.dyn_cast<BaseTensorType>();
|
||||
if (!resultTensorType || !resultTensorType.hasDtype()) {
|
||||
if (!resultTensorType || !resultTensorType.hasDtype() ||
|
||||
!resultTensorType.hasSizes()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int64_t ct = sizes.size();
|
||||
if (resultTensorType.getSizes().size() != 1)
|
||||
return nullptr;
|
||||
if (resultTensorType.getSizes()[0] != ct)
|
||||
return nullptr;
|
||||
for (auto sz : sizes)
|
||||
if (sz == Torch::kUnknownSize || sz < 0)
|
||||
return nullptr;
|
||||
|
||||
for (auto sz : resultTensorType.getSizes())
|
||||
if (sz == Torch::kUnknownSize || sz < 0)
|
||||
return nullptr;
|
||||
|
||||
ShapedType shapedty =
|
||||
mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType
|
||||
sizes, resultTensorType.getDtype());
|
||||
if (!shapedty) {
|
||||
return nullptr;
|
||||
}
|
||||
mlir::RankedTensorType::get(sizes, resultTensorType.getDtype());
|
||||
|
||||
auto elementType = shapedty.getElementType();
|
||||
if (elementType.isa<IntegerType>()) {
|
||||
int64_t value = 0;
|
||||
|
|
Loading…
Reference in New Issue