[torch] Unpacking sometimes misses shape inference (#3609)

It is possible that the unpacked tensor does not match the same inferred
shapes. This is pretty common when ingesting form the `onnx` frontend.
pull/3622/head
Rob Suderman 2024-08-08 16:17:31 -07:00 committed by GitHub
parent f91f816336
commit fd98476f77
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 1 deletions

View File

@ -3290,7 +3290,18 @@ void PrimListUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
if (op->getNumResults() != listConstruct.getElements().size()) if (op->getNumResults() != listConstruct.getElements().size())
return failure(); return failure();
rewriter.replaceOp(op, listConstruct.getElements()); SmallVector<Value> unpacked;
for (int i = 0, s = op->getNumResults(); i < s; ++i) {
auto element = listConstruct.getElements()[i];
if (element.getType() != op->getResult(i).getType()) {
element = rewriter.create<TensorStaticInfoCastOp>(
op.getLoc(), op->getResult(i).getType(), element);
}
unpacked.push_back(element);
}
rewriter.replaceOp(op, unpacked);
return success(); return success();
}); });
} }

View File

@ -1890,6 +1890,18 @@ func.func @prim.ListUnpack$fold_list(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !t
return %1#0, %1#1 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32> return %1#0, %1#1 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>
} }
// CHECK-LABEL: func.func @prim.ListUnpack$fold_list_cast(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,3],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[2,3],f32>) -> (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) {
// CHECK: %[[CAST0:.+]] = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[2,3],f32> to !torch.vtensor<[?,?],f32>
// CHECK: %[[CAST1:.+]] = torch.tensor_static_info_cast %arg1 : !torch.vtensor<[2,3],f32> to !torch.vtensor<[?,?],f32>
// CHECK: return %[[CAST0]], %[[CAST1]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>
func.func @prim.ListUnpack$fold_list_cast(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !torch.vtensor<[2,3],f32>) -> (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) {
%0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) -> !torch.list<vtensor>
%1:2 = torch.prim.ListUnpack %0 : !torch.list<vtensor> -> !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>
return %1#0, %1#1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>
}
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { // CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
// CHECK: %[[CST:.*]] = torch.vtensor.literal(dense<3> : tensor<si64>) : !torch.vtensor<[],si64> // CHECK: %[[CST:.*]] = torch.vtensor.literal(dense<3> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: return %[[CST]] : !torch.vtensor<[],si64> // CHECK: return %[[CST]] : !torch.vtensor<[],si64>