mirror of https://github.com/llvm/torch-mlir
Fix onnx.Gather lowering with dynamic shapes (#3675)
Supports the result with dynamic shape and scalar indices like ``` func.func @test_gather_scalar(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[], si64>) -> !torch.vtensor<[?,?],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} { %0 = torch.operator "onnx.Gather"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[], si64>) -> !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32> } ``` `Torch::AtenSqueezeOp` is referring to the result shape, so it will failed on lowering if the result shape is dynamic.pull/3679/head
parent
98e08023bb
commit
fd759e4b1f
|
@ -1941,7 +1941,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
indicesCt = Torch::kUnknownSize;
|
||||
break;
|
||||
}
|
||||
|
||||
indicesCt *= sz;
|
||||
}
|
||||
|
||||
|
@ -1976,8 +1975,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
return success();
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenSqueezeOp>(binder.op, resultType,
|
||||
gather);
|
||||
// indicesRank = 0 will select 1 from the axis dim and squeeze it
|
||||
// Use AtenSqueezeDimOp for the case of result with dynamic shape
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenSqueezeDimOp>(
|
||||
binder.op, resultType, gather, index);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
|
|
|
@ -78,7 +78,7 @@ func.func @test_gather_scalar(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.
|
|||
// CHECK: %[[SEL:.+]] = torch.aten.where.self %[[LT]], %[[ADD]], %arg1
|
||||
// CHECK: %[[FLAT:.+]] = torch.aten.unsqueeze %[[SEL]], %[[ZERO]] : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[ISEL:.+]] = torch.aten.index_select %arg0, %[[AXIS]], %[[FLAT]]
|
||||
// CHECK: %[[RES:.+]] = torch.aten.squeeze %[[ISEL]] : !torch.vtensor<[1,4,5],f32> -> !torch.vtensor<[4,5],f32>
|
||||
// CHECK: %[[RES:.+]] = torch.aten.squeeze.dim %[[ISEL]], %[[AXIS]] : !torch.vtensor<[1,4,5],f32>, !torch.int -> !torch.vtensor<[4,5],f32>
|
||||
// CHECK: return %[[RES]]
|
||||
%0 = torch.operator "onnx.Gather"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[], si64>) -> !torch.vtensor<[4,5],f32>
|
||||
return %0 : !torch.vtensor<[4,5],f32>
|
||||
|
|
Loading…
Reference in New Issue