From fd759e4b1f8c1f9d4d031d570b8048ecf8356790 Mon Sep 17 00:00:00 2001 From: jinchen <49575973+jinchen62@users.noreply.github.com> Date: Thu, 29 Aug 2024 17:02:16 -0700 Subject: [PATCH] Fix onnx.Gather lowering with dynamic shapes (#3675) Supports the result with dynamic shape and scalar indices like ``` func.func @test_gather_scalar(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[], si64>) -> !torch.vtensor<[?,?],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} { %0 = torch.operator "onnx.Gather"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[], si64>) -> !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32> } ``` `Torch::AtenSqueezeOp` is referring to the result shape, so it will failed on lowering if the result shape is dynamic. --- lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp | 7 ++++--- test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index ef50c3bca..168040d9b 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1941,7 +1941,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( indicesCt = Torch::kUnknownSize; break; } - indicesCt *= sz; } @@ -1976,8 +1975,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return success(); } - rewriter.replaceOpWithNewOp(binder.op, resultType, - gather); + // indicesRank = 0 will select 1 from the axis dim and squeeze it + // Use AtenSqueezeDimOp for the case of result with dynamic shape + rewriter.replaceOpWithNewOp( + binder.op, resultType, gather, index); return success(); }); patterns.onOp( diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 2e7b59088..21be2a65f 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -78,7 +78,7 @@ func.func @test_gather_scalar(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch. // CHECK: %[[SEL:.+]] = torch.aten.where.self %[[LT]], %[[ADD]], %arg1 // CHECK: %[[FLAT:.+]] = torch.aten.unsqueeze %[[SEL]], %[[ZERO]] : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> // CHECK: %[[ISEL:.+]] = torch.aten.index_select %arg0, %[[AXIS]], %[[FLAT]] - // CHECK: %[[RES:.+]] = torch.aten.squeeze %[[ISEL]] : !torch.vtensor<[1,4,5],f32> -> !torch.vtensor<[4,5],f32> + // CHECK: %[[RES:.+]] = torch.aten.squeeze.dim %[[ISEL]], %[[AXIS]] : !torch.vtensor<[1,4,5],f32>, !torch.int -> !torch.vtensor<[4,5],f32> // CHECK: return %[[RES]] %0 = torch.operator "onnx.Gather"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[], si64>) -> !torch.vtensor<[4,5],f32> return %0 : !torch.vtensor<[4,5],f32>