mirror of https://github.com/llvm/torch-mlir
[Onnx] expand support for constant matching (#3607)
The pattern `m_OnnxListOfConstantInts` previously only checked if the attr inside an `onnx.Constant` op is a `DenseResourceElementsAttr`, but didn't handle `ElementsAttr`'s. This patch adds support for `ElementsAttr` and provides an example of it's use via a lit test for `onnx.Unsqueeze`.pull/3611/head
parent
341f415b1e
commit
c8efc201f4
|
@ -90,6 +90,12 @@ struct onnx_list_of_constant_ints_op_binder {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
if (ElementsAttr attr = dyn_cast_or_null<ElementsAttr>(
|
||||
constOp->getAttr("torch.onnx.value"))) {
|
||||
for (auto axis : attr.getValues<llvm::APInt>())
|
||||
bind_values.push_back(axis.getSExtValue());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
|
|
@ -505,6 +505,18 @@ func.func @test_squeeze_two_axes(%arg0: !torch.vtensor<[3,1,4,5,1],f32>, %arg1:
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_unsqueeze_dyn_dims
|
||||
func.func @test_unsqueeze_dyn_dims(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,1,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "1.11.0"} {
|
||||
// CHECK: %[[x0:.*]] = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||
// CHECK: %[[int1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[x1:.*]] = torch.aten.unsqueeze %arg0, %[[int1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,1,?],f32>
|
||||
%0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64>
|
||||
%1 = torch.operator "onnx.Unsqueeze"(%arg0, %0) : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?,1,?],f32>
|
||||
return %1 : !torch.vtensor<[?,1,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_unsqueeze_axis_0
|
||||
func.func @test_unsqueeze_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
|
|
Loading…
Reference in New Issue