[TorchToLinalg][test] Add test for ConvertAtenConvolutionOp (#3679)

This patch add a test for 638ef14, which use `linalg.broadcast` instead
of `generic` for convolution bias.

Co-authored-by: Rongsheng Gao <gaorongsheng@huawei.com>
pull/3681/head
Longsheng Mou 2024-08-30 17:51:50 +08:00 committed by GitHub
parent fd759e4b1f
commit 3180704b14
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 26 additions and 0 deletions

View File

@ -54,3 +54,29 @@ func.func @q_conv_test(%arg0: !torch.vtensor<[?,?,?,?],si8>, %arg1: !torch.vtens
%11 = torch.aten.dequantize.tensor %10 : !torch.vtensor<[?,?,?,?],!torch.qint32> -> !torch.vtensor<[?,?,?,?],f32>
return %11 : !torch.vtensor<[?,?,?,?],f32>
}
// -----
// CHECK-LABEL: func.func @conv_broadcast(
// CHECK-SAME: %[[arg0:.*]]: !torch.vtensor<[1,80,3000],f32>,
// CHECK-SAME: %[[arg1:.*]]: !torch.vtensor<[1024,80,3],f32>,
// CHECK-SAME: %[[arg2:.*]]: !torch.vtensor<[1024],f32>) -> !torch.vtensor<[1,1024,3000],f32> {
// CHECK: %[[c0:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[input:.*]] = torch_c.to_builtin_tensor %[[arg0]] : !torch.vtensor<[1,80,3000],f32> -> tensor<1x80x3000xf32>
// CHECK-DAG: %[[weight:.*]] = torch_c.to_builtin_tensor %[[arg1]] : !torch.vtensor<[1024,80,3],f32> -> tensor<1024x80x3xf32>
// CHECK-DAG: %[[bias:.*]] = torch_c.to_builtin_tensor %[[arg2]] : !torch.vtensor<[1024],f32> -> tensor<1024xf32>
// CHECK: %[[padInput:.*]] = tensor.pad %[[input]] low[0, 0, 1] high[0, 0, 1]
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1024x3000xf32>
// CHECK: %[[broadcastBias:.*]] = linalg.broadcast ins(%[[bias]] : tensor<1024xf32>) outs(%[[EMPTY]] : tensor<1x1024x3000xf32>) dimensions = [0, 2]
// CHECK: %[[conv:.*]] = linalg.conv_1d_ncw_fcw {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>}
// CHECK-SAME: ins(%[[padInput:.*]], %[[weight]] : tensor<1x80x3002xf32>, tensor<1024x80x3xf32>)
// CHECK-SAME: outs(%[[broadcastBias]] : tensor<1x1024x3000xf32>) -> tensor<1x1024x3000xf32>
func.func @conv_broadcast(%arg0: !torch.vtensor<[1,80,3000],f32>, %arg1: !torch.vtensor<[1024,80,3],f32>, %arg2: !torch.vtensor<[1024],f32>) -> !torch.vtensor<[1,1024,3000],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%false = torch.constant.bool false
%0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
%2 = torch.aten.convolution %arg0, %arg1, %arg2, %0, %0, %0, %false, %1, %int1 : !torch.vtensor<[1,80,3000],f32>, !torch.vtensor<[1024,80,3],f32>, !torch.vtensor<[1024],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1024,3000],f32>
return %2 : !torch.vtensor<[1,1024,3000],f32>
}