mirror of https://github.com/llvm/torch-mlir
[TorchToLinalg][test] Add test for ConvertAtenConvolutionOp (#3679)
This patch add a test for 638ef14
, which use `linalg.broadcast` instead
of `generic` for convolution bias.
Co-authored-by: Rongsheng Gao <gaorongsheng@huawei.com>
pull/3681/head
parent
fd759e4b1f
commit
3180704b14
|
@ -54,3 +54,29 @@ func.func @q_conv_test(%arg0: !torch.vtensor<[?,?,?,?],si8>, %arg1: !torch.vtens
|
|||
%11 = torch.aten.dequantize.tensor %10 : !torch.vtensor<[?,?,?,?],!torch.qint32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||
return %11 : !torch.vtensor<[?,?,?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @conv_broadcast(
|
||||
// CHECK-SAME: %[[arg0:.*]]: !torch.vtensor<[1,80,3000],f32>,
|
||||
// CHECK-SAME: %[[arg1:.*]]: !torch.vtensor<[1024,80,3],f32>,
|
||||
// CHECK-SAME: %[[arg2:.*]]: !torch.vtensor<[1024],f32>) -> !torch.vtensor<[1,1024,3000],f32> {
|
||||
// CHECK: %[[c0:.*]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK-DAG: %[[input:.*]] = torch_c.to_builtin_tensor %[[arg0]] : !torch.vtensor<[1,80,3000],f32> -> tensor<1x80x3000xf32>
|
||||
// CHECK-DAG: %[[weight:.*]] = torch_c.to_builtin_tensor %[[arg1]] : !torch.vtensor<[1024,80,3],f32> -> tensor<1024x80x3xf32>
|
||||
// CHECK-DAG: %[[bias:.*]] = torch_c.to_builtin_tensor %[[arg2]] : !torch.vtensor<[1024],f32> -> tensor<1024xf32>
|
||||
// CHECK: %[[padInput:.*]] = tensor.pad %[[input]] low[0, 0, 1] high[0, 0, 1]
|
||||
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1024x3000xf32>
|
||||
// CHECK: %[[broadcastBias:.*]] = linalg.broadcast ins(%[[bias]] : tensor<1024xf32>) outs(%[[EMPTY]] : tensor<1x1024x3000xf32>) dimensions = [0, 2]
|
||||
// CHECK: %[[conv:.*]] = linalg.conv_1d_ncw_fcw {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>}
|
||||
// CHECK-SAME: ins(%[[padInput:.*]], %[[weight]] : tensor<1x80x3002xf32>, tensor<1024x80x3xf32>)
|
||||
// CHECK-SAME: outs(%[[broadcastBias]] : tensor<1x1024x3000xf32>) -> tensor<1x1024x3000xf32>
|
||||
func.func @conv_broadcast(%arg0: !torch.vtensor<[1,80,3000],f32>, %arg1: !torch.vtensor<[1024,80,3],f32>, %arg2: !torch.vtensor<[1024],f32>) -> !torch.vtensor<[1,1024,3000],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
%int1 = torch.constant.int 1
|
||||
%false = torch.constant.bool false
|
||||
%0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
|
||||
%1 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
|
||||
%2 = torch.aten.convolution %arg0, %arg1, %arg2, %0, %0, %0, %false, %1, %int1 : !torch.vtensor<[1,80,3000],f32>, !torch.vtensor<[1024,80,3],f32>, !torch.vtensor<[1024],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1024,3000],f32>
|
||||
return %2 : !torch.vtensor<[1,1024,3000],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue