From 738d45d3bbabbd0c1026cf923d7ebeec19eb244f Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Fri, 1 Nov 2024 14:56:48 -0500 Subject: [PATCH] add scalarization patterns to support dynamic pytorch pad exports (#3838) 1. Adds case handling for `aten.slice.tensor` shape inference with negative strides. This is not technically allowed by native pytorch, but it is useful for ONNX ingest. We were getting some incorrect shapes for these negative strided slice ops. 2. Adds scalarization support for ops seen in pytorch pad exports to ONNX. These are typically `aten.view` `aten.transpose.int` and `aten.slice.Tensor` with negative strides (and rank 2). 3. Allows view op `self` to be added to the worklist conditionally, based on whether the view op actually occurs as a middle point in a shape computation. --- .../Transforms/AbstractInterpLibrary.cpp | 62 ++++- .../Torch/Transforms/ScalarizeShapes.cpp | 246 +++++++++++++++-- .../build_tools/abstract_interp_lib_gen.py | 28 +- test/Dialect/Torch/scalarize-shapes.mlir | 259 ++++++++++++++++++ 4 files changed, 568 insertions(+), 27 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 233c6be7e..ead29d59a 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -10131,8 +10131,66 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %2 : !torch.tuple, list>\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.int) -> !torch.list {\n" -" %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list, !torch.int, !torch.optional, !torch.optional, !torch.int) -> !torch.list\n" -" return %0 : !torch.list\n" +" %int-1 = torch.constant.int -1\n" +" %none = torch.constant.none\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %9 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %9 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" }\n" +" %2 = torch.aten.__isnot__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.int) {\n" +" %9 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %9 : !torch.int\n" +" } else {\n" +" %9 = func.call @__torch__.torch.jit._shape_functions.max_int() : () -> !torch.int\n" +" torch.prim.If.yield %9 : !torch.int\n" +" }\n" +" %4 = torch.aten.lt.int %arg4, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %5:3 = torch.prim.If %4 -> (!torch.int, !torch.int, !torch.int) {\n" +" %9 = torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.int) {\n" +" %20 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int\n" +" %21 = torch.aten.add.int %1, %20 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %21 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %1 : !torch.int\n" +" }\n" +" %11 = torch.aten.lt.int %10, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %12 = torch.prim.If %11 -> (!torch.int) {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %10 : !torch.int\n" +" }\n" +" %13 = torch.aten.lt.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %14 = torch.prim.If %13 -> (!torch.int) {\n" +" %20 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int\n" +" %21 = torch.aten.add.int %3, %20 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %21 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" %15 = torch.aten.lt.int %14, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %16 = torch.prim.If %15 -> (!torch.int) {\n" +" torch.prim.If.yield %int-1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %14 : !torch.int\n" +" }\n" +" %17 = torch.aten.add.int %16, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %18 = torch.aten.add.int %12, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %19 = torch.aten.neg.int %arg4 : !torch.int -> !torch.int\n" +" torch.prim.If.yield %17, %18, %19 : !torch.int, !torch.int, !torch.int\n" +" } else {\n" +" torch.prim.If.yield %1, %3, %arg4 : !torch.int, !torch.int, !torch.int\n" +" }\n" +" %6 = torch.derefine %5#0 : !torch.int to !torch.optional\n" +" %7 = torch.derefine %5#1 : !torch.int to !torch.optional\n" +" %8 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %6, %7, %5#2) : (!torch.list, !torch.int, !torch.optional, !torch.optional, !torch.int) -> !torch.list\n" +" return %8 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.as_strided\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional) -> !torch.list {\n" " return %arg1 : !torch.list\n" diff --git a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp index 345b5e156..9a85fbaa8 100644 --- a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp +++ b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp @@ -17,6 +17,7 @@ #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" using namespace mlir; @@ -310,7 +311,9 @@ public: auto selfShape = selfTy.getSizes(); int64_t selfRank = selfShape.size(); - dim = dim < 0 ? dim + selfRank : dim; + dim = toPositiveDim(dim, selfRank); + if (!isValidDim(dim, selfRank)) + return failure(); int64_t dimLength = elements.size(); if (selfShape[dim] != dimLength) return rewriter.notifyMatchFailure( @@ -362,6 +365,11 @@ public: auto loc = op.getLoc(); ImplicitLocOpBuilder b(loc, rewriter); + auto selfTy = cast(op.getSelf().getType()); + auto resultTy = cast(op.getType()); + if (!selfTy.areAllSizesKnown() || !resultTy.areAllSizesKnown()) + return rewriter.notifyMatchFailure(op, "requires static sizes"); + SmallVector elements; if (failed(getListFromTensor(op.getSelf(), elements))) return failure(); @@ -379,39 +387,69 @@ public: if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) return rewriter.notifyMatchFailure(op, "requires a constant step"); - if (step < 0) - return rewriter.notifyMatchFailure(op, "requires a positive step value"); - - auto selfTy = cast(op.getSelf().getType()); auto selfShape = selfTy.getSizes(); + auto resultShape = resultTy.getSizes(); int64_t selfRank = selfShape.size(); // Correct for negative indexing: - dim = dim < 0 ? dim + selfRank : dim; + dim = toPositiveDim(dim, selfRank); + if (!isValidDim(dim, selfRank)) + return failure(); - int64_t dimLength = elements.size(); + int64_t dimLength = selfShape[dim]; start = start < 0 ? start + dimLength : start; end = end < 0 ? end + dimLength : end; + end = (end < 0) ? -1 : end; + end = (end < 0 && step > 0) ? 0 : end; start = start < 0 ? 0 : start; - end = end < 0 ? 0 : end; end = end > dimLength ? dimLength : end; - if (selfShape[dim] != dimLength) - return rewriter.notifyMatchFailure( - op, "dim length does not match number of elements"); + int64_t frontDimProd = 1, backDimProd = 1; + for (int64_t i = 0; i < selfRank; i++) { + if (i < dim) + frontDimProd *= selfShape[i]; + if (i > dim) + backDimProd *= selfShape[i]; + } + int64_t fullDimProd = frontDimProd * dimLength * backDimProd; + if (fullDimProd != (int64_t)elements.size()) + return rewriter.notifyMatchFailure(op, "unexpected number of elements."); - for (int64_t i = 0; i < selfRank; ++i) { - if (i == dim) + // [d0,d1] i -> (i//d1, i % d1) -> (i//d1) * d1 + (i % d1) + // [d0,d1,d2] i -> (i//d2, i%d2) -> ((i//(d1*d2), (i//d2) % d1, i % d2) + + auto isSliceIdx = [&](int64_t i) { + int64_t dimidx = (i / backDimProd) % dimLength; + bool onStep = ((dimidx - start) % step == 0); + bool beforeEnd = (step < 0 && dimidx > end); + beforeEnd = beforeEnd || (step > 0 && dimidx < end); + bool afterBegin = (step < 0 && dimidx <= start); + afterBegin = afterBegin || (step > 0 && dimidx >= start); + return onStep && beforeEnd && afterBegin; + }; + + auto flipIdx = [&](int64_t i) { + int64_t frontIdx = (i / (backDimProd * dimLength)); + int64_t dimIdx = (i / (backDimProd)) % dimLength; + int64_t flipDimIdx = dimLength - 1 - dimIdx; + int64_t backIdx = i % (backDimProd); + return frontIdx * (dimLength * backDimProd) + flipDimIdx * (backDimProd) + + backIdx; + }; + SmallVector selected; + for (int64_t i = 0; i < (int64_t)elements.size(); i++) { + if (!isSliceIdx(i)) continue; - if (selfShape[i] != 1) - return rewriter.notifyMatchFailure(op, - "expects unary non-dim dimension"); + int64_t index = (step > 0) ? i : flipIdx(i); + selected.push_back(elements[index]); } - SmallVector selected; - for (int i = start; i < end; i += step) - selected.push_back(elements[i]); + fullDimProd = (fullDimProd * resultShape[dim]) / selfShape[dim]; + if ((int64_t)selected.size() != fullDimProd) + return rewriter.notifyMatchFailure( + op, "Constructed slice values have an incompatable number of " + "elements to match the provided return type."); SmallVector values; if (failed(materializeFolds(b, selected, values))) @@ -424,6 +462,114 @@ public: }; } // namespace +namespace { +class PropagateAtenTransposeIntPattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenTransposeIntOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, rewriter); + + auto selfTy = cast(op.getSelf().getType()); + auto resultTy = cast(op.getType()); + if (!selfTy.areAllSizesKnown() || !resultTy.areAllSizesKnown()) + return rewriter.notifyMatchFailure(op, "requires static sizes"); + + SmallVector elements; + if (failed(getListFromTensor(op.getSelf(), elements))) + return failure(); + + int64_t dim0, dim1; + if (!matchPattern(op.getDim0(), m_TorchConstantInt(&dim0))) + return failure(); + if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1))) + return failure(); + + ArrayRef selfSizes = selfTy.getSizes(); + int64_t rank = selfSizes.size(); + + dim0 = toPositiveDim(dim0, rank); + dim1 = toPositiveDim(dim1, rank); + if (!isValidDim(dim0, rank) || !isValidDim(dim0, rank)) + return failure(); + + if (dim0 == dim1) { + rewriter.replaceOp(op, op.getSelf()); + return success(); + } + + if (dim0 > dim1) { + // swap dim0 and dim1 + dim0 = dim0 + dim1; + dim1 = dim0 - dim1; + dim0 -= dim1; + } + + // A generic transpose will look like... + // [frontDimsFlat, dim0, midDimsFlat, dim1, backDimsFlat] -> . + // [frontDimsFlat, dim1, midDimsFlat, dim0, backDimsFlat] . + // If any of front, mid, or back don't actually exist (e.g. dim0 = 0, or + // dim1 = dim0 + 1), the reassociation of completely flattened indices will + // remain unaffected by the artificially unsqueezed dims. + // -------- + // Setting some notation, let D0,D1,D2,D3,D4 be the respective dim sizes of + // "self". Let D'j be the transpose dim sizes, and Djk = Dj*Dk. Let fl_trans + // and fl_self be 1-D flattened tensors. Then: + // -------- + // fl_trans[i] = + // = trans[i/D'1234, i/(D'234) % D'1, i/(D'34) % D'2, i/D'4 % D'3, i % D'4] + // = trans[i/D1234, i/D214 % D3, i/D14 % D2, i/D4 % D1, i % D4] + // = self[i/D1234, i/D4 % D1, i/D14 % D2, i/D214 % D3, i % D4] + // = fl_self[dot.prod(indices, (D1234,D234,D34,D4,1))] . + // -------- + // reassoc(i) = (i/(D1234)) * D1234 + + // (i/D4 % D1) * D234 + + // (i/(D14) % D2) * D34 + + // (i/(D214) % D3) * D4 + + // (i % D4) . + + SmallVector D(5, 1); + int64_t i = -1; + // D[0] corresponds to flattened front dims + while (++i < dim0) + D[0] *= selfSizes[i]; + // D[1] is the earliest transpose dim + D[1] = selfSizes[i]; + // D[2] corresponds to flattened middle dims + while (++i < dim1) + D[2] *= selfSizes[i]; + // D[3] is the later transpose dim + D[3] = selfSizes[i]; + // D[4] corresponds to flattened back dims + while (++i < rank) + D[4] *= selfSizes[i]; + + int64_t D1234 = D[1] * D[2] * D[3] * D[4]; + int64_t fullDP = D[0] * D1234; + if (fullDP != (int64_t)elements.size()) + return failure(); + auto reassoc = [&](int64_t i) { + return (i / D1234) * D1234 + ((i / D[4]) % D[1]) * D[2] * D[3] * D[4] + + ((i / (D[1] * D[4])) % D[2]) * D[3] * D[4] + + ((i / (D[2] * D[1] * D[4])) % D[3]) * D[4] + (i % D[4]); + }; + SmallVector transposedFolds; + transposedFolds.reserve(fullDP); + for (int64_t i = 0; i < fullDP; i++) + transposedFolds.push_back(elements[reassoc(i)]); + + SmallVector transposedVals; + if (failed(materializeFolds(b, transposedFolds, transposedVals))) + return failure(); + + Value result = constructAtenTensorOpFromList(b, resultTy, transposedVals); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace namespace { class PropagateAtenWhereSelfPattern : public OpRewritePattern { public: @@ -600,6 +746,27 @@ public: }; } // namespace +namespace { +template +class PropagateAtenViewLikePattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenViewLikeOp op, + PatternRewriter &rewriter) const override { + SmallVector selfFolds; + if (failed(getListFromTensor(op.getSelf(), selfFolds))) + return failure(); + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + SmallVector selfVals; + if (failed(materializeFolds(b, selfFolds, selfVals))) + return failure(); + Value result = constructAtenTensorOpFromList(b, op.getType(), selfVals); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + namespace { template struct ArithmeticHelper { @@ -1065,6 +1232,34 @@ bool isAnchorOp(Operation *op) { isPrimListOfInts(op); } +// The argument to this function, op, is the use of some source op, srcOp. If +// this function returns true, we want to invalidate srcOp as a target for shape +// scalarization. +bool isInvalidValidViewConsumer(Operation *op, + SetVector &workList) { + // if the consumer isn't a view op, don't invalidate it + auto view = dyn_cast_or_null(op); + if (!view) + return false; + auto resultTy = dyn_cast(view.getType()); + if (!resultTy || !resultTy.hasDtype()) + return true; + // if the view op doesn't return integer types, then srcOp is not a shape + // tensor. note: prim lists will always get added before reaching this + // function call. + if (!isa(resultTy.getDtype())) + return true; + // check uses of the view op. + // If the view op has a use in our worklist, then it needs to be scalarized. + for (OpOperand &use : op->getUses()) { + Operation *userOp = use.getOwner(); + if (workList.contains(userOp)) + return false; + } + // invalidate, since the view op was added as a one-off for canonicalization. + return true; +} + void populateScalarizationFoldPatterns(RewritePatternSet &patterns) { patterns.insert, FoldAtenSqueezePattern, @@ -1078,6 +1273,11 @@ void populateScalarizationCanonicalizePatterns(RewritePatternSet &patterns) { } void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) { + patterns.add>(patterns.getContext(), + /*benefit=*/10); + patterns.insert, + PropagateAtenViewLikePattern>( + patterns.getContext()); // A note on division: onnx.Div from int, int -> int types rounds towards // zero. The torch DivTensorOp actually doesn't allow returning an int dtype, // but this was artificially plummbed through. Unfortunately, there is no @@ -1088,6 +1288,7 @@ void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) { PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern, PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern, PropagateAtenWhereSelfPattern, PropagateAtenBroadcastToPattern, + PropagateAtenTransposeIntPattern, PropagateAtenArithmeticPattern, PropagateAtenArithmeticPattern, PropagateAtenArithmeticPattern, @@ -1105,9 +1306,6 @@ void populateScalarizationRemovePatterns(RewritePatternSet &patterns) { RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, RemoveUnusedPattern>( patterns.getContext()); } @@ -1168,12 +1366,12 @@ public: // shapeCalculationOps. It's consumer (%1) is indeed a shape // calculation op, but the size.int op is an elementary unit of shape // computation. No futher gathering of producers is necessary to - // reduce this. Similarly, don't add the `self` of a view op. + // reduce this. Similarly, don't always add the `self` of a view op. for (OpOperand &use : op->getUses()) { Operation *userOp = use.getOwner(); if (shapeCalculationOps.contains(userOp) && !isSourceOpForShapeScalarization(userOp) && - !isa(userOp)) { + !isInvalidValidViewConsumer(userOp, shapeCalculationOps)) { shapeCalculationOps.insert(op); return; } diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 65cc18837..06437574d 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1903,7 +1903,33 @@ def aten〇_weight_norm_interface〡shape(v: List[int], g: List[int], dim: int = return upstream_shape_functions.unary(v), upstream_shape_functions.unary(g) def aten〇slice〇Tensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]: - return upstream_shape_functions.slice(self, dim, start, end, step) + start_val = start if start is not None else 0 + end_val = end if end is not None else upstream_shape_functions.max_int() + if (step < 0): + # Convert to equivalent postive-step parameters, which will require swapping start and end. + # If the parameters are in the normal range (0 <= start < d and -1 <= end <= start), then + # swapped_end = start + 1 and swapped_begin = end + 1. + # The shift of inclusion can cause issues if these parameters are not already resolved on the left. + # e.g. start = -1, end = -3 . So valid start is actually d-1, and valid end is d-3. Therefore, we + # should have swapped_end = d, but adding 1 to start before making it valid would result in an + # incorrect, but "valid", swapped_end = 0 for forward slicing. + # Additionally, if adding d doesn't make these values positive, but adding twice would, we need + # to clamp after resolving, otherwise the upstream function will try to resolve a second time. + if start_val < 0: + start_val += self[dim] + if start_val < 0: + start_val = 0 + if end_val < 0: + end_val += self[dim] + if end_val < 0: + end_val = -1 + + tmp = end_val + 1 + end_val = start_val + 1 + start_val = tmp + step = -step + return upstream_shape_functions.slice(self,dim,start_val,end_val,step) + def aten〇as_strided〡shape(self: List[int], size: List[int], stride: List[int], storage_offset: Optional[int] = None) -> List[int]: return size diff --git a/test/Dialect/Torch/scalarize-shapes.mlir b/test/Dialect/Torch/scalarize-shapes.mlir index 166e2fda5..f5193b701 100644 --- a/test/Dialect/Torch/scalarize-shapes.mlir +++ b/test/Dialect/Torch/scalarize-shapes.mlir @@ -374,3 +374,262 @@ func.func @squeeze_dim_full_fold(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.li %59 = torch.prim.ListConstruct %58 : (!torch.int) -> !torch.list return %59 : !torch.list } + +// ----- + +// CHECK-LABEL: @pytorch_dynamic_pad_export_view$prop( +func.func @pytorch_dynamic_pad_export_view$prop(%arg0: !torch.vtensor<[?,144,?,?],f32>) -> !torch.vtensor<[4,2],si64> { + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[x0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[x1:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I3:.*]] = torch.constant.int 3 + // CHECK: %[[x2:.*]] = torch.aten.size.int %arg0, %[[I3]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I144:.*]] = torch.constant.int 144 + // CHECK: %[[I0_0:.*]] = torch.constant.int 0 + // CHECK: %[[I0_1:.*]] = torch.constant.int 0 + // CHECK: %[[I0_2:.*]] = torch.constant.int 0 + // CHECK: %[[I0_3:.*]] = torch.constant.int 0 + // CHECK: %[[x3:.*]] = torch.prim.ListConstruct %[[x0]], %[[I144]], %[[x1]], %[[x2]], %[[I0_0]], %[[I0_1]], %[[I0_2]], %[[I0_3]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[none:.*]] = torch.constant.none + // CHECK: %[[false:.*]] = torch.constant.bool false + // CHECK: %[[x4:.*]] = torch.aten.tensor %[[x3]], %[[none]], %[[none]], %[[false]] : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4,2],si64> + // CHECK: return %[[x4]] : !torch.vtensor<[4,2],si64> + %0 = torch.vtensor.literal(dense<0> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %1 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %2 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %int-9223372036854775807 = torch.constant.int -9223372036854775807 + %int-1 = torch.constant.int -1 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,144,?,?],f32> -> !torch.vtensor<[4],si64> + %4 = torch.prim.ListConstruct %3, %0 : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.list + %5 = torch.aten.cat %4, %int0 : !torch.list, !torch.int -> !torch.vtensor<[8],si64> + %6 = torch.prim.ListConstruct %int-1, %int2 : (!torch.int, !torch.int) -> !torch.list + %7 = torch.aten.view %5, %6 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[4,2],si64> + %8 = torch.aten.slice.Tensor %7, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64> + %9 = torch.aten.transpose.int %8, %int0, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4],si64> + %10 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list + %11 = torch.aten.view %9, %10 : !torch.vtensor<[2,4],si64>, !torch.list -> !torch.vtensor<[8],si64> + %12 = torch.aten.index_select %11, %int0, %1 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %13 = torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int + %14 = torch.aten.index_select %11, %int0, %2 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %15 = torch.aten.item %14 : !torch.vtensor<[1],si64> -> !torch.int + %16 = torch.prim.ListConstruct %13, %15: (!torch.int, !torch.int) -> !torch.list + return %7 : !torch.vtensor<[4,2],si64> +} + +// ----- + +// CHECK-LABEL: @pytorch_dynamic_pad_export_slice$prop( +func.func @pytorch_dynamic_pad_export_slice$prop(%arg0: !torch.vtensor<[?,144,?,?],f32>) -> !torch.vtensor<[4,2],si64> { + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[x0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[x1:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I3:.*]] = torch.constant.int 3 + // CHECK: %[[x2:.*]] = torch.aten.size.int %arg0, %[[I3]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I0_0:.*]] = torch.constant.int 0 + // CHECK: %[[I0_1:.*]] = torch.constant.int 0 + // CHECK: %[[I0_2:.*]] = torch.constant.int 0 + // CHECK: %[[I0_3:.*]] = torch.constant.int 0 + // CHECK: %[[I144:.*]] = torch.constant.int 144 + // CHECK: %[[x3:.*]] = torch.prim.ListConstruct %[[I0_0]], %[[I0_1]], %[[I0_2]], %[[I0_3]], %[[x1]], %[[x2]], %[[x0]], %[[I144]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[none:.*]] = torch.constant.none + // CHECK: %[[false:.*]] = torch.constant.bool false + // CHECK: %[[x4:.*]] = torch.aten.tensor %[[x3]], %[[none]], %[[none]], %[[false]] : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4,2],si64> + // CHECK: return %[[x4]] : !torch.vtensor<[4,2],si64> + %0 = torch.vtensor.literal(dense<0> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %1 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %2 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %int-9223372036854775807 = torch.constant.int -9223372036854775807 + %int-1 = torch.constant.int -1 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,144,?,?],f32> -> !torch.vtensor<[4],si64> + %4 = torch.prim.ListConstruct %3, %0 : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.list + %5 = torch.aten.cat %4, %int0 : !torch.list, !torch.int -> !torch.vtensor<[8],si64> + %6 = torch.prim.ListConstruct %int-1, %int2 : (!torch.int, !torch.int) -> !torch.list + %7 = torch.aten.view %5, %6 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[4,2],si64> + %8 = torch.aten.slice.Tensor %7, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64> + %9 = torch.aten.transpose.int %8, %int0, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4],si64> + %10 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list + %11 = torch.aten.view %9, %10 : !torch.vtensor<[2,4],si64>, !torch.list -> !torch.vtensor<[8],si64> + %12 = torch.aten.index_select %11, %int0, %1 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %13 = torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int + %14 = torch.aten.index_select %11, %int0, %2 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %15 = torch.aten.item %14 : !torch.vtensor<[1],si64> -> !torch.int + %16 = torch.prim.ListConstruct %13, %15: (!torch.int, !torch.int) -> !torch.list + return %8 : !torch.vtensor<[4,2],si64> +} + +// ----- + +// CHECK-LABEL: @pytorch_dynamic_pad_export_transpose$prop( +func.func @pytorch_dynamic_pad_export_transpose$prop(%arg0: !torch.vtensor<[?,144,?,?],f32>) -> !torch.vtensor<[2,4],si64> { + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[DIM0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[DIM2:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I3:.*]] = torch.constant.int 3 + // CHECK: %[[DIM3:.*]] = torch.aten.size.int %arg0, %[[I3]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I0_0:.*]] = torch.constant.int 0 + // CHECK: %[[I0_1:.*]] = torch.constant.int 0 + // CHECK: %[[I0_2:.*]] = torch.constant.int 0 + // CHECK: %[[I0_3:.*]] = torch.constant.int 0 + // CHECK: %[[DIM1:.*]] = torch.constant.int 144 + // CHECK: %[[x3:.*]] = torch.prim.ListConstruct %[[I0_0]], %[[I0_1]], %[[DIM2]], %[[DIM0]], %[[I0_2]], %[[I0_3]], %[[DIM3]], %[[DIM1]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[none:.*]] = torch.constant.none + // CHECK: %[[false:.*]] = torch.constant.bool false + // CHECK: %[[x4:.*]] = torch.aten.tensor %[[x3]], %[[none]], %[[none]], %[[false]] : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[2,4],si64> + // CHECK: return %[[x4]] : !torch.vtensor<[2,4],si64> + %0 = torch.vtensor.literal(dense<0> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %1 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %2 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %int-9223372036854775807 = torch.constant.int -9223372036854775807 + %int-1 = torch.constant.int -1 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,144,?,?],f32> -> !torch.vtensor<[4],si64> + %4 = torch.prim.ListConstruct %3, %0 : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.list + %5 = torch.aten.cat %4, %int0 : !torch.list, !torch.int -> !torch.vtensor<[8],si64> + %6 = torch.prim.ListConstruct %int-1, %int2 : (!torch.int, !torch.int) -> !torch.list + %7 = torch.aten.view %5, %6 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[4,2],si64> + %8 = torch.aten.slice.Tensor %7, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64> + %9 = torch.aten.transpose.int %8, %int0, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4],si64> + %10 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list + %11 = torch.aten.view %9, %10 : !torch.vtensor<[2,4],si64>, !torch.list -> !torch.vtensor<[8],si64> + %12 = torch.aten.index_select %11, %int0, %1 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %13 = torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int + %14 = torch.aten.index_select %11, %int0, %2 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %15 = torch.aten.item %14 : !torch.vtensor<[1],si64> -> !torch.int + %16 = torch.prim.ListConstruct %13, %15: (!torch.int, !torch.int) -> !torch.list + return %9 : !torch.vtensor<[2,4],si64> +} + +// ----- + +// CHECK-LABEL: @pytorch_dynamic_pad_export_full( +func.func @pytorch_dynamic_pad_export_full(%arg0: !torch.vtensor<[?,144,?,?],f32>) -> !torch.list { + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[DIM2:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[x1:.*]] = torch.prim.ListConstruct %[[DIM2]], %[[I0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: return %[[x1]] : !torch.list + %0 = torch.vtensor.literal(dense<0> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %1 = torch.vtensor.literal(dense<2> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %2 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %int-9223372036854775807 = torch.constant.int -9223372036854775807 + %int-1 = torch.constant.int -1 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,144,?,?],f32> -> !torch.vtensor<[4],si64> + %4 = torch.prim.ListConstruct %3, %0 : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.list + %5 = torch.aten.cat %4, %int0 : !torch.list, !torch.int -> !torch.vtensor<[8],si64> + %6 = torch.prim.ListConstruct %int-1, %int2 : (!torch.int, !torch.int) -> !torch.list + %7 = torch.aten.view %5, %6 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[4,2],si64> + %8 = torch.aten.slice.Tensor %7, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64> + %9 = torch.aten.transpose.int %8, %int0, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4],si64> + %10 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list + %11 = torch.aten.view %9, %10 : !torch.vtensor<[2,4],si64>, !torch.list -> !torch.vtensor<[8],si64> + %12 = torch.aten.index_select %11, %int0, %1 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %13 = torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int + %14 = torch.aten.index_select %11, %int0, %2 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %15 = torch.aten.item %14 : !torch.vtensor<[1],si64> -> !torch.int + %16 = torch.prim.ListConstruct %13, %15: (!torch.int, !torch.int) -> !torch.list + return %16 : !torch.list +} + +// ----- + +// CHECK-LABEL: @transpose$prop_3d_0_1 +func.func @transpose$prop_3d_0_1(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[2,2,2],si64> { + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[SIZE0_0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[SIZE0_1:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[SIZE0_2:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I3:.*]] = torch.constant.int 3 + // CHECK: %[[SIZE0_3:.*]] = torch.aten.size.int %arg0, %[[I3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I0_0:.*]] = torch.constant.int 0 + // CHECK: %[[SIZE1_0:.*]] = torch.aten.size.int %arg1, %[[I0_0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I1_1:.*]] = torch.constant.int 1 + // CHECK: %[[SIZE1_1:.*]] = torch.aten.size.int %arg1, %[[I1_1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I2_2:.*]] = torch.constant.int 2 + // CHECK: %[[SIZE1_2:.*]] = torch.aten.size.int %arg1, %[[I2_2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I3_3:.*]] = torch.constant.int 3 + // CHECK: %[[SIZE1_3:.*]] = torch.aten.size.int %arg1, %[[I3_3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[SIZE0_0]], %[[SIZE0_1]], %[[SIZE1_0]], %[[SIZE1_1]], %[[SIZE0_2]], %[[SIZE0_3]], %[[SIZE1_2]], %[[SIZE1_3]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[TENSOR:.*]] = torch.aten.tensor %[[LIST]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[2,2,2],si64> + // CHECK: return %[[TENSOR]] : !torch.vtensor<[2,2,2],si64> + %0 = torch.vtensor.literal(dense<2> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %int-1 = torch.constant.int -1 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %1 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,?,?,?],f32> -> !torch.vtensor<[4],si64> + %2 = torch.aten._shape_as_tensor %arg1 : !torch.vtensor<[?,?,?,?],f32> -> !torch.vtensor<[4],si64> + %3 = torch.prim.ListConstruct %1, %2 : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.list + %4 = torch.aten.cat %3, %int0 : !torch.list, !torch.int -> !torch.vtensor<[8],si64> + %5 = torch.prim.ListConstruct %int2, %int2, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6 = torch.aten.view %4, %5 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[2,2,2],si64> + %7 = torch.aten.transpose.int %6, %int0, %int1 : !torch.vtensor<[2,2,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,2,2],si64> + %8 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list + %9 = torch.aten.view %7, %8 : !torch.vtensor<[2,2,2],si64>, !torch.list -> !torch.vtensor<[8],si64> + %10 = torch.aten.index_select %9, %int0, %0 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %11 = torch.aten.item %10 : !torch.vtensor<[1],si64> -> !torch.int + %12 = torch.prim.ListConstruct %11 : (!torch.int) -> !torch.list + return %7 : !torch.vtensor<[2,2,2],si64> +} + +// ----- + +// CHECK-LABEL: @transpose$prop_3d_m1_0 +func.func @transpose$prop_3d_m1_0(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[2,2,2],si64> { + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[SIZE0_0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[SIZE0_1:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[SIZE0_2:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I3:.*]] = torch.constant.int 3 + // CHECK: %[[SIZE0_3:.*]] = torch.aten.size.int %arg0, %[[I3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I0_0:.*]] = torch.constant.int 0 + // CHECK: %[[SIZE1_0:.*]] = torch.aten.size.int %arg1, %[[I0_0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I1_1:.*]] = torch.constant.int 1 + // CHECK: %[[SIZE1_1:.*]] = torch.aten.size.int %arg1, %[[I1_1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I2_2:.*]] = torch.constant.int 2 + // CHECK: %[[SIZE1_2:.*]] = torch.aten.size.int %arg1, %[[I2_2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I3_3:.*]] = torch.constant.int 3 + // CHECK: %[[SIZE1_3:.*]] = torch.aten.size.int %arg1, %[[I3_3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[SIZE0_0]], %[[SIZE1_0]], %[[SIZE0_2]], %[[SIZE1_2]], %[[SIZE0_1]], %[[SIZE1_1]], %[[SIZE0_3]], %[[SIZE1_3]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[TENSOR:.*]] = torch.aten.tensor %[[LIST]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[2,2,2],si64> + // CHECK: return %[[TENSOR]] : !torch.vtensor<[2,2,2],si64> + %0 = torch.vtensor.literal(dense<2> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %int-1 = torch.constant.int -1 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %1 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,?,?,?],f32> -> !torch.vtensor<[4],si64> + %2 = torch.aten._shape_as_tensor %arg1 : !torch.vtensor<[?,?,?,?],f32> -> !torch.vtensor<[4],si64> + %3 = torch.prim.ListConstruct %1, %2 : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.list + %4 = torch.aten.cat %3, %int0 : !torch.list, !torch.int -> !torch.vtensor<[8],si64> + %5 = torch.prim.ListConstruct %int2, %int2, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6 = torch.aten.view %4, %5 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[2,2,2],si64> + %7 = torch.aten.transpose.int %6, %int-1, %int0 : !torch.vtensor<[2,2,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,2,2],si64> + %8 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list + %9 = torch.aten.view %7, %8 : !torch.vtensor<[2,2,2],si64>, !torch.list -> !torch.vtensor<[8],si64> + %10 = torch.aten.index_select %9, %int0, %0 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %11 = torch.aten.item %10 : !torch.vtensor<[1],si64> -> !torch.int + %12 = torch.prim.ListConstruct %11 : (!torch.int) -> !torch.list + return %7 : !torch.vtensor<[2,2,2],si64> +}