mirror of https://github.com/llvm/torch-mlir
[torch] Unpacking sometimes misses shape inference (#3609)
It is possible that the unpacked tensor does not match the same inferred shapes. This is pretty common when ingesting form the `onnx` frontend.pull/3622/head
parent
f91f816336
commit
fd98476f77
|
@ -3290,7 +3290,18 @@ void PrimListUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
if (op->getNumResults() != listConstruct.getElements().size())
|
if (op->getNumResults() != listConstruct.getElements().size())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
rewriter.replaceOp(op, listConstruct.getElements());
|
SmallVector<Value> unpacked;
|
||||||
|
for (int i = 0, s = op->getNumResults(); i < s; ++i) {
|
||||||
|
auto element = listConstruct.getElements()[i];
|
||||||
|
if (element.getType() != op->getResult(i).getType()) {
|
||||||
|
element = rewriter.create<TensorStaticInfoCastOp>(
|
||||||
|
op.getLoc(), op->getResult(i).getType(), element);
|
||||||
|
}
|
||||||
|
|
||||||
|
unpacked.push_back(element);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, unpacked);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -1890,6 +1890,18 @@ func.func @prim.ListUnpack$fold_list(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !t
|
||||||
return %1#0, %1#1 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>
|
return %1#0, %1#1 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @prim.ListUnpack$fold_list_cast(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,3],f32>,
|
||||||
|
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[2,3],f32>) -> (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) {
|
||||||
|
// CHECK: %[[CAST0:.+]] = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[2,3],f32> to !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: %[[CAST1:.+]] = torch.tensor_static_info_cast %arg1 : !torch.vtensor<[2,3],f32> to !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: return %[[CAST0]], %[[CAST1]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>
|
||||||
|
func.func @prim.ListUnpack$fold_list_cast(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !torch.vtensor<[2,3],f32>) -> (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) {
|
||||||
|
%0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) -> !torch.list<vtensor>
|
||||||
|
%1:2 = torch.prim.ListUnpack %0 : !torch.list<vtensor> -> !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>
|
||||||
|
return %1#0, %1#1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
|
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
|
||||||
// CHECK: %[[CST:.*]] = torch.vtensor.literal(dense<3> : tensor<si64>) : !torch.vtensor<[],si64>
|
// CHECK: %[[CST:.*]] = torch.vtensor.literal(dense<3> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||||
// CHECK: return %[[CST]] : !torch.vtensor<[],si64>
|
// CHECK: return %[[CST]] : !torch.vtensor<[],si64>
|
||||||
|
|
Loading…
Reference in New Issue