add scalarization patterns to support dynamic pytorch pad exports (#3838)

1. Adds case handling for `aten.slice.tensor` shape inference with
negative strides. This is not technically allowed by native pytorch, but
it is useful for ONNX ingest. We were getting some incorrect shapes for
these negative strided slice ops.
2. Adds scalarization support for ops seen in pytorch pad exports to
ONNX. These are typically `aten.view` `aten.transpose.int` and
`aten.slice.Tensor` with negative strides (and rank 2).
3. Allows view op `self` to be added to the worklist conditionally,
based on whether the view op actually occurs as a middle point in a
shape computation.
pull/3845/head
zjgarvey 2024-11-01 14:56:48 -05:00 committed by GitHub
parent 39d69db5ca
commit 738d45d3bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 568 additions and 27 deletions

View File

@ -10131,8 +10131,66 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" return %2 : !torch.tuple<list<int>, list<int>>\n" " return %2 : !torch.tuple<list<int>, list<int>>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.int) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>\n" " %int-1 = torch.constant.int -1\n"
" return %0 : !torch.list<int>\n" " %none = torch.constant.none\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
" %9 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %9 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %int0 : !torch.int\n"
" }\n"
" %2 = torch.aten.__isnot__ %arg3, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %3 = torch.prim.If %2 -> (!torch.int) {\n"
" %9 = torch.prim.unchecked_cast %arg3 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %9 : !torch.int\n"
" } else {\n"
" %9 = func.call @__torch__.torch.jit._shape_functions.max_int() : () -> !torch.int\n"
" torch.prim.If.yield %9 : !torch.int\n"
" }\n"
" %4 = torch.aten.lt.int %arg4, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %5:3 = torch.prim.If %4 -> (!torch.int, !torch.int, !torch.int) {\n"
" %9 = torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %10 = torch.prim.If %9 -> (!torch.int) {\n"
" %20 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %21 = torch.aten.add.int %1, %20 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %21 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %1 : !torch.int\n"
" }\n"
" %11 = torch.aten.lt.int %10, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %12 = torch.prim.If %11 -> (!torch.int) {\n"
" torch.prim.If.yield %int0 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %10 : !torch.int\n"
" }\n"
" %13 = torch.aten.lt.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %14 = torch.prim.If %13 -> (!torch.int) {\n"
" %20 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %21 = torch.aten.add.int %3, %20 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %21 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %3 : !torch.int\n"
" }\n"
" %15 = torch.aten.lt.int %14, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %16 = torch.prim.If %15 -> (!torch.int) {\n"
" torch.prim.If.yield %int-1 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %14 : !torch.int\n"
" }\n"
" %17 = torch.aten.add.int %16, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %18 = torch.aten.add.int %12, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %19 = torch.aten.neg.int %arg4 : !torch.int -> !torch.int\n"
" torch.prim.If.yield %17, %18, %19 : !torch.int, !torch.int, !torch.int\n"
" } else {\n"
" torch.prim.If.yield %1, %3, %arg4 : !torch.int, !torch.int, !torch.int\n"
" }\n"
" %6 = torch.derefine %5#0 : !torch.int to !torch.optional<int>\n"
" %7 = torch.derefine %5#1 : !torch.int to !torch.optional<int>\n"
" %8 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %6, %7, %5#2) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>\n"
" return %8 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.as_strided\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.as_strided\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg1 : !torch.list<int>\n" " return %arg1 : !torch.list<int>\n"

View File

@ -17,6 +17,7 @@
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
using namespace mlir; using namespace mlir;
@ -310,7 +311,9 @@ public:
auto selfShape = selfTy.getSizes(); auto selfShape = selfTy.getSizes();
int64_t selfRank = selfShape.size(); int64_t selfRank = selfShape.size();
dim = dim < 0 ? dim + selfRank : dim; dim = toPositiveDim(dim, selfRank);
if (!isValidDim(dim, selfRank))
return failure();
int64_t dimLength = elements.size(); int64_t dimLength = elements.size();
if (selfShape[dim] != dimLength) if (selfShape[dim] != dimLength)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
@ -362,6 +365,11 @@ public:
auto loc = op.getLoc(); auto loc = op.getLoc();
ImplicitLocOpBuilder b(loc, rewriter); ImplicitLocOpBuilder b(loc, rewriter);
auto selfTy = cast<BaseTensorType>(op.getSelf().getType());
auto resultTy = cast<BaseTensorType>(op.getType());
if (!selfTy.areAllSizesKnown() || !resultTy.areAllSizesKnown())
return rewriter.notifyMatchFailure(op, "requires static sizes");
SmallVector<OpFoldResult> elements; SmallVector<OpFoldResult> elements;
if (failed(getListFromTensor(op.getSelf(), elements))) if (failed(getListFromTensor(op.getSelf(), elements)))
return failure(); return failure();
@ -379,39 +387,69 @@ public:
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) if (!matchPattern(op.getStep(), m_TorchConstantInt(&step)))
return rewriter.notifyMatchFailure(op, "requires a constant step"); return rewriter.notifyMatchFailure(op, "requires a constant step");
if (step < 0)
return rewriter.notifyMatchFailure(op, "requires a positive step value");
auto selfTy = cast<BaseTensorType>(op.getSelf().getType());
auto selfShape = selfTy.getSizes(); auto selfShape = selfTy.getSizes();
auto resultShape = resultTy.getSizes();
int64_t selfRank = selfShape.size(); int64_t selfRank = selfShape.size();
// Correct for negative indexing: // Correct for negative indexing:
dim = dim < 0 ? dim + selfRank : dim; dim = toPositiveDim(dim, selfRank);
if (!isValidDim(dim, selfRank))
return failure();
int64_t dimLength = elements.size(); int64_t dimLength = selfShape[dim];
start = start < 0 ? start + dimLength : start; start = start < 0 ? start + dimLength : start;
end = end < 0 ? end + dimLength : end; end = end < 0 ? end + dimLength : end;
end = (end < 0) ? -1 : end;
end = (end < 0 && step > 0) ? 0 : end;
start = start < 0 ? 0 : start; start = start < 0 ? 0 : start;
end = end < 0 ? 0 : end;
end = end > dimLength ? dimLength : end; end = end > dimLength ? dimLength : end;
if (selfShape[dim] != dimLength) int64_t frontDimProd = 1, backDimProd = 1;
return rewriter.notifyMatchFailure( for (int64_t i = 0; i < selfRank; i++) {
op, "dim length does not match number of elements"); if (i < dim)
frontDimProd *= selfShape[i];
if (i > dim)
backDimProd *= selfShape[i];
}
int64_t fullDimProd = frontDimProd * dimLength * backDimProd;
if (fullDimProd != (int64_t)elements.size())
return rewriter.notifyMatchFailure(op, "unexpected number of elements.");
for (int64_t i = 0; i < selfRank; ++i) { // [d0,d1] i -> (i//d1, i % d1) -> (i//d1) * d1 + (i % d1)
if (i == dim) // [d0,d1,d2] i -> (i//d2, i%d2) -> ((i//(d1*d2), (i//d2) % d1, i % d2)
auto isSliceIdx = [&](int64_t i) {
int64_t dimidx = (i / backDimProd) % dimLength;
bool onStep = ((dimidx - start) % step == 0);
bool beforeEnd = (step < 0 && dimidx > end);
beforeEnd = beforeEnd || (step > 0 && dimidx < end);
bool afterBegin = (step < 0 && dimidx <= start);
afterBegin = afterBegin || (step > 0 && dimidx >= start);
return onStep && beforeEnd && afterBegin;
};
auto flipIdx = [&](int64_t i) {
int64_t frontIdx = (i / (backDimProd * dimLength));
int64_t dimIdx = (i / (backDimProd)) % dimLength;
int64_t flipDimIdx = dimLength - 1 - dimIdx;
int64_t backIdx = i % (backDimProd);
return frontIdx * (dimLength * backDimProd) + flipDimIdx * (backDimProd) +
backIdx;
};
SmallVector<OpFoldResult> selected;
for (int64_t i = 0; i < (int64_t)elements.size(); i++) {
if (!isSliceIdx(i))
continue; continue;
if (selfShape[i] != 1) int64_t index = (step > 0) ? i : flipIdx(i);
return rewriter.notifyMatchFailure(op, selected.push_back(elements[index]);
"expects unary non-dim dimension");
} }
SmallVector<OpFoldResult> selected; fullDimProd = (fullDimProd * resultShape[dim]) / selfShape[dim];
for (int i = start; i < end; i += step) if ((int64_t)selected.size() != fullDimProd)
selected.push_back(elements[i]); return rewriter.notifyMatchFailure(
op, "Constructed slice values have an incompatable number of "
"elements to match the provided return type.");
SmallVector<Value> values; SmallVector<Value> values;
if (failed(materializeFolds(b, selected, values))) if (failed(materializeFolds(b, selected, values)))
@ -424,6 +462,114 @@ public:
}; };
} // namespace } // namespace
namespace {
class PropagateAtenTransposeIntPattern
: public OpRewritePattern<AtenTransposeIntOp> {
public:
using OpRewritePattern<AtenTransposeIntOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenTransposeIntOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
ImplicitLocOpBuilder b(loc, rewriter);
auto selfTy = cast<BaseTensorType>(op.getSelf().getType());
auto resultTy = cast<BaseTensorType>(op.getType());
if (!selfTy.areAllSizesKnown() || !resultTy.areAllSizesKnown())
return rewriter.notifyMatchFailure(op, "requires static sizes");
SmallVector<OpFoldResult> elements;
if (failed(getListFromTensor(op.getSelf(), elements)))
return failure();
int64_t dim0, dim1;
if (!matchPattern(op.getDim0(), m_TorchConstantInt(&dim0)))
return failure();
if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1)))
return failure();
ArrayRef<int64_t> selfSizes = selfTy.getSizes();
int64_t rank = selfSizes.size();
dim0 = toPositiveDim(dim0, rank);
dim1 = toPositiveDim(dim1, rank);
if (!isValidDim(dim0, rank) || !isValidDim(dim0, rank))
return failure();
if (dim0 == dim1) {
rewriter.replaceOp(op, op.getSelf());
return success();
}
if (dim0 > dim1) {
// swap dim0 and dim1
dim0 = dim0 + dim1;
dim1 = dim0 - dim1;
dim0 -= dim1;
}
// A generic transpose will look like...
// [frontDimsFlat, dim0, midDimsFlat, dim1, backDimsFlat] -> .
// [frontDimsFlat, dim1, midDimsFlat, dim0, backDimsFlat] .
// If any of front, mid, or back don't actually exist (e.g. dim0 = 0, or
// dim1 = dim0 + 1), the reassociation of completely flattened indices will
// remain unaffected by the artificially unsqueezed dims.
// --------
// Setting some notation, let D0,D1,D2,D3,D4 be the respective dim sizes of
// "self". Let D'j be the transpose dim sizes, and Djk = Dj*Dk. Let fl_trans
// and fl_self be 1-D flattened tensors. Then:
// --------
// fl_trans[i] =
// = trans[i/D'1234, i/(D'234) % D'1, i/(D'34) % D'2, i/D'4 % D'3, i % D'4]
// = trans[i/D1234, i/D214 % D3, i/D14 % D2, i/D4 % D1, i % D4]
// = self[i/D1234, i/D4 % D1, i/D14 % D2, i/D214 % D3, i % D4]
// = fl_self[dot.prod(indices, (D1234,D234,D34,D4,1))] .
// --------
// reassoc(i) = (i/(D1234)) * D1234 +
// (i/D4 % D1) * D234 +
// (i/(D14) % D2) * D34 +
// (i/(D214) % D3) * D4 +
// (i % D4) .
SmallVector<int64_t, 5> D(5, 1);
int64_t i = -1;
// D[0] corresponds to flattened front dims
while (++i < dim0)
D[0] *= selfSizes[i];
// D[1] is the earliest transpose dim
D[1] = selfSizes[i];
// D[2] corresponds to flattened middle dims
while (++i < dim1)
D[2] *= selfSizes[i];
// D[3] is the later transpose dim
D[3] = selfSizes[i];
// D[4] corresponds to flattened back dims
while (++i < rank)
D[4] *= selfSizes[i];
int64_t D1234 = D[1] * D[2] * D[3] * D[4];
int64_t fullDP = D[0] * D1234;
if (fullDP != (int64_t)elements.size())
return failure();
auto reassoc = [&](int64_t i) {
return (i / D1234) * D1234 + ((i / D[4]) % D[1]) * D[2] * D[3] * D[4] +
((i / (D[1] * D[4])) % D[2]) * D[3] * D[4] +
((i / (D[2] * D[1] * D[4])) % D[3]) * D[4] + (i % D[4]);
};
SmallVector<OpFoldResult> transposedFolds;
transposedFolds.reserve(fullDP);
for (int64_t i = 0; i < fullDP; i++)
transposedFolds.push_back(elements[reassoc(i)]);
SmallVector<Value> transposedVals;
if (failed(materializeFolds(b, transposedFolds, transposedVals)))
return failure();
Value result = constructAtenTensorOpFromList(b, resultTy, transposedVals);
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
namespace { namespace {
class PropagateAtenWhereSelfPattern : public OpRewritePattern<AtenWhereSelfOp> { class PropagateAtenWhereSelfPattern : public OpRewritePattern<AtenWhereSelfOp> {
public: public:
@ -600,6 +746,27 @@ public:
}; };
} // namespace } // namespace
namespace {
template <typename AtenViewLikeOp>
class PropagateAtenViewLikePattern : public OpRewritePattern<AtenViewLikeOp> {
public:
using OpRewritePattern<AtenViewLikeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenViewLikeOp op,
PatternRewriter &rewriter) const override {
SmallVector<OpFoldResult> selfFolds;
if (failed(getListFromTensor(op.getSelf(), selfFolds)))
return failure();
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
SmallVector<Value> selfVals;
if (failed(materializeFolds(b, selfFolds, selfVals)))
return failure();
Value result = constructAtenTensorOpFromList(b, op.getType(), selfVals);
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
namespace { namespace {
template <typename OpTy> struct ArithmeticHelper { template <typename OpTy> struct ArithmeticHelper {
@ -1065,6 +1232,34 @@ bool isAnchorOp(Operation *op) {
isPrimListOfInts(op); isPrimListOfInts(op);
} }
// The argument to this function, op, is the use of some source op, srcOp. If
// this function returns true, we want to invalidate srcOp as a target for shape
// scalarization.
bool isInvalidValidViewConsumer(Operation *op,
SetVector<Operation *> &workList) {
// if the consumer isn't a view op, don't invalidate it
auto view = dyn_cast_or_null<AtenViewOp>(op);
if (!view)
return false;
auto resultTy = dyn_cast<ValueTensorType>(view.getType());
if (!resultTy || !resultTy.hasDtype())
return true;
// if the view op doesn't return integer types, then srcOp is not a shape
// tensor. note: prim lists will always get added before reaching this
// function call.
if (!isa<mlir::IntegerType>(resultTy.getDtype()))
return true;
// check uses of the view op.
// If the view op has a use in our worklist, then it needs to be scalarized.
for (OpOperand &use : op->getUses()) {
Operation *userOp = use.getOwner();
if (workList.contains(userOp))
return false;
}
// invalidate, since the view op was added as a one-off for canonicalization.
return true;
}
void populateScalarizationFoldPatterns(RewritePatternSet &patterns) { void populateScalarizationFoldPatterns(RewritePatternSet &patterns) {
patterns.insert<FoldAtenSqueezePattern<AtenSqueezeOp>, patterns.insert<FoldAtenSqueezePattern<AtenSqueezeOp>,
FoldAtenSqueezePattern<AtenSqueezeDimOp>, FoldAtenSqueezePattern<AtenSqueezeDimOp>,
@ -1078,6 +1273,11 @@ void populateScalarizationCanonicalizePatterns(RewritePatternSet &patterns) {
} }
void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) { void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) {
patterns.add<PropagateAtenViewLikePattern<AtenViewOp>>(patterns.getContext(),
/*benefit=*/10);
patterns.insert<PropagateAtenViewLikePattern<AtenFlattenUsingIntsOp>,
PropagateAtenViewLikePattern<AtenUnflattenIntOp>>(
patterns.getContext());
// A note on division: onnx.Div from int, int -> int types rounds towards // A note on division: onnx.Div from int, int -> int types rounds towards
// zero. The torch DivTensorOp actually doesn't allow returning an int dtype, // zero. The torch DivTensorOp actually doesn't allow returning an int dtype,
// but this was artificially plummbed through. Unfortunately, there is no // but this was artificially plummbed through. Unfortunately, there is no
@ -1088,6 +1288,7 @@ void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) {
PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern, PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern,
PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern, PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern,
PropagateAtenWhereSelfPattern, PropagateAtenBroadcastToPattern, PropagateAtenWhereSelfPattern, PropagateAtenBroadcastToPattern,
PropagateAtenTransposeIntPattern,
PropagateAtenArithmeticPattern<AtenAddTensorOp, AtenAddIntOp>, PropagateAtenArithmeticPattern<AtenAddTensorOp, AtenAddIntOp>,
PropagateAtenArithmeticPattern<AtenSubTensorOp, AtenSubIntOp>, PropagateAtenArithmeticPattern<AtenSubTensorOp, AtenSubIntOp>,
PropagateAtenArithmeticPattern<AtenMulTensorOp, AtenMulIntOp>, PropagateAtenArithmeticPattern<AtenMulTensorOp, AtenMulIntOp>,
@ -1105,9 +1306,6 @@ void populateScalarizationRemovePatterns(RewritePatternSet &patterns) {
RemoveUnusedPattern<Torch::AtenSizeIntOp>, RemoveUnusedPattern<Torch::AtenSizeIntOp>,
RemoveUnusedPattern<Torch::AtenSliceTensorOp>, RemoveUnusedPattern<Torch::AtenSliceTensorOp>,
RemoveUnusedPattern<Torch::AtenTensorOp>, RemoveUnusedPattern<Torch::AtenTensorOp>,
RemoveUnusedPattern<Torch::ConstantBoolOp>,
RemoveUnusedPattern<Torch::ConstantIntOp>,
RemoveUnusedPattern<Torch::ConstantNoneOp>,
RemoveUnusedPattern<Torch::PrimListConstructOp>>( RemoveUnusedPattern<Torch::PrimListConstructOp>>(
patterns.getContext()); patterns.getContext());
} }
@ -1168,12 +1366,12 @@ public:
// shapeCalculationOps. It's consumer (%1) is indeed a shape // shapeCalculationOps. It's consumer (%1) is indeed a shape
// calculation op, but the size.int op is an elementary unit of shape // calculation op, but the size.int op is an elementary unit of shape
// computation. No futher gathering of producers is necessary to // computation. No futher gathering of producers is necessary to
// reduce this. Similarly, don't add the `self` of a view op. // reduce this. Similarly, don't always add the `self` of a view op.
for (OpOperand &use : op->getUses()) { for (OpOperand &use : op->getUses()) {
Operation *userOp = use.getOwner(); Operation *userOp = use.getOwner();
if (shapeCalculationOps.contains(userOp) && if (shapeCalculationOps.contains(userOp) &&
!isSourceOpForShapeScalarization(userOp) && !isSourceOpForShapeScalarization(userOp) &&
!isa<AtenViewOp>(userOp)) { !isInvalidValidViewConsumer(userOp, shapeCalculationOps)) {
shapeCalculationOps.insert(op); shapeCalculationOps.insert(op);
return; return;
} }

View File

@ -1903,7 +1903,33 @@ def aten_weight_norm_interface〡shape(v: List[int], g: List[int], dim: int =
return upstream_shape_functions.unary(v), upstream_shape_functions.unary(g) return upstream_shape_functions.unary(v), upstream_shape_functions.unary(g)
def atensliceTensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]: def atensliceTensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
return upstream_shape_functions.slice(self, dim, start, end, step) start_val = start if start is not None else 0
end_val = end if end is not None else upstream_shape_functions.max_int()
if (step < 0):
# Convert to equivalent postive-step parameters, which will require swapping start and end.
# If the parameters are in the normal range (0 <= start < d and -1 <= end <= start), then
# swapped_end = start + 1 and swapped_begin = end + 1.
# The shift of inclusion can cause issues if these parameters are not already resolved on the left.
# e.g. start = -1, end = -3 . So valid start is actually d-1, and valid end is d-3. Therefore, we
# should have swapped_end = d, but adding 1 to start before making it valid would result in an
# incorrect, but "valid", swapped_end = 0 for forward slicing.
# Additionally, if adding d doesn't make these values positive, but adding twice would, we need
# to clamp after resolving, otherwise the upstream function will try to resolve a second time.
if start_val < 0:
start_val += self[dim]
if start_val < 0:
start_val = 0
if end_val < 0:
end_val += self[dim]
if end_val < 0:
end_val = -1
tmp = end_val + 1
end_val = start_val + 1
start_val = tmp
step = -step
return upstream_shape_functions.slice(self,dim,start_val,end_val,step)
def atenas_strided〡shape(self: List[int], size: List[int], stride: List[int], storage_offset: Optional[int] = None) -> List[int]: def atenas_strided〡shape(self: List[int], size: List[int], stride: List[int], storage_offset: Optional[int] = None) -> List[int]:
return size return size

View File

@ -374,3 +374,262 @@ func.func @squeeze_dim_full_fold(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.li
%59 = torch.prim.ListConstruct %58 : (!torch.int) -> !torch.list<int> %59 = torch.prim.ListConstruct %58 : (!torch.int) -> !torch.list<int>
return %59 : !torch.list<int> return %59 : !torch.list<int>
} }
// -----
// CHECK-LABEL: @pytorch_dynamic_pad_export_view$prop(
func.func @pytorch_dynamic_pad_export_view$prop(%arg0: !torch.vtensor<[?,144,?,?],f32>) -> !torch.vtensor<[4,2],si64> {
// CHECK: %[[I0:.*]] = torch.constant.int 0
// CHECK: %[[x0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[I2:.*]] = torch.constant.int 2
// CHECK: %[[x1:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[I3:.*]] = torch.constant.int 3
// CHECK: %[[x2:.*]] = torch.aten.size.int %arg0, %[[I3]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[I144:.*]] = torch.constant.int 144
// CHECK: %[[I0_0:.*]] = torch.constant.int 0
// CHECK: %[[I0_1:.*]] = torch.constant.int 0
// CHECK: %[[I0_2:.*]] = torch.constant.int 0
// CHECK: %[[I0_3:.*]] = torch.constant.int 0
// CHECK: %[[x3:.*]] = torch.prim.ListConstruct %[[x0]], %[[I144]], %[[x1]], %[[x2]], %[[I0_0]], %[[I0_1]], %[[I0_2]], %[[I0_3]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[none:.*]] = torch.constant.none
// CHECK: %[[false:.*]] = torch.constant.bool false
// CHECK: %[[x4:.*]] = torch.aten.tensor %[[x3]], %[[none]], %[[none]], %[[false]] : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4,2],si64>
// CHECK: return %[[x4]] : !torch.vtensor<[4,2],si64>
%0 = torch.vtensor.literal(dense<0> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%1 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%2 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int-9223372036854775807 = torch.constant.int -9223372036854775807
%int-1 = torch.constant.int -1
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,144,?,?],f32> -> !torch.vtensor<[4],si64>
%4 = torch.prim.ListConstruct %3, %0 : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.list<vtensor>
%5 = torch.aten.cat %4, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[8],si64>
%6 = torch.prim.ListConstruct %int-1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%7 = torch.aten.view %5, %6 : !torch.vtensor<[8],si64>, !torch.list<int> -> !torch.vtensor<[4,2],si64>
%8 = torch.aten.slice.Tensor %7, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64>
%9 = torch.aten.transpose.int %8, %int0, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4],si64>
%10 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
%11 = torch.aten.view %9, %10 : !torch.vtensor<[2,4],si64>, !torch.list<int> -> !torch.vtensor<[8],si64>
%12 = torch.aten.index_select %11, %int0, %1 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%13 = torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int
%14 = torch.aten.index_select %11, %int0, %2 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%15 = torch.aten.item %14 : !torch.vtensor<[1],si64> -> !torch.int
%16 = torch.prim.ListConstruct %13, %15: (!torch.int, !torch.int) -> !torch.list<int>
return %7 : !torch.vtensor<[4,2],si64>
}
// -----
// CHECK-LABEL: @pytorch_dynamic_pad_export_slice$prop(
func.func @pytorch_dynamic_pad_export_slice$prop(%arg0: !torch.vtensor<[?,144,?,?],f32>) -> !torch.vtensor<[4,2],si64> {
// CHECK: %[[I0:.*]] = torch.constant.int 0
// CHECK: %[[x0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[I2:.*]] = torch.constant.int 2
// CHECK: %[[x1:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[I3:.*]] = torch.constant.int 3
// CHECK: %[[x2:.*]] = torch.aten.size.int %arg0, %[[I3]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[I0_0:.*]] = torch.constant.int 0
// CHECK: %[[I0_1:.*]] = torch.constant.int 0
// CHECK: %[[I0_2:.*]] = torch.constant.int 0
// CHECK: %[[I0_3:.*]] = torch.constant.int 0
// CHECK: %[[I144:.*]] = torch.constant.int 144
// CHECK: %[[x3:.*]] = torch.prim.ListConstruct %[[I0_0]], %[[I0_1]], %[[I0_2]], %[[I0_3]], %[[x1]], %[[x2]], %[[x0]], %[[I144]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[none:.*]] = torch.constant.none
// CHECK: %[[false:.*]] = torch.constant.bool false
// CHECK: %[[x4:.*]] = torch.aten.tensor %[[x3]], %[[none]], %[[none]], %[[false]] : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4,2],si64>
// CHECK: return %[[x4]] : !torch.vtensor<[4,2],si64>
%0 = torch.vtensor.literal(dense<0> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%1 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%2 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int-9223372036854775807 = torch.constant.int -9223372036854775807
%int-1 = torch.constant.int -1
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,144,?,?],f32> -> !torch.vtensor<[4],si64>
%4 = torch.prim.ListConstruct %3, %0 : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.list<vtensor>
%5 = torch.aten.cat %4, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[8],si64>
%6 = torch.prim.ListConstruct %int-1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%7 = torch.aten.view %5, %6 : !torch.vtensor<[8],si64>, !torch.list<int> -> !torch.vtensor<[4,2],si64>
%8 = torch.aten.slice.Tensor %7, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64>
%9 = torch.aten.transpose.int %8, %int0, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4],si64>
%10 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
%11 = torch.aten.view %9, %10 : !torch.vtensor<[2,4],si64>, !torch.list<int> -> !torch.vtensor<[8],si64>
%12 = torch.aten.index_select %11, %int0, %1 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%13 = torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int
%14 = torch.aten.index_select %11, %int0, %2 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%15 = torch.aten.item %14 : !torch.vtensor<[1],si64> -> !torch.int
%16 = torch.prim.ListConstruct %13, %15: (!torch.int, !torch.int) -> !torch.list<int>
return %8 : !torch.vtensor<[4,2],si64>
}
// -----
// CHECK-LABEL: @pytorch_dynamic_pad_export_transpose$prop(
func.func @pytorch_dynamic_pad_export_transpose$prop(%arg0: !torch.vtensor<[?,144,?,?],f32>) -> !torch.vtensor<[2,4],si64> {
// CHECK: %[[I0:.*]] = torch.constant.int 0
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[I2:.*]] = torch.constant.int 2
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[I3:.*]] = torch.constant.int 3
// CHECK: %[[DIM3:.*]] = torch.aten.size.int %arg0, %[[I3]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[I0_0:.*]] = torch.constant.int 0
// CHECK: %[[I0_1:.*]] = torch.constant.int 0
// CHECK: %[[I0_2:.*]] = torch.constant.int 0
// CHECK: %[[I0_3:.*]] = torch.constant.int 0
// CHECK: %[[DIM1:.*]] = torch.constant.int 144
// CHECK: %[[x3:.*]] = torch.prim.ListConstruct %[[I0_0]], %[[I0_1]], %[[DIM2]], %[[DIM0]], %[[I0_2]], %[[I0_3]], %[[DIM3]], %[[DIM1]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[none:.*]] = torch.constant.none
// CHECK: %[[false:.*]] = torch.constant.bool false
// CHECK: %[[x4:.*]] = torch.aten.tensor %[[x3]], %[[none]], %[[none]], %[[false]] : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[2,4],si64>
// CHECK: return %[[x4]] : !torch.vtensor<[2,4],si64>
%0 = torch.vtensor.literal(dense<0> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%1 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%2 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int-9223372036854775807 = torch.constant.int -9223372036854775807
%int-1 = torch.constant.int -1
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,144,?,?],f32> -> !torch.vtensor<[4],si64>
%4 = torch.prim.ListConstruct %3, %0 : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.list<vtensor>
%5 = torch.aten.cat %4, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[8],si64>
%6 = torch.prim.ListConstruct %int-1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%7 = torch.aten.view %5, %6 : !torch.vtensor<[8],si64>, !torch.list<int> -> !torch.vtensor<[4,2],si64>
%8 = torch.aten.slice.Tensor %7, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64>
%9 = torch.aten.transpose.int %8, %int0, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4],si64>
%10 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
%11 = torch.aten.view %9, %10 : !torch.vtensor<[2,4],si64>, !torch.list<int> -> !torch.vtensor<[8],si64>
%12 = torch.aten.index_select %11, %int0, %1 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%13 = torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int
%14 = torch.aten.index_select %11, %int0, %2 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%15 = torch.aten.item %14 : !torch.vtensor<[1],si64> -> !torch.int
%16 = torch.prim.ListConstruct %13, %15: (!torch.int, !torch.int) -> !torch.list<int>
return %9 : !torch.vtensor<[2,4],si64>
}
// -----
// CHECK-LABEL: @pytorch_dynamic_pad_export_full(
func.func @pytorch_dynamic_pad_export_full(%arg0: !torch.vtensor<[?,144,?,?],f32>) -> !torch.list<int> {
// CHECK: %[[I2:.*]] = torch.constant.int 2
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[I0:.*]] = torch.constant.int 0
// CHECK: %[[x1:.*]] = torch.prim.ListConstruct %[[DIM2]], %[[I0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: return %[[x1]] : !torch.list<int>
%0 = torch.vtensor.literal(dense<0> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%1 = torch.vtensor.literal(dense<2> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%2 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int-9223372036854775807 = torch.constant.int -9223372036854775807
%int-1 = torch.constant.int -1
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,144,?,?],f32> -> !torch.vtensor<[4],si64>
%4 = torch.prim.ListConstruct %3, %0 : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.list<vtensor>
%5 = torch.aten.cat %4, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[8],si64>
%6 = torch.prim.ListConstruct %int-1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%7 = torch.aten.view %5, %6 : !torch.vtensor<[8],si64>, !torch.list<int> -> !torch.vtensor<[4,2],si64>
%8 = torch.aten.slice.Tensor %7, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64>
%9 = torch.aten.transpose.int %8, %int0, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4],si64>
%10 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
%11 = torch.aten.view %9, %10 : !torch.vtensor<[2,4],si64>, !torch.list<int> -> !torch.vtensor<[8],si64>
%12 = torch.aten.index_select %11, %int0, %1 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%13 = torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int
%14 = torch.aten.index_select %11, %int0, %2 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%15 = torch.aten.item %14 : !torch.vtensor<[1],si64> -> !torch.int
%16 = torch.prim.ListConstruct %13, %15: (!torch.int, !torch.int) -> !torch.list<int>
return %16 : !torch.list<int>
}
// -----
// CHECK-LABEL: @transpose$prop_3d_0_1
func.func @transpose$prop_3d_0_1(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[2,2,2],si64> {
// CHECK: %[[I0:.*]] = torch.constant.int 0
// CHECK: %[[SIZE0_0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[I1:.*]] = torch.constant.int 1
// CHECK: %[[SIZE0_1:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[I2:.*]] = torch.constant.int 2
// CHECK: %[[SIZE0_2:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[I3:.*]] = torch.constant.int 3
// CHECK: %[[SIZE0_3:.*]] = torch.aten.size.int %arg0, %[[I3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[I0_0:.*]] = torch.constant.int 0
// CHECK: %[[SIZE1_0:.*]] = torch.aten.size.int %arg1, %[[I0_0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[I1_1:.*]] = torch.constant.int 1
// CHECK: %[[SIZE1_1:.*]] = torch.aten.size.int %arg1, %[[I1_1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[I2_2:.*]] = torch.constant.int 2
// CHECK: %[[SIZE1_2:.*]] = torch.aten.size.int %arg1, %[[I2_2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[I3_3:.*]] = torch.constant.int 3
// CHECK: %[[SIZE1_3:.*]] = torch.aten.size.int %arg1, %[[I3_3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[SIZE0_0]], %[[SIZE0_1]], %[[SIZE1_0]], %[[SIZE1_1]], %[[SIZE0_2]], %[[SIZE0_3]], %[[SIZE1_2]], %[[SIZE1_3]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[TENSOR:.*]] = torch.aten.tensor %[[LIST]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[2,2,2],si64>
// CHECK: return %[[TENSOR]] : !torch.vtensor<[2,2,2],si64>
%0 = torch.vtensor.literal(dense<2> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int-1 = torch.constant.int -1
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%1 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,?,?,?],f32> -> !torch.vtensor<[4],si64>
%2 = torch.aten._shape_as_tensor %arg1 : !torch.vtensor<[?,?,?,?],f32> -> !torch.vtensor<[4],si64>
%3 = torch.prim.ListConstruct %1, %2 : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.list<vtensor>
%4 = torch.aten.cat %3, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[8],si64>
%5 = torch.prim.ListConstruct %int2, %int2, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%6 = torch.aten.view %4, %5 : !torch.vtensor<[8],si64>, !torch.list<int> -> !torch.vtensor<[2,2,2],si64>
%7 = torch.aten.transpose.int %6, %int0, %int1 : !torch.vtensor<[2,2,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,2,2],si64>
%8 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
%9 = torch.aten.view %7, %8 : !torch.vtensor<[2,2,2],si64>, !torch.list<int> -> !torch.vtensor<[8],si64>
%10 = torch.aten.index_select %9, %int0, %0 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%11 = torch.aten.item %10 : !torch.vtensor<[1],si64> -> !torch.int
%12 = torch.prim.ListConstruct %11 : (!torch.int) -> !torch.list<int>
return %7 : !torch.vtensor<[2,2,2],si64>
}
// -----
// CHECK-LABEL: @transpose$prop_3d_m1_0
func.func @transpose$prop_3d_m1_0(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[2,2,2],si64> {
// CHECK: %[[I0:.*]] = torch.constant.int 0
// CHECK: %[[SIZE0_0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[I1:.*]] = torch.constant.int 1
// CHECK: %[[SIZE0_1:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[I2:.*]] = torch.constant.int 2
// CHECK: %[[SIZE0_2:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[I3:.*]] = torch.constant.int 3
// CHECK: %[[SIZE0_3:.*]] = torch.aten.size.int %arg0, %[[I3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[I0_0:.*]] = torch.constant.int 0
// CHECK: %[[SIZE1_0:.*]] = torch.aten.size.int %arg1, %[[I0_0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[I1_1:.*]] = torch.constant.int 1
// CHECK: %[[SIZE1_1:.*]] = torch.aten.size.int %arg1, %[[I1_1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[I2_2:.*]] = torch.constant.int 2
// CHECK: %[[SIZE1_2:.*]] = torch.aten.size.int %arg1, %[[I2_2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[I3_3:.*]] = torch.constant.int 3
// CHECK: %[[SIZE1_3:.*]] = torch.aten.size.int %arg1, %[[I3_3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[SIZE0_0]], %[[SIZE1_0]], %[[SIZE0_2]], %[[SIZE1_2]], %[[SIZE0_1]], %[[SIZE1_1]], %[[SIZE0_3]], %[[SIZE1_3]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[TENSOR:.*]] = torch.aten.tensor %[[LIST]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[2,2,2],si64>
// CHECK: return %[[TENSOR]] : !torch.vtensor<[2,2,2],si64>
%0 = torch.vtensor.literal(dense<2> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int-1 = torch.constant.int -1
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%1 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,?,?,?],f32> -> !torch.vtensor<[4],si64>
%2 = torch.aten._shape_as_tensor %arg1 : !torch.vtensor<[?,?,?,?],f32> -> !torch.vtensor<[4],si64>
%3 = torch.prim.ListConstruct %1, %2 : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.list<vtensor>
%4 = torch.aten.cat %3, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[8],si64>
%5 = torch.prim.ListConstruct %int2, %int2, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%6 = torch.aten.view %4, %5 : !torch.vtensor<[8],si64>, !torch.list<int> -> !torch.vtensor<[2,2,2],si64>
%7 = torch.aten.transpose.int %6, %int-1, %int0 : !torch.vtensor<[2,2,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,2,2],si64>
%8 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
%9 = torch.aten.view %7, %8 : !torch.vtensor<[2,2,2],si64>, !torch.list<int> -> !torch.vtensor<[8],si64>
%10 = torch.aten.index_select %9, %int0, %0 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
%11 = torch.aten.item %10 : !torch.vtensor<[1],si64> -> !torch.int
%12 = torch.prim.ListConstruct %11 : (!torch.int) -> !torch.list<int>
return %7 : !torch.vtensor<[2,2,2],si64>
}