TorchToLinalg: Try folding shape computations to keep static shapes when possible (#3475)

Before this PR, a statically shaped aten.convolution would generate
dynamically shaped linalg IR, and even `-canonicalize` would not be able
to fold it back into static shapes. This PR ensure that shape
calculations are folded on construction to directly generate statically
shaped linalg IR.

We achieve that by ensuring that `arith` ops involved in computing
shapes are created via `createOrFold`, so that later uses of
`getAsOpFoldResult` see constants instead of those ops.

For example
```
module {
  func.func @forward(%arg0: !torch.vtensor<[32,336,112,112],f32>,
                        %arg1: !torch.vtensor<[336,168,3,3],f32>, 
                        %arg2: !torch.vtensor<[336],f32>) 
                        -> !torch.vtensor<[32,336,56,56],f32> {
    %false = torch.constant.bool false
    %int2 = torch.constant.int 2
    %int1 = torch.constant.int 1
    %0 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
    %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
    %2 = torch.prim.ListConstruct  : () -> !torch.list<int>
    %3 = torch.aten.convolution %arg0, %arg1, %arg2, %1, %0, %0, %false, %2, %int2 
    : !torch.vtensor<[32,336,112,112],f32>, !torch.vtensor<[336,168,3,3],f32>, !torch.vtensor<[336],f32>, !torch.list<int>,
      !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int
   -> !torch.vtensor<[32,336,56,56],f32>
    return %3 : !torch.vtensor<[32,336,56,56],f32>
  }
}
```
would result in
```
[...]
  %padded = tensor.pad %2 low[%14, %15, %16, %17] high[%14, %15, %16, %17] {
    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
      tensor.yield %cst : f32
    } : tensor<32x336x112x112xf32> to tensor<?x?x?x?xf32>
[...]
  %45 = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
    ins(%expanded, %expanded_37 : tensor<?x2x?x?x?xf32>, tensor<2x168x168x3x3xf32>)
    outs(%expanded_44 : tensor<32x2x168x?x?xf32>) -> tensor<32x2x168x?x?xf32>
[...]
```
and with this PR all shapes are static.
pull/3463/head
Matthias Gehre 2024-06-27 08:43:10 +02:00 committed by GitHub
parent 6eebe61bfe
commit 6678e1a256
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 85 additions and 121 deletions

View File

@ -860,7 +860,7 @@ public:
Type intType = IntegerType::get(context, 64);
auto castIndexToInt = [&](Value v) {
return rewriter.create<arith::IndexCastOp>(loc, intType, v);
return rewriter.createOrFold<arith::IndexCastOp>(loc, intType, v);
};
SmallVector<Value> paddingIntValues;

View File

@ -82,16 +82,10 @@ Value torch_to_linalg::getDynamicZeroPaddedTensor(
pad < paddingIncludingUnchanged.end(); pad++)
*pad = castIntToIndex(b, loc, *pad);
Type elementType = cast<RankedTensorType>(input.getType()).getElementType();
// TODO: audit possibility of sparsity on this tensor
Type inputType =
RankedTensorType::get(makeShapeLLVMCompatible(llvm::ArrayRef<int64_t>(
SmallVector<int64_t>(inRank, kUnknownSize))),
elementType);
SmallVector<OpFoldResult> paddingValues =
getAsOpFoldResult(paddingIncludingUnchanged);
return b.create<tensor::PadOp>(loc, inputType, input, /*low=*/paddingValues,
return b.create<tensor::PadOp>(loc, Type{}, input, /*low=*/paddingValues,
/*high=*/paddingValues, pad);
}
@ -103,25 +97,25 @@ Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc,
Value c1 = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(1));
Value c2 = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(2));
Value doublePadding = b.create<arith::MulIOp>(loc, paddingInt, c2);
Value doublePadding = b.createOrFold<arith::MulIOp>(loc, paddingInt, c2);
// in + 2 * padding
Value inAddDoublePadding =
b.create<arith::AddIOp>(loc, castIndexToInt64(b, loc, in), doublePadding);
Value inAddDoublePadding = b.createOrFold<arith::AddIOp>(
loc, castIndexToInt64(b, loc, in), doublePadding);
// dilation * (kernelSize - 1)
Value kernelSizeSub1 = b.create<arith::SubIOp>(loc, kernelSizeInt, c1);
Value kernelSizeSub1 = b.createOrFold<arith::SubIOp>(loc, kernelSizeInt, c1);
Value dilationTimesKernelSize =
b.create<arith::MulIOp>(loc, dilationInt, kernelSizeSub1);
b.createOrFold<arith::MulIOp>(loc, dilationInt, kernelSizeSub1);
Value temp =
b.create<arith::SubIOp>(loc, inAddDoublePadding, dilationTimesKernelSize);
Value dividend = b.create<arith::SubIOp>(loc, temp, c1);
Value temp = b.createOrFold<arith::SubIOp>(loc, inAddDoublePadding,
dilationTimesKernelSize);
Value dividend = b.createOrFold<arith::SubIOp>(loc, temp, c1);
Value division;
if (ceilMode)
division = b.create<arith::CeilDivSIOp>(loc, dividend, strideInt);
division = b.createOrFold<arith::CeilDivSIOp>(loc, dividend, strideInt);
else
division = b.create<arith::FloorDivSIOp>(loc, dividend, strideInt);
Value out = b.create<arith::AddIOp>(loc, division, c1);
division = b.createOrFold<arith::FloorDivSIOp>(loc, dividend, strideInt);
Value out = b.createOrFold<arith::AddIOp>(loc, division, c1);
return castIntToIndex(b, loc, out);
}

