[Onnx] expand support for constant matching (#3607)

The pattern `m_OnnxListOfConstantInts` previously only checked if the
attr inside an `onnx.Constant` op is a `DenseResourceElementsAttr`, but
didn't handle `ElementsAttr`'s. This patch adds support for
`ElementsAttr` and provides an example of it's use via a lit test for
`onnx.Unsqueeze`.
pull/3611/head
zjgarvey 2024-08-07 17:35:34 -07:00 committed by GitHub
parent 341f415b1e
commit c8efc201f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 0 deletions

View File

@ -90,6 +90,12 @@ struct onnx_list_of_constant_ints_op_binder {
} }
return true; return true;
} }
if (ElementsAttr attr = dyn_cast_or_null<ElementsAttr>(
constOp->getAttr("torch.onnx.value"))) {
for (auto axis : attr.getValues<llvm::APInt>())
bind_values.push_back(axis.getSExtValue());
return true;
}
return false; return false;
} }
}; };

View File

@ -505,6 +505,18 @@ func.func @test_squeeze_two_axes(%arg0: !torch.vtensor<[3,1,4,5,1],f32>, %arg1:
// ----- // -----
// CHECK-LABEL: func.func @test_unsqueeze_dyn_dims
func.func @test_unsqueeze_dyn_dims(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,1,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "1.11.0"} {
// CHECK: %[[x0:.*]] = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
// CHECK: %[[int1:.*]] = torch.constant.int 1
// CHECK: %[[x1:.*]] = torch.aten.unsqueeze %arg0, %[[int1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,1,?],f32>
%0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
%1 = torch.operator "onnx.Unsqueeze"(%arg0, %0) : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?,1,?],f32>
return %1 : !torch.vtensor<[?,1,?],f32>
}
// -----
// CHECK-LABEL: func.func @test_unsqueeze_axis_0 // CHECK-LABEL: func.func @test_unsqueeze_axis_0
func.func @test_unsqueeze_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_unsqueeze_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[INT0:.*]] = torch.constant.int 0