[Stablehlo] lowering aten.view to shape.num_elements + stablehlo.comp… (#3125)

…ute_reshape_shape

as that `aten.view` support at most one `-1` in dim list. The original
calculation of `numel` is wrong when there is a `-1` in dim list.
pull/2977/head
Yuanqiang Liu 2024-04-09 14:54:57 +08:00 committed by GitHub
parent 42a16fa912
commit 8d5e2578b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 48 deletions

View File

@ -13,6 +13,7 @@
#include "PopulatePatterns.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
@ -178,8 +179,7 @@ public:
}
auto loc = op.getLoc();
auto newRank = dimSizes.size();
if (newRank == 0 || rankType.getRank() == 0) {
if (dimSizes.size() == 0 || rankType.getRank() == 0) {
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
@ -193,35 +193,9 @@ public:
return dSize;
});
const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
Type intType = rewriter.getIntegerType(options.dimSizeIndexBits);
if (options.dimSizeIndexBits == 32) {
// The i64 calculation is much slower than i32 on some devices, such as
// Nvidia GPU. One can truncate from i64 to i32 since dimension sizes are
// unlikely to exceed the range of i32(4GiB)
std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value &dSize) {
// dimSize: cast i64 -> i32
dSize = rewriter.create<arith::TruncIOp>(loc, intType, dSize);
return dSize;
});
}
Value numel = rewriter.create<shape::NumElementsOp>(
loc, rewriter.create<shape::ShapeOfOp>(loc, adaptor.getSelf()));
Value numel = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, 1));
for (auto d : dimSizes) {
numel = rewriter.create<arith::MulIOp>(loc, numel, d);
}
numel = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
numel);
if (dimSizes.size() == 0) {
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
adaptor.getSelf());
return success();
}
Value stablehloShape =
rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
Value computedShape = rewriter.create<stablehlo::ComputeReshapeShapeOp>(

View File

@ -308,15 +308,13 @@ func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,2
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT]]-1, %[[INT]]224 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT]]-1
// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT224]]
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
// CHECK: %[[T4:.*]] = arith.muli %[[C1_I64]], %[[T2]] : i64
// CHECK: %[[T5:.*]] = arith.muli %[[T4]], %[[T3]] : i64
// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : i64 to index
// CHECK: %[[T4:.*]] = shape.shape_of %[[T0]] : tensor<?x?x?x?xf32> -> tensor<4xindex>
// CHECK: %[[T5:.*]] = shape.num_elements %[[T4]] : tensor<4xindex> -> index
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]], %[[T3]] : tensor<2xi64>
// CHECK: %[[T7:.*]] = stablehlo.compute_reshape_shape %[[T6]], %[[FROM_ELEMENTS]] : (index, tensor<2xi64>) -> tensor<2xi64>
// CHECK: %[[T8:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T7]] : (tensor<?x?x?x?xf32>, tensor<2xi64>) -> tensor<?x224xf32>
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x224xf32> -> !torch.vtensor<[?,224],f32>
// CHECK: return %[[T9]] : !torch.vtensor<[?,224],f32>
// CHECK: %[[T6:.*]] = stablehlo.compute_reshape_shape %[[T5]], %[[FROM_ELEMENTS]] : (index, tensor<2xi64>) -> tensor<2xi64>
// CHECK: %[[T7:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T6]] : (tensor<?x?x?x?xf32>, tensor<2xi64>) -> tensor<?x224xf32>
// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<?x224xf32> -> !torch.vtensor<[?,224],f32>
// CHECK: return %[[T8]] : !torch.vtensor<[?,224],f32>
func.func @torch.aten.view$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,224],f32> {
%int-1 = torch.constant.int -1
%int224 = torch.constant.int 224
@ -339,17 +337,13 @@ func.func @torch.aten.view$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT120]]
// CHECK: %[[T4:.*]] = torch_c.to_i64 %[[INT4]]
// CHECK: %[[T5:.*]] = torch_c.to_i64 %[[INT64]]
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
// CHECK: %[[T6:.*]] = arith.muli %[[C1_I64]], %[[T2]] : i64
// CHECK: %[[T7:.*]] = arith.muli %[[T6]], %[[T3]] : i64
// CHECK: %[[T8:.*]] = arith.muli %[[T7]], %[[T4]] : i64
// CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T5]] : i64
// CHECK: %[[T10:.*]] = arith.index_cast %[[T9]] : i64 to index
// CHECK: %[[T6:.*]] = shape.shape_of %[[T0]] : tensor<?x?x?x?x?xf32> -> tensor<5xindex>
// CHECK: %[[T7:.*]] = shape.num_elements %[[T6]] : tensor<5xindex> -> index
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]], %[[T3]], %[[T4]], %[[T5]] : tensor<4xi64>
// CHECK: %[[T11:.*]] = stablehlo.compute_reshape_shape %[[T10]], %[[FROM_ELEMENTS]] : (index, tensor<4xi64>) -> tensor<4xi64>
// CHECK: %[[T12:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T11]] : (tensor<?x?x?x?x?xf32>, tensor<4xi64>) -> tensor<?x120x4x64xf32>
// CHECK: %[[T13:.*]] = torch_c.from_builtin_tensor %[[T12]] : tensor<?x120x4x64xf32> -> !torch.vtensor<[?,120,4,64],f32>
// CHECK: return %[[T13]] : !torch.vtensor<[?,120,4,64],f32>
// CHECK: %[[T8:.*]] = stablehlo.compute_reshape_shape %[[T7]], %[[FROM_ELEMENTS]] : (index, tensor<4xi64>) -> tensor<4xi64>
// CHECK: %[[T9:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T8]] : (tensor<?x?x?x?x?xf32>, tensor<4xi64>) -> tensor<?x120x4x64xf32>
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<?x120x4x64xf32> -> !torch.vtensor<[?,120,4,64],f32>
// CHECK: return %[[T10]] : !torch.vtensor<[?,120,4,64],f32>
func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.vtensor<[?,120,4,64],f32> {
%int-1 = torch.constant.int -1
%int120 = torch.constant.int 120