View File

@ -140,12 +140,12 @@ Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
Value castIntToIndex(OpBuilder &b, Location loc, Value v) {
assert(isa<IntegerType>(v.getType()) && "must be called with integer type");
return b.create<arith::IndexCastOp>(loc, b.getIndexType(), v);
return b.createOrFold<arith::IndexCastOp>(loc, b.getIndexType(), v);
}
Value castIndexToInt64(OpBuilder &b, Location loc, Value idx) {
assert(isa<IndexType>(idx.getType()) && "must be called with integer type");
return b.create<arith::IndexCastOp>(loc, b.getI64Type(), idx);
return b.createOrFold<arith::IndexCastOp>(loc, b.getI64Type(), idx);
}
SmallVector<Value>

View File

@ -94,7 +94,7 @@ static void setupTorchIntToI64Conversion(ConversionTarget &target,
if (!isa<Torch::IntType>(inputs[0].getType()))
return std::nullopt;
assert(inputs.size() == 1);
return builder.create<ToI64Op>(loc, inputs[0]).getResult();
return builder.createOrFold<ToI64Op>(loc, inputs[0]);
});
auto sourceMaterialization = [](OpBuilder &builder, Torch::IntType type,
ValueRange inputs, Location loc) -> Value {

View File

@ -67,7 +67,7 @@ func.func @elementwise$ternary(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],f32> {
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[BUILTIN_C1:.*]] = torch_c.to_i64 %[[C1]]
// CHECK: %[[BUILTIN_C1:.*]] = arith.constant 1 : i64
// CHECK: linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>]
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32, %{{.*}}: f32):
// CHECK: %[[ALPHA:.*]] = arith.sitofp %[[BUILTIN_C1]] : i64 to f32

View File

