mirror of https://github.com/llvm/torch-mlir
[torch-mlir][Tosa] fix during torch.max.dim lower to tosa the reshape's new shape attr mismatch reshape's result type (#1378)
parent
04f3a4ffce
commit
797feaf129
|
@ -2866,8 +2866,9 @@ LogicalResult ConvertAtenOp<AtenMaxDimOp>::matchAndRewrite(
|
|||
adaptor.self(), dimAttr);
|
||||
|
||||
if (argMax.getType() != indicesType) {
|
||||
argMax = rewriter.create<tosa::ReshapeOp>(op->getLoc(), indicesType, argMax,
|
||||
prunedShapeAttr);
|
||||
argMax = rewriter.create<tosa::ReshapeOp>(
|
||||
op->getLoc(), indicesType, argMax,
|
||||
rewriter.getI64ArrayAttr(reducedShape));
|
||||
}
|
||||
|
||||
if (!keepDim) {
|
||||
|
|
|
@ -794,3 +794,26 @@ func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32> )
|
|||
%0 = torch.aten.avg_pool2d %arg0, %kernel, %stride, %padding, %false, %true, %none : !torch.vtensor<[1,512,7,7],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,512,1,1],f32>
|
||||
return %0 : !torch.vtensor<[1,512,1,1],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @torch.aten.max.dim$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x2x3xf32>)
|
||||
// CHECK: %[[VAL_0:.*]] = torch_c.from_builtin_tensor %[[ARG0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
|
||||
// CHECK: %[[VAL_TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[VAL_I2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_2:.*]] = "tosa.reduce_max"(%[[VAL_1]]) {axis = 2 : i64} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = "tosa.argmax"(%[[VAL_1]]) {axis = 2 : i64} : (tensor<3x2x3xf32>) -> tensor<3x2xi64>
|
||||
// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_3]]) {new_shape = [3, 2, 1]} : (tensor<3x2xi64>) -> tensor<3x2x1xi64>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32>
|
||||
// CHECK: %[[VAL_6:.*]] = torch_c.to_builtin_tensor %[[VAL_5]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
|
||||
// CEHCK: return %[[VAL_6]] : tensor<3x2x1xf32>
|
||||
func.func @torch.aten.max.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {
|
||||
%0 = torch_c.from_builtin_tensor %arg0 : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
|
||||
%true = torch.constant.bool true
|
||||
%int2 = torch.constant.int 2
|
||||
%values, %indices = torch.aten.max.dim %0, %int2, %true : !torch.vtensor<[3,2,3],f32>, !torch.int, !torch.bool -> !torch.vtensor<[3,2,1],f32>, !torch.vtensor<[3,2,1],si64>
|
||||
%1 = torch_c.to_builtin_tensor %values : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
|
||||
return %1 : tensor<3x2x1xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue