mirror of https://github.com/llvm/torch-mlir
[Stablehlo] lowering aten.view to shape.num_elements + stablehlo.comp… (#3125)
…ute_reshape_shape as that `aten.view` support at most one `-1` in dim list. The original calculation of `numel` is wrong when there is a `-1` in dim list.pull/2977/head
parent
42a16fa912
commit
8d5e2578b0
|
@ -13,6 +13,7 @@
|
|||
#include "PopulatePatterns.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
|
||||
|
@ -178,8 +179,7 @@ public:
|
|||
}
|
||||
|
||||
auto loc = op.getLoc();
|
||||
auto newRank = dimSizes.size();
|
||||
if (newRank == 0 || rankType.getRank() == 0) {
|
||||
if (dimSizes.size() == 0 || rankType.getRank() == 0) {
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
|
@ -193,35 +193,9 @@ public:
|
|||
return dSize;
|
||||
});
|
||||
|
||||
const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
|
||||
Type intType = rewriter.getIntegerType(options.dimSizeIndexBits);
|
||||
if (options.dimSizeIndexBits == 32) {
|
||||
// The i64 calculation is much slower than i32 on some devices, such as
|
||||
// Nvidia GPU. One can truncate from i64 to i32 since dimension sizes are
|
||||
// unlikely to exceed the range of i32(4GiB)
|
||||
std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value &dSize) {
|
||||
// dimSize: cast i64 -> i32
|
||||
dSize = rewriter.create<arith::TruncIOp>(loc, intType, dSize);
|
||||
return dSize;
|
||||
});
|
||||
}
|
||||
Value numel = rewriter.create<shape::NumElementsOp>(
|
||||
loc, rewriter.create<shape::ShapeOfOp>(loc, adaptor.getSelf()));
|
||||
|
||||
Value numel = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(intType, 1));
|
||||
for (auto d : dimSizes) {
|
||||
numel = rewriter.create<arith::MulIOp>(loc, numel, d);
|
||||
}
|
||||
numel = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
|
||||
numel);
|
||||
|
||||
if (dimSizes.size() == 0) {
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
adaptor.getSelf());
|
||||
return success();
|
||||
}
|
||||
Value stablehloShape =
|
||||
rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
|
||||
Value computedShape = rewriter.create<stablehlo::ComputeReshapeShapeOp>(
|
||||
|
|
|
@ -308,15 +308,13 @@ func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,2
|
|||
// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT]]-1, %[[INT]]224 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT]]-1
|
||||
// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT224]]
|
||||
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
|
||||
// CHECK: %[[T4:.*]] = arith.muli %[[C1_I64]], %[[T2]] : i64
|
||||
// CHECK: %[[T5:.*]] = arith.muli %[[T4]], %[[T3]] : i64
|
||||
// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : i64 to index
|
||||
// CHECK: %[[T4:.*]] = shape.shape_of %[[T0]] : tensor<?x?x?x?xf32> -> tensor<4xindex>
|
||||
// CHECK: %[[T5:.*]] = shape.num_elements %[[T4]] : tensor<4xindex> -> index
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]], %[[T3]] : tensor<2xi64>
|
||||
// CHECK: %[[T7:.*]] = stablehlo.compute_reshape_shape %[[T6]], %[[FROM_ELEMENTS]] : (index, tensor<2xi64>) -> tensor<2xi64>
|
||||
// CHECK: %[[T8:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T7]] : (tensor<?x?x?x?xf32>, tensor<2xi64>) -> tensor<?x224xf32>
|
||||
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x224xf32> -> !torch.vtensor<[?,224],f32>
|
||||
// CHECK: return %[[T9]] : !torch.vtensor<[?,224],f32>
|
||||
// CHECK: %[[T6:.*]] = stablehlo.compute_reshape_shape %[[T5]], %[[FROM_ELEMENTS]] : (index, tensor<2xi64>) -> tensor<2xi64>
|
||||
// CHECK: %[[T7:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T6]] : (tensor<?x?x?x?xf32>, tensor<2xi64>) -> tensor<?x224xf32>
|
||||
// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<?x224xf32> -> !torch.vtensor<[?,224],f32>
|
||||
// CHECK: return %[[T8]] : !torch.vtensor<[?,224],f32>
|
||||
func.func @torch.aten.view$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,224],f32> {
|
||||
%int-1 = torch.constant.int -1
|
||||
%int224 = torch.constant.int 224
|
||||
|
@ -339,17 +337,13 @@ func.func @torch.aten.view$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch
|
|||
// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT120]]
|
||||
// CHECK: %[[T4:.*]] = torch_c.to_i64 %[[INT4]]
|
||||
// CHECK: %[[T5:.*]] = torch_c.to_i64 %[[INT64]]
|
||||
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
|
||||
// CHECK: %[[T6:.*]] = arith.muli %[[C1_I64]], %[[T2]] : i64
|
||||
// CHECK: %[[T7:.*]] = arith.muli %[[T6]], %[[T3]] : i64
|
||||
// CHECK: %[[T8:.*]] = arith.muli %[[T7]], %[[T4]] : i64
|
||||
// CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T5]] : i64
|
||||
// CHECK: %[[T10:.*]] = arith.index_cast %[[T9]] : i64 to index
|
||||
// CHECK: %[[T6:.*]] = shape.shape_of %[[T0]] : tensor<?x?x?x?x?xf32> -> tensor<5xindex>
|
||||
// CHECK: %[[T7:.*]] = shape.num_elements %[[T6]] : tensor<5xindex> -> index
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]], %[[T3]], %[[T4]], %[[T5]] : tensor<4xi64>
|
||||
// CHECK: %[[T11:.*]] = stablehlo.compute_reshape_shape %[[T10]], %[[FROM_ELEMENTS]] : (index, tensor<4xi64>) -> tensor<4xi64>
|
||||
// CHECK: %[[T12:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T11]] : (tensor<?x?x?x?x?xf32>, tensor<4xi64>) -> tensor<?x120x4x64xf32>
|
||||
// CHECK: %[[T13:.*]] = torch_c.from_builtin_tensor %[[T12]] : tensor<?x120x4x64xf32> -> !torch.vtensor<[?,120,4,64],f32>
|
||||
// CHECK: return %[[T13]] : !torch.vtensor<[?,120,4,64],f32>
|
||||
// CHECK: %[[T8:.*]] = stablehlo.compute_reshape_shape %[[T7]], %[[FROM_ELEMENTS]] : (index, tensor<4xi64>) -> tensor<4xi64>
|
||||
// CHECK: %[[T9:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T8]] : (tensor<?x?x?x?x?xf32>, tensor<4xi64>) -> tensor<?x120x4x64xf32>
|
||||
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<?x120x4x64xf32> -> !torch.vtensor<[?,120,4,64],f32>
|
||||
// CHECK: return %[[T10]] : !torch.vtensor<[?,120,4,64],f32>
|
||||
func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.vtensor<[?,120,4,64],f32> {
|
||||
%int-1 = torch.constant.int -1
|
||||
%int120 = torch.constant.int 120
|
||||
|
|
Loading…
Reference in New Issue