@ -7,13 +7,13 @@ func.func @forward_max_pool1d(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vten
%int3 = torch.constant.int 3
%int4 = torch.constant.int 4
%false = torch.constant.bool false
// CHECK: %[[C1:.*]] = torch_c.to_i64 %int1
// CHECK: %[[C1:.*]] = arith.constant 1 : i64
// CHECK: %[[NEUTRAL:.*]] = arith.constant 0xFF800000 : f32
// CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 3] high[0, 0, 3]
// CHECK: %[[OUT:.*]] = linalg.fill ins(%[[NEUTRAL]] : f32) outs(%{{.*}} : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: %[[T1:.*]] = arith.index_cast %[[C1]] : i64 to index
// CHECK: %[[INIT:.*]] = tensor.empty(%[[T1]]) : tensor<?xf32>
// CHECK: linalg.pooling_ncw_max {dilations = dense<4> : vector<1xi64>, strides = dense<2> : vector<1xi64>} ins(%[[PADDED]], %[[INIT]] : tensor<?x?x?xf32>, tensor<?xf32>) outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: %[[T1:.*]] = arith.constant 1 : index
// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<1xf32>
// CHECK: linalg.pooling_ncw_max {dilations = dense<4> : vector<1xi64>, strides = dense<2> : vector<1xi64>} ins(%[[PADDED]], %[[INIT]] : tensor<?x?x?xf32>, tensor<1xf32>) outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
%kernel_size = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%stride = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
%padding = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
@ -33,15 +33,15 @@ func.func @forward_max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vt
%int7 = torch.constant.int 7
%int8 = torch.constant.int 8
%false = torch.constant.bool false
// CHECK: %[[C1:.*]] = torch_c.to_i64 %int1
// CHECK: %[[C2:.*]] = torch_c.to_i64 %int2
// CHECK: %[[C1:.*]] = arith.constant 1 : i64
// CHECK: %[[C2:.*]] = arith.constant 2 : i64
// CHECK: %[[NEUTRAL:.*]] = arith.constant 0xFF800000 : f32
// CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 5, 6] high[0, 0, 5, 6]
// CHECK: %[[OUT:.*]] = linalg.fill ins(%[[NEUTRAL]] : f32) outs(%{{.*}} : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[T1:.*]] = arith.index_cast %[[C1]] : i64 to index
// CHECK: %[[T2:.*]] = arith.index_cast %[[C2]] : i64 to index
// CHECK: %[[INIT:.*]] = tensor.empty(%[[T1]], %[[T2]]) : tensor<?x?xf32>
// CHECK: linalg.pooling_nchw_max {dilations = dense<[7, 8]> : vector<2xi64>, strides = dense<[3, 4]> : vector<2xi64>} ins(%[[PADDED]], %[[INIT]] : tensor<?x?x?x?xf32>, tensor<?x?xf32>) outs(%[[OUT]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[T1:.*]] = arith.constant 1 : index
// CHECK: %[[T2:.*]] = arith.constant 2 : index
// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<1x2xf32>
// CHECK: linalg.pooling_nchw_max {dilations = dense<[7, 8]> : vector<2xi64>, strides = dense<[3, 4]> : vector<2xi64>} ins(%[[PADDED]], %[[INIT]] : tensor<?x?x?x?xf32>, tensor<1x2xf32>) outs(%[[OUT]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
%kernel_size = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%stride = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
%padding = torch.prim.ListConstruct %int5, %int6 : (!torch.int, !torch.int) -> !torch.list<int>
@ -88,7 +88,7 @@ func.func @forward_max_pool3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.
// CHECK: } : tensor<?x?x?x?x?xf32> to tensor<?x?x?x?x?xf32>
// CHECK: %[[OUTPUT_TENSOR:.*]] = linalg.fill ins(%[[MIN_VALUE:.*]] : f32) outs(%{{.*}} : tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
// CHECK: %[[MAX_3D_POOL:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[PADDED_INPUT_TENSOR:.*]], %{{.*}} : tensor<?x?x?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[OUTPUT_TENSOR:.*]] : tensor<?x?x?x?x?xf32>) {
// CHECK: %[[MAX_3D_POOL:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[PADDED_INPUT_TENSOR:.*]], %{{.*}} : tensor<?x?x?x?x?xf32>, tensor<8x8x8xf32>) outs(%[[OUTPUT_TENSOR:.*]] : tensor<?x?x?x?x?xf32>) {
// CHECK-NEXT: ^bb0(%[[CURRENT_VALUE:.*]]: f32, %[[KERNEL:.*]]: f32, %[[ACC_OUT:.*]]: f32):
// CHECK-NEXT: %[[MAXF:.*]] = arith.maximumf %[[CURRENT_VALUE:.*]], %[[ACC_OUT:.*]] : f32
// CHECK-NEXT: linalg.yield %[[MAXF:.*]] : f32

View File

@ -7,10 +7,8 @@
// CHECK-LABEL: func.func @torch.aten.view$twotothree
// CHECK: %[[ARG0:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[3,2],f32> -> tensor<3x2xf32>
// CHECK: %[[T3:.*]] = torch.constant.int 3
// CHECK: %[[T2:.*]] = torch.constant.int 2
// CHECK: %[[N2:.*]] = torch_c.to_i64 %[[T2]]
// CHECK: %[[N3:.*]] = torch_c.to_i64 %[[T3]]
// CHECK: %[[N2:.*]] = arith.constant 2 : i64
// CHECK: %[[N3:.*]] = arith.constant 3 : i64
// CHECK: %[[ELEMENTS:.*]] = tensor.from_elements %[[N2]], %[[N3]] : tensor<2xi64>
// CHECK: %[[RESHAPE:.*]] = tensor.reshape %[[ARG0]](%[[ELEMENTS]]) : (tensor<3x2xf32>, tensor<2xi64>) -> tensor<2x3xf32>
func.func @torch.aten.view$twotothree(%arg0: !torch.vtensor<[3,2],f32>) -> !torch.vtensor<[2,3],f32>
@ -112,13 +110,12 @@ func.func @torch.aten.view$expandInferredDim(%arg0: !torch.vtensor<[2,6],f32>) -
// reshape. Someday, this should generate flatten/unflatten.
// CHECK-LABEL: func.func @torch.aten$dynamicValOutput
// CHECK: %[[SELF:.*]] = torch_c.to_builtin_tensor %arg0
// CHECK: %[[CONSTANT1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[PROD1:.*]] = arith.constant 1
// CHECK-DAG: %[[ARG1_CVT:.*]] = torch_c.to_i64 %arg1
// CHECK-DAG: %[[PROD2:.*]] = arith.muli %[[PROD1]], %[[ARG1_CVT]]
// CHECK-DAG: %[[ONEI64:.*]] = torch_c.to_i64 %[[CONSTANT1]]
// CHECK-DAG: %[[ONEI64:.*]] = arith.constant 1 : i64
// CHECK-DAG: %[[PROD3:.*]] = arith.muli %[[PROD2]], %[[ONEI64]]
// CHECK-DAG: %[[ONEI64_0:.*]] = torch_c.to_i64 %[[CONSTANT1]]
// CHECK-DAG: %[[ONEI64_0:.*]] = arith.constant 1 : i64
// CHECK-DAG: %[[PROD4:.*]] = arith.muli %[[PROD3]], %[[ONEI64_0]]
// CHECK-DAG: %[[INDEX0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[DIM0_INDEX:.*]] = tensor.dim %[[SELF]], %[[INDEX0]] : tensor<?x?x?xf32>
@ -134,8 +131,8 @@ func.func @torch.aten.view$expandInferredDim(%arg0: !torch.vtensor<[2,6],f32>) -
// CHECK-DAG: %[[KNOWN2:.*]] = arith.muli %[[KNOWN1]], %[[DIM2]] : i64
// CHECK-DAG: %[[DIMINFER:.*]] = arith.divui %[[KNOWN2]], %[[PROD4]] : i64
// CHECK: %[[DIM0:.*]] = torch_c.to_i64 %arg1
// CHECK: %[[DIM1:.*]] = torch_c.to_i64 %[[CONSTANT1]]
// CHECK: %[[DIM3:.*]] = torch_c.to_i64 %[[CONSTANT1]]
// CHECK: %[[DIM1:.*]] = arith.constant 1 : i64
// CHECK: %[[DIM3:.*]] = arith.constant 1 : i64
// CHECK: %[[OUTPUT_DIMS:.*]] = tensor.from_elements %[[DIM0]], %[[DIM1]], %[[DIMINFER]], %[[DIM3]] : tensor<4xi64>
// CHECK: tensor.reshape %[[SELF]](%[[OUTPUT_DIMS]]) : (tensor<?x?x?xf32>, tensor<4xi64>) -> tensor<?x1x?x1xf32>
//

View File

@ -4,9 +4,9 @@
// CHECK-SAME: %[[VAL_0:.*]]: !torch.bool) -> !torch.int {
// CHECK: %[[VAL_1:.*]] = torch_c.to_i1 %[[VAL_0]]
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2
// CHECK: %[[VAL_3:.*]] = torch_c.to_i64 %[[VAL_2]]
// CHECK: %[[VAL_3:.*]] = arith.constant 2 : i64
// CHECK: %[[VAL_4:.*]] = torch.constant.int 1
// CHECK: %[[VAL_5:.*]] = torch_c.to_i64 %[[VAL_4]]
// CHECK: %[[VAL_5:.*]] = arith.constant 1 : i64
// CHECK: %[[VAL_6:.*]] = scf.if %[[VAL_1]] -> (i64) {
// CHECK: scf.yield %[[VAL_3]] : i64
// CHECK: } else {
@ -31,11 +31,11 @@ func.func @torch.prim.if(%arg0: !torch.bool) -> !torch.int {
// CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_i1 %[[VAL_0]]
// CHECK-DAG: %[[VAL_3:.*]] = torch_c.to_i1 %[[VAL_1]]
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
// CHECK: %[[VAL_5:.*]] = torch_c.to_i64 %[[VAL_4]]
// CHECK: %[[VAL_5:.*]] = arith.constant 2 : i64
// CHECK: %[[VAL_6:.*]] = torch.constant.int 3
// CHECK: %[[VAL_7:.*]] = torch_c.to_i64 %[[VAL_6]]
// CHECK: %[[VAL_7:.*]] = arith.constant 3 : i64
// CHECK: %[[VAL_8:.*]] = torch.constant.int 4
// CHECK: %[[VAL_9:.*]] = torch_c.to_i64 %[[VAL_8]]
// CHECK: %[[VAL_9:.*]] = arith.constant 4 : i64
// CHECK: %[[VAL_10:.*]] = scf.if %[[VAL_2]] -> (i64) {
// CHECK: %[[VAL_11:.*]] = scf.if %[[VAL_3]] -> (i64) {
// CHECK: scf.yield %[[VAL_5]] : i64

View File

@ -103,8 +103,7 @@ func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.
// CHECK-LABEL: func.func @torch.aten.addscalar$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[INT9:.*]] = torch.constant.int 9
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
// CHECK: %[[T1:.*]] = arith.constant 9 : i64
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
@ -124,10 +123,8 @@ func.func @torch.aten.addscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc
// CHECK-LABEL: func.func @torch.aten.addscalar$alpha(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[INT9:.*]] = torch.constant.int 9
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
// CHECK: %[[T1:.*]] = arith.constant 9 : i64
// CHECK: %[[T2:.*]] = arith.constant 2 : i64
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
@ -167,8 +164,7 @@ func.func @torch.aten.addtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
// CHECK: %[[T2:.*]] = arith.constant 2 : i64
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64>
// CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
@ -204,8 +200,7 @@ func.func @torch.aten.addtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1
// CHECK-LABEL: func.func @torch.aten.subscalar$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[INT9:.*]] = torch.constant.int 9
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
// CHECK: %[[T1:.*]] = arith.constant 9 : i64
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
@ -225,8 +220,7 @@ func.func @torch.aten.subscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc
// CHECK-LABEL: func.func @torch.aten.rsubscalar$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[INT9:.*]] = torch.constant.int 9
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
// CHECK: %[[T1:.*]] = arith.constant 9 : i64
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
@ -246,10 +240,8 @@ func.func @torch.aten.rsubscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !tor
// CHECK-LABEL: func.func @torch.aten.subscalar$alpha(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[INT9:.*]] = torch.constant.int 9
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
// CHECK: %[[T1:.*]] = arith.constant 9 : i64
// CHECK: %[[T2:.*]] = arith.constant 2 : i64
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
@ -289,8 +281,7 @@ func.func @torch.aten.subtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
// CHECK: %[[T2:.*]] = arith.constant 2 : i64
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64>
// CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor<f32>
@ -326,8 +317,7 @@ func.func @torch.aten.subtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1
// CHECK-LABEL: func.func @torch.aten.mulscalar$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[INT9:.*]] = torch.constant.int 9
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
// CHECK: %[[T1:.*]] = arith.constant 9 : i64
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
@ -359,8 +349,7 @@ func.func @torch.aten.multensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
// CHECK-LABEL: func.func @torch.aten.divscalar$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[INT9:.*]] = torch.constant.int 9
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
// CHECK: %[[T1:.*]] = arith.constant 9 : i64
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor<f32>
@ -392,8 +381,7 @@ func.func @torch.aten.divtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
// CHECK-LABEL: func.func @torch.aten.gt.scalar(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT3]]
// CHECK: %[[T1:.*]] = arith.constant 3 : i64
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : tensor<1xi64>
// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xi64>) -> tensor<i64>

View File

@ -278,7 +278,7 @@ func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vten
// CHECK: %[[T_5:.*]] = torch.constant.int 1
// CHECK: %[[T_6:.*]] = torch.constant.int 4
// CHECK: %[[T_7:.*]] = torch.constant.int 3
// CHECK: %[[T_8:.*]] = torch_c.to_i64 %[[T_7]]
// CHECK: %[[T_8:.*]] = arith.constant 3 : i64
// CHECK: %[[T_9:.*]] = torch.prim.ListConstruct %[[T_4]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_10:.*]] = torch.prim.ListConstruct %[[T_6]], %[[T_4]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_11:.*]] = torch.prim.ListConstruct %[[T_7]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list<int>
@ -314,8 +314,7 @@ func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !
// CHECK: %int2 = torch.constant.int 2
// CHECK: %int1 = torch.constant.int 1
// CHECK: %int4 = torch.constant.int 4
// CHECK: %int3 = torch.constant.int 3
// CHECK: %[[T_3:.*]] = torch_c.to_i64 %int3
// CHECK: %[[T_3:.*]] = arith.constant 3 : i64
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int4, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_6:.*]] = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
@ -357,7 +356,7 @@ func.func @torch.aten.convolution$bias(%arg0: !torch.vtensor<[?,?,?,?],f32>, %ar
// CHECK: %none = torch.constant.none
// CHECK: %int0 = torch.constant.int 0
// CHECK: %int1 = torch.constant.int 1
// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int1
// CHECK: %[[T_2:.*]] = arith.constant 1 : i64
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_5:.*]] = stablehlo.transpose %[[T_1]], dims = [2, 3, 1, 0] : (tensor<2x4x3x3xf32>) -> tensor<3x3x4x2xf32>
@ -388,7 +387,7 @@ func.func @torch.aten.convolution$transposed_basic(%arg0: !torch.vtensor<[1,2,7,
// CHECK: %none = torch.constant.none
// CHECK: %int0 = torch.constant.int 0
// CHECK: %int1 = torch.constant.int 1
// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int1
// CHECK: %[[T_2:.*]] = arith.constant 1 : i64
// CHECK: %int2 = torch.constant.int 2
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
@ -423,7 +422,7 @@ func.func @torch.aten.convolution$transposed_stride(%arg0: !torch.vtensor<[1,2,7
// CHECK: %none = torch.constant.none
// CHECK: %int0 = torch.constant.int 0
// CHECK: %int1 = torch.constant.int 1
// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int1
// CHECK: %[[T_2:.*]] = arith.constant 1 : i64
// CHECK: %int2 = torch.constant.int 2
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
@ -459,12 +458,12 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor
// CHECK: %int0 = torch.constant.int 0
// CHECK: %int1 = torch.constant.int 1
// CHECK: %int2 = torch.constant.int 2
// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int2
// CHECK: %[[T_2:.*]] = arith.constant 2 : i64
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_6:.*]] = stablehlo.transpose %[[T_1]], dims = [2, 3, 1, 0] : (tensor<2x2x3x3xf32>) -> tensor<3x3x2x2xf32>
// CHECK: %[[T_7:.*]] = stablehlo.reverse %6, dims = [0, 1] : tensor<3x3x2x2xf32>
// CHECK: %[[T_7:.*]] = stablehlo.reverse %[[T_6]], dims = [0, 1] : tensor<3x3x2x2xf32>
// CHECK: %c0 = arith.constant 0 : index
// CHECK: %dim = tensor.dim %[[T_7]], %c0 : tensor<3x3x2x2xf32>
// CHECK: %[[T_8:.*]] = arith.index_cast %dim : index to i64
@ -477,14 +476,14 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor
// CHECK: %c3 = arith.constant 3 : index
// CHECK: %dim_2 = tensor.dim %[[T_7]], %c3 : tensor<3x3x2x2xf32>
// CHECK: %[[T_11:.*]] = arith.index_cast %dim_2 : index to i64
// CHECK: %c2_i64 = arith.constant 2 : i64
// CHECK: %[[T_12:.*]] = arith.divsi %[[T_11]], %c2_i64 : i64
// CHECK: %[[T_13:.*]] = arith.muli %[[T_10]], %c2_i64 : i64
// CHECK: %from_elements = tensor.from_elements %[[T_8]], %[[T_9]], %[[T_10]], %c2_i64, %[[T_12]] : tensor<5xi64>
// CHECK: %[[C2:.*]] = arith.constant 2 : i64
// CHECK: %[[T_12:.*]] = arith.divsi %[[T_11]], %[[C2]] : i64
// CHECK: %[[T_13:.*]] = arith.muli %[[T_10]], %[[C2]] : i64
// CHECK: %from_elements = tensor.from_elements %[[T_8]], %[[T_9]], %[[T_10]], %[[C2]], %[[T_12]] : tensor<5xi64>
// CHECK: %[[T_14:.*]] = stablehlo.dynamic_reshape %[[T_7]], %from_elements : (tensor<3x3x2x2xf32>, tensor<5xi64>) -> tensor<3x3x2x2x1xf32>
// CHECK: %[[T_15:.*]] = stablehlo.transpose %[[T_14]], dims = [0, 1, 3, 2, 4] : (tensor<3x3x2x2x1xf32>) -> tensor<3x3x2x2x1xf32>
// CHECK: %from_elements_3 = tensor.from_elements %[[T_8]], %[[T_9]], %[[T_13]], %[[T_12]] : tensor<4xi64>
// CHECK: %[[T_16:.*]] = stablehlo.dynamic_reshape %[[T_15]], %from_elements_3 : (tensor<3x3x2x2x1xf32>, tensor<4xi64>) -> tensor<3x3x4x1xf32>
// CHECK: %[[from_elements_3:.*]] = tensor.from_elements %[[T_8]], %[[T_9]], %[[T_13]], %[[T_12]] : tensor<4xi64>
// CHECK: %[[T_16:.*]] = stablehlo.dynamic_reshape %[[T_15]], %[[from_elements_3]] : (tensor<3x3x2x2x1xf32>, tensor<4xi64>) -> tensor<3x3x4x1xf32>
// CHECK: %[[T_17:.*]] = stablehlo.convolution(%[[T_0]], %[[T_16]])
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[0, 1, o, i]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 2 : i64} : (tensor<1x2x7x7xf32>, tensor<3x3x4x1xf32>) -> tensor<1x4x15x15xf32>
// CHECK: %[[T_18:.*]] = torch_c.from_builtin_tensor %[[T_17]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32>

View File

@ -3,12 +3,9 @@
// CHECK-LABEL: func.func @torch.aten.slice.strided$slice_like(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT0]]
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
// CHECK: %[[INT10:.*]] = torch.constant.int 10
// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT10]]
// CHECK: %[[T1:.*]] = arith.constant 0 : i64
// CHECK: %[[T2:.*]] = arith.constant 2 : i64
// CHECK: %[[T3:.*]] = arith.constant 10 : i64
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x?xf32>
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM]] : index to i64
@ -42,7 +39,7 @@
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T10]], %[[C0_I64_5]], %[[C0_I64_5]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T21]], %[[T18]], %[[T19]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64>
// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x?x?xf32>
// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS_6]], %[[FROM_ELEMENTS_7]] : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x?x?xf32>
// CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[T23]] : !torch.vtensor<[?,?,?],f32>
func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
@ -58,12 +55,9 @@ func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32
// CHECK-LABEL: func.func @torch.aten.slice.strided.static$slice_like(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[2,65,256],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,65,256],f32> -> tensor<4x65x256xf32>
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT0]]
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]]
// CHECK: %[[INT9223372036854775807:.*]] = torch.constant.int 9223372036854775807
// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT9223372036854775807]]
// CHECK: %[[T1:.*]] = arith.constant 0 : i64
// CHECK: %[[T2:.*]] = arith.constant 2 : i64
// CHECK: %[[T3:.*]] = arith.constant 9223372036854775807 : i64
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<4x65x256xf32>
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM]] : index to i64
@ -97,7 +91,7 @@ func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T10]], %[[C0_I64_5]], %[[C0_I64_5]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T21]], %[[T18]], %[[T19]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64>
// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<2x65x256xf32>
// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS_6]], %[[FROM_ELEMENTS_7]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<2x65x256xf32>
// CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<2x65x256xf32> -> !torch.vtensor<[2,65,256],f32>
// CHECK: return %[[T23]] : !torch.vtensor<[2,65,256],f32>
func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[2,65,256],f32> {
@ -113,12 +107,9 @@ func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,6
// CHECK-LABEL: func.func @torch.aten.slice.last$slice_like(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,1,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT0]]
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT1]]
// CHECK: %[[INT:.*]]-1 = torch.constant.int -1
// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT]]-1
// CHECK: %[[T1:.*]] = arith.constant 0 : i64
// CHECK: %[[T2:.*]] = arith.constant 1 : i64
// CHECK: %[[T3:.*]] = arith.constant -1 : i64
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?x?xf32>
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM]] : index to i64
@ -152,7 +143,7 @@ func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,6
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_5]], %[[T10]], %[[C0_I64_5]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T17]], %[[T21]], %[[T19]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64>
// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x1x?xf32>
// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS_6]], %[[FROM_ELEMENTS_7]] : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x1x?xf32>
// CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<?x1x?xf32> -> !torch.vtensor<[?,1,?],f32>
// CHECK: return %[[T23]] : !torch.vtensor<[?,1,?],f32>
func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,1,?],f32> {
@ -168,12 +159,9 @@ func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>)
// CHECK-LABEL: func.func @torch.aten.slice.last.static$slice_like(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,1,256],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,65,256],f32> -> tensor<4x65x256xf32>
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT0]]
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT1]]
// CHECK: %[[INT:.*]]-1 = torch.constant.int -1
// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT]]-1
// CHECK: %[[T1:.*]] = arith.constant 0 : i64
// CHECK: %[[T2:.*]] = arith.constant 1 : i64
// CHECK: %[[T3:.*]] = arith.constant -1 : i64
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<4x65x256xf32>
// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM]] : index to i64
@ -207,7 +195,7 @@ func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>)
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_5]], %[[T10]], %[[C0_I64_5]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T17]], %[[T21]], %[[T19]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64>
// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x1x256xf32>
// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS_6]], %[[FROM_ELEMENTS_7]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x1x256xf32>
// CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<4x1x256xf32> -> !torch.vtensor<[4,1,256],f32>
// CHECK: return %[[T23]] : !torch.vtensor<[4,1,256],f32>
func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,1,256],f32> {
@ -224,8 +212,7 @@ func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,2
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT2]]
// CHECK: %[[T1:.*]] = arith.constant 2 : i64
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?x?xf32>
@ -247,7 +234,7 @@ func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,2
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_4]], %[[C0_I64]], %[[C0_I64_4]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_5:.*]] = tensor.from_elements %[[T3]], %[[T7]], %[[T5]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64>
// CHECK: %[[T8:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x?x?xf32>
// CHECK: %[[T8:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS_6]] : (tensor<?x?x?xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<?x?x?xf32>
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[T9]] : !torch.vtensor<[?,?,?],f32>
func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
@ -264,8 +251,7 @@ func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>)
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,33,256],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,65,256],f32> -> tensor<4x65x256xf32>
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT2]]
// CHECK: %[[T1:.*]] = arith.constant 2 : i64
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<4x65x256xf32>
@ -287,7 +273,7 @@ func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>)
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_4]], %[[C0_I64]], %[[C0_I64_4]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_5:.*]] = tensor.from_elements %[[T3]], %[[T7]], %[[T5]] : tensor<3xi64>
// CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64>
// CHECK: %[[T8:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x33x256xf32>
// CHECK: %[[T8:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS_6]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x33x256xf32>
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<4x33x256xf32> -> !torch.vtensor<[4,33,256],f32>
// CHECK: return %[[T9]] : !torch.vtensor<[4,33,256],f32>
func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,33,256],f32> {