[torch-mlir][Tosa] fix during torch.max.dim lower to tosa the reshape's new shape attr mismatch reshape's result type (#1378)

pull/1380/head snapshot-20220917.599
long.chen 2022-09-17 12:29:56 +08:00 committed by GitHub
parent 04f3a4ffce
commit 797feaf129
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 2 deletions

View File

@ -2866,8 +2866,9 @@ LogicalResult ConvertAtenOp<AtenMaxDimOp>::matchAndRewrite(
adaptor.self(), dimAttr);
if (argMax.getType() != indicesType) {
argMax = rewriter.create<tosa::ReshapeOp>(op->getLoc(), indicesType, argMax,
prunedShapeAttr);
argMax = rewriter.create<tosa::ReshapeOp>(
op->getLoc(), indicesType, argMax,
rewriter.getI64ArrayAttr(reducedShape));
}
if (!keepDim) {

View File

@ -794,3 +794,26 @@ func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32> )
%0 = torch.aten.avg_pool2d %arg0, %kernel, %stride, %padding, %false, %true, %none : !torch.vtensor<[1,512,7,7],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,512,1,1],f32>
return %0 : !torch.vtensor<[1,512,1,1],f32>
}
// -----
// CHECK-LABEL: @torch.aten.max.dim$basic(
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x2x3xf32>)
// CHECK: %[[VAL_0:.*]] = torch_c.from_builtin_tensor %[[ARG0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
// CHECK: %[[VAL_TRUE:.*]] = torch.constant.bool true
// CHECK: %[[VAL_I2:.*]] = torch.constant.int 2
// CHECK: %[[VAL_2:.*]] = "tosa.reduce_max"(%[[VAL_1]]) {axis = 2 : i64} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32>
// CHECK: %[[VAL_3:.*]] = "tosa.argmax"(%[[VAL_1]]) {axis = 2 : i64} : (tensor<3x2x3xf32>) -> tensor<3x2xi64>
// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_3]]) {new_shape = [3, 2, 1]} : (tensor<3x2xi64>) -> tensor<3x2x1xi64>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32>
// CHECK: %[[VAL_6:.*]] = torch_c.to_builtin_tensor %[[VAL_5]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
// CEHCK: return %[[VAL_6]] : tensor<3x2x1xf32>
func.func @torch.aten.max.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {
%0 = torch_c.from_builtin_tensor %arg0 : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
%true = torch.constant.bool true
%int2 = torch.constant.int 2
%values, %indices = torch.aten.max.dim %0, %int2, %true : !torch.vtensor<[3,2,3],f32>, !torch.int, !torch.bool -> !torch.vtensor<[3,2,1],f32>, !torch.vtensor<[3,2,1],si64>
%1 = torch_c.to_builtin_tensor %values : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
return %1 : tensor<3x2x1xf32>
}