mirror of https://github.com/llvm/torch-mlir
add scalarization patterns to support dynamic pytorch pad exports (#3838)
1. Adds case handling for `aten.slice.tensor` shape inference with negative strides. This is not technically allowed by native pytorch, but it is useful for ONNX ingest. We were getting some incorrect shapes for these negative strided slice ops. 2. Adds scalarization support for ops seen in pytorch pad exports to ONNX. These are typically `aten.view` `aten.transpose.int` and `aten.slice.Tensor` with negative strides (and rank 2). 3. Allows view op `self` to be added to the worklist conditionally, based on whether the view op actually occurs as a middle point in a shape computation.pull/3845/head
parent
39d69db5ca
commit
738d45d3bb
|
@ -10131,8 +10131,66 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" return %2 : !torch.tuple<list<int>, list<int>>\n"
|
" return %2 : !torch.tuple<list<int>, list<int>>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.int) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.int) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>\n"
|
" %int-1 = torch.constant.int -1\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" %none = torch.constant.none\n"
|
||||||
|
" %int0 = torch.constant.int 0\n"
|
||||||
|
" %int1 = torch.constant.int 1\n"
|
||||||
|
" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||||
|
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
|
||||||
|
" %9 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
|
||||||
|
" torch.prim.If.yield %9 : !torch.int\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.If.yield %int0 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" %2 = torch.aten.__isnot__ %arg3, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||||
|
" %3 = torch.prim.If %2 -> (!torch.int) {\n"
|
||||||
|
" %9 = torch.prim.unchecked_cast %arg3 : !torch.optional<int> -> !torch.int\n"
|
||||||
|
" torch.prim.If.yield %9 : !torch.int\n"
|
||||||
|
" } else {\n"
|
||||||
|
" %9 = func.call @__torch__.torch.jit._shape_functions.max_int() : () -> !torch.int\n"
|
||||||
|
" torch.prim.If.yield %9 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" %4 = torch.aten.lt.int %arg4, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" %5:3 = torch.prim.If %4 -> (!torch.int, !torch.int, !torch.int) {\n"
|
||||||
|
" %9 = torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" %10 = torch.prim.If %9 -> (!torch.int) {\n"
|
||||||
|
" %20 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %21 = torch.aten.add.int %1, %20 : !torch.int, !torch.int -> !torch.int\n"
|
||||||
|
" torch.prim.If.yield %21 : !torch.int\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.If.yield %1 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" %11 = torch.aten.lt.int %10, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" %12 = torch.prim.If %11 -> (!torch.int) {\n"
|
||||||
|
" torch.prim.If.yield %int0 : !torch.int\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.If.yield %10 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" %13 = torch.aten.lt.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" %14 = torch.prim.If %13 -> (!torch.int) {\n"
|
||||||
|
" %20 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %21 = torch.aten.add.int %3, %20 : !torch.int, !torch.int -> !torch.int\n"
|
||||||
|
" torch.prim.If.yield %21 : !torch.int\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.If.yield %3 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" %15 = torch.aten.lt.int %14, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" %16 = torch.prim.If %15 -> (!torch.int) {\n"
|
||||||
|
" torch.prim.If.yield %int-1 : !torch.int\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.If.yield %14 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" %17 = torch.aten.add.int %16, %int1 : !torch.int, !torch.int -> !torch.int\n"
|
||||||
|
" %18 = torch.aten.add.int %12, %int1 : !torch.int, !torch.int -> !torch.int\n"
|
||||||
|
" %19 = torch.aten.neg.int %arg4 : !torch.int -> !torch.int\n"
|
||||||
|
" torch.prim.If.yield %17, %18, %19 : !torch.int, !torch.int, !torch.int\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.If.yield %1, %3, %arg4 : !torch.int, !torch.int, !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" %6 = torch.derefine %5#0 : !torch.int to !torch.optional<int>\n"
|
||||||
|
" %7 = torch.derefine %5#1 : !torch.int to !torch.optional<int>\n"
|
||||||
|
" %8 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %6, %7, %5#2) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>\n"
|
||||||
|
" return %8 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.as_strided\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.as_strided\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||||
" return %arg1 : !torch.list<int>\n"
|
" return %arg1 : !torch.list<int>\n"
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
@ -310,7 +311,9 @@ public:
|
||||||
|
|
||||||
auto selfShape = selfTy.getSizes();
|
auto selfShape = selfTy.getSizes();
|
||||||
int64_t selfRank = selfShape.size();
|
int64_t selfRank = selfShape.size();
|
||||||
dim = dim < 0 ? dim + selfRank : dim;
|
dim = toPositiveDim(dim, selfRank);
|
||||||
|
if (!isValidDim(dim, selfRank))
|
||||||
|
return failure();
|
||||||
int64_t dimLength = elements.size();
|
int64_t dimLength = elements.size();
|
||||||
if (selfShape[dim] != dimLength)
|
if (selfShape[dim] != dimLength)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -362,6 +365,11 @@ public:
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
ImplicitLocOpBuilder b(loc, rewriter);
|
ImplicitLocOpBuilder b(loc, rewriter);
|
||||||
|
|
||||||
|
auto selfTy = cast<BaseTensorType>(op.getSelf().getType());
|
||||||
|
auto resultTy = cast<BaseTensorType>(op.getType());
|
||||||
|
if (!selfTy.areAllSizesKnown() || !resultTy.areAllSizesKnown())
|
||||||
|
return rewriter.notifyMatchFailure(op, "requires static sizes");
|
||||||
|
|
||||||
SmallVector<OpFoldResult> elements;
|
SmallVector<OpFoldResult> elements;
|
||||||
if (failed(getListFromTensor(op.getSelf(), elements)))
|
if (failed(getListFromTensor(op.getSelf(), elements)))
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -379,39 +387,69 @@ public:
|
||||||
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step)))
|
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step)))
|
||||||
return rewriter.notifyMatchFailure(op, "requires a constant step");
|
return rewriter.notifyMatchFailure(op, "requires a constant step");
|
||||||
|
|
||||||
if (step < 0)
|
|
||||||
return rewriter.notifyMatchFailure(op, "requires a positive step value");
|
|
||||||
|
|
||||||
auto selfTy = cast<BaseTensorType>(op.getSelf().getType());
|
|
||||||
auto selfShape = selfTy.getSizes();
|
auto selfShape = selfTy.getSizes();
|
||||||
|
auto resultShape = resultTy.getSizes();
|
||||||
int64_t selfRank = selfShape.size();
|
int64_t selfRank = selfShape.size();
|
||||||
|
|
||||||
// Correct for negative indexing:
|
// Correct for negative indexing:
|
||||||
dim = dim < 0 ? dim + selfRank : dim;
|
dim = toPositiveDim(dim, selfRank);
|
||||||
|
if (!isValidDim(dim, selfRank))
|
||||||
|
return failure();
|
||||||
|
|
||||||
int64_t dimLength = elements.size();
|
int64_t dimLength = selfShape[dim];
|
||||||
start = start < 0 ? start + dimLength : start;
|
start = start < 0 ? start + dimLength : start;
|
||||||
end = end < 0 ? end + dimLength : end;
|
end = end < 0 ? end + dimLength : end;
|
||||||
|
end = (end < 0) ? -1 : end;
|
||||||
|
end = (end < 0 && step > 0) ? 0 : end;
|
||||||
|
|
||||||
start = start < 0 ? 0 : start;
|
start = start < 0 ? 0 : start;
|
||||||
end = end < 0 ? 0 : end;
|
|
||||||
end = end > dimLength ? dimLength : end;
|
end = end > dimLength ? dimLength : end;
|
||||||
|
|
||||||
if (selfShape[dim] != dimLength)
|
int64_t frontDimProd = 1, backDimProd = 1;
|
||||||
return rewriter.notifyMatchFailure(
|
for (int64_t i = 0; i < selfRank; i++) {
|
||||||
op, "dim length does not match number of elements");
|
if (i < dim)
|
||||||
|
frontDimProd *= selfShape[i];
|
||||||
|
if (i > dim)
|
||||||
|
backDimProd *= selfShape[i];
|
||||||
|
}
|
||||||
|
int64_t fullDimProd = frontDimProd * dimLength * backDimProd;
|
||||||
|
if (fullDimProd != (int64_t)elements.size())
|
||||||
|
return rewriter.notifyMatchFailure(op, "unexpected number of elements.");
|
||||||
|
|
||||||
for (int64_t i = 0; i < selfRank; ++i) {
|
// [d0,d1] i -> (i//d1, i % d1) -> (i//d1) * d1 + (i % d1)
|
||||||
if (i == dim)
|
// [d0,d1,d2] i -> (i//d2, i%d2) -> ((i//(d1*d2), (i//d2) % d1, i % d2)
|
||||||
|
|
||||||
|
auto isSliceIdx = [&](int64_t i) {
|
||||||
|
int64_t dimidx = (i / backDimProd) % dimLength;
|
||||||
|
bool onStep = ((dimidx - start) % step == 0);
|
||||||
|
bool beforeEnd = (step < 0 && dimidx > end);
|
||||||
|
beforeEnd = beforeEnd || (step > 0 && dimidx < end);
|
||||||
|
bool afterBegin = (step < 0 && dimidx <= start);
|
||||||
|
afterBegin = afterBegin || (step > 0 && dimidx >= start);
|
||||||
|
return onStep && beforeEnd && afterBegin;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto flipIdx = [&](int64_t i) {
|
||||||
|
int64_t frontIdx = (i / (backDimProd * dimLength));
|
||||||
|
int64_t dimIdx = (i / (backDimProd)) % dimLength;
|
||||||
|
int64_t flipDimIdx = dimLength - 1 - dimIdx;
|
||||||
|
int64_t backIdx = i % (backDimProd);
|
||||||
|
return frontIdx * (dimLength * backDimProd) + flipDimIdx * (backDimProd) +
|
||||||
|
backIdx;
|
||||||
|
};
|
||||||
|
SmallVector<OpFoldResult> selected;
|
||||||
|
for (int64_t i = 0; i < (int64_t)elements.size(); i++) {
|
||||||
|
if (!isSliceIdx(i))
|
||||||
continue;
|
continue;
|
||||||
if (selfShape[i] != 1)
|
int64_t index = (step > 0) ? i : flipIdx(i);
|
||||||
return rewriter.notifyMatchFailure(op,
|
selected.push_back(elements[index]);
|
||||||
"expects unary non-dim dimension");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<OpFoldResult> selected;
|
fullDimProd = (fullDimProd * resultShape[dim]) / selfShape[dim];
|
||||||
for (int i = start; i < end; i += step)
|
if ((int64_t)selected.size() != fullDimProd)
|
||||||
selected.push_back(elements[i]);
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Constructed slice values have an incompatable number of "
|
||||||
|
"elements to match the provided return type.");
|
||||||
|
|
||||||
SmallVector<Value> values;
|
SmallVector<Value> values;
|
||||||
if (failed(materializeFolds(b, selected, values)))
|
if (failed(materializeFolds(b, selected, values)))
|
||||||
|
@ -424,6 +462,114 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class PropagateAtenTransposeIntPattern
|
||||||
|
: public OpRewritePattern<AtenTransposeIntOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<AtenTransposeIntOp>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenTransposeIntOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
ImplicitLocOpBuilder b(loc, rewriter);
|
||||||
|
|
||||||
|
auto selfTy = cast<BaseTensorType>(op.getSelf().getType());
|
||||||
|
auto resultTy = cast<BaseTensorType>(op.getType());
|
||||||
|
if (!selfTy.areAllSizesKnown() || !resultTy.areAllSizesKnown())
|
||||||
|
return rewriter.notifyMatchFailure(op, "requires static sizes");
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> elements;
|
||||||
|
if (failed(getListFromTensor(op.getSelf(), elements)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t dim0, dim1;
|
||||||
|
if (!matchPattern(op.getDim0(), m_TorchConstantInt(&dim0)))
|
||||||
|
return failure();
|
||||||
|
if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
ArrayRef<int64_t> selfSizes = selfTy.getSizes();
|
||||||
|
int64_t rank = selfSizes.size();
|
||||||
|
|
||||||
|
dim0 = toPositiveDim(dim0, rank);
|
||||||
|
dim1 = toPositiveDim(dim1, rank);
|
||||||
|
if (!isValidDim(dim0, rank) || !isValidDim(dim0, rank))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (dim0 == dim1) {
|
||||||
|
rewriter.replaceOp(op, op.getSelf());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dim0 > dim1) {
|
||||||
|
// swap dim0 and dim1
|
||||||
|
dim0 = dim0 + dim1;
|
||||||
|
dim1 = dim0 - dim1;
|
||||||
|
dim0 -= dim1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// A generic transpose will look like...
|
||||||
|
// [frontDimsFlat, dim0, midDimsFlat, dim1, backDimsFlat] -> .
|
||||||
|
// [frontDimsFlat, dim1, midDimsFlat, dim0, backDimsFlat] .
|
||||||
|
// If any of front, mid, or back don't actually exist (e.g. dim0 = 0, or
|
||||||
|
// dim1 = dim0 + 1), the reassociation of completely flattened indices will
|
||||||
|
// remain unaffected by the artificially unsqueezed dims.
|
||||||
|
// --------
|
||||||
|
// Setting some notation, let D0,D1,D2,D3,D4 be the respective dim sizes of
|
||||||
|
// "self". Let D'j be the transpose dim sizes, and Djk = Dj*Dk. Let fl_trans
|
||||||
|
// and fl_self be 1-D flattened tensors. Then:
|
||||||
|
// --------
|
||||||
|
// fl_trans[i] =
|
||||||
|
// = trans[i/D'1234, i/(D'234) % D'1, i/(D'34) % D'2, i/D'4 % D'3, i % D'4]
|
||||||
|
// = trans[i/D1234, i/D214 % D3, i/D14 % D2, i/D4 % D1, i % D4]
|
||||||
|
// = self[i/D1234, i/D4 % D1, i/D14 % D2, i/D214 % D3, i % D4]
|
||||||
|
// = fl_self[dot.prod(indices, (D1234,D234,D34,D4,1))] .
|
||||||
|
// --------
|
||||||
|
// reassoc(i) = (i/(D1234)) * D1234 +
|
||||||
|
// (i/D4 % D1) * D234 +
|
||||||
|
// (i/(D14) % D2) * D34 +
|
||||||
|
// (i/(D214) % D3) * D4 +
|
||||||
|
// (i % D4) .
|
||||||
|
|
||||||
|
SmallVector<int64_t, 5> D(5, 1);
|
||||||
|
int64_t i = -1;
|
||||||
|
// D[0] corresponds to flattened front dims
|
||||||
|
while (++i < dim0)
|
||||||
|
D[0] *= selfSizes[i];
|
||||||
|
// D[1] is the earliest transpose dim
|
||||||
|
D[1] = selfSizes[i];
|
||||||
|
// D[2] corresponds to flattened middle dims
|
||||||
|
while (++i < dim1)
|
||||||
|
D[2] *= selfSizes[i];
|
||||||
|
// D[3] is the later transpose dim
|
||||||
|
D[3] = selfSizes[i];
|
||||||
|
// D[4] corresponds to flattened back dims
|
||||||
|
while (++i < rank)
|
||||||
|
D[4] *= selfSizes[i];
|
||||||
|
|
||||||
|
int64_t D1234 = D[1] * D[2] * D[3] * D[4];
|
||||||
|
int64_t fullDP = D[0] * D1234;
|
||||||
|
if (fullDP != (int64_t)elements.size())
|
||||||
|
return failure();
|
||||||
|
auto reassoc = [&](int64_t i) {
|
||||||
|
return (i / D1234) * D1234 + ((i / D[4]) % D[1]) * D[2] * D[3] * D[4] +
|
||||||
|
((i / (D[1] * D[4])) % D[2]) * D[3] * D[4] +
|
||||||
|
((i / (D[2] * D[1] * D[4])) % D[3]) * D[4] + (i % D[4]);
|
||||||
|
};
|
||||||
|
SmallVector<OpFoldResult> transposedFolds;
|
||||||
|
transposedFolds.reserve(fullDP);
|
||||||
|
for (int64_t i = 0; i < fullDP; i++)
|
||||||
|
transposedFolds.push_back(elements[reassoc(i)]);
|
||||||
|
|
||||||
|
SmallVector<Value> transposedVals;
|
||||||
|
if (failed(materializeFolds(b, transposedFolds, transposedVals)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Value result = constructAtenTensorOpFromList(b, resultTy, transposedVals);
|
||||||
|
rewriter.replaceOp(op, result);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
namespace {
|
namespace {
|
||||||
class PropagateAtenWhereSelfPattern : public OpRewritePattern<AtenWhereSelfOp> {
|
class PropagateAtenWhereSelfPattern : public OpRewritePattern<AtenWhereSelfOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -600,6 +746,27 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <typename AtenViewLikeOp>
|
||||||
|
class PropagateAtenViewLikePattern : public OpRewritePattern<AtenViewLikeOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<AtenViewLikeOp>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenViewLikeOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
SmallVector<OpFoldResult> selfFolds;
|
||||||
|
if (failed(getListFromTensor(op.getSelf(), selfFolds)))
|
||||||
|
return failure();
|
||||||
|
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
||||||
|
SmallVector<Value> selfVals;
|
||||||
|
if (failed(materializeFolds(b, selfFolds, selfVals)))
|
||||||
|
return failure();
|
||||||
|
Value result = constructAtenTensorOpFromList(b, op.getType(), selfVals);
|
||||||
|
rewriter.replaceOp(op, result);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <typename OpTy> struct ArithmeticHelper {
|
template <typename OpTy> struct ArithmeticHelper {
|
||||||
|
@ -1065,6 +1232,34 @@ bool isAnchorOp(Operation *op) {
|
||||||
isPrimListOfInts(op);
|
isPrimListOfInts(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The argument to this function, op, is the use of some source op, srcOp. If
|
||||||
|
// this function returns true, we want to invalidate srcOp as a target for shape
|
||||||
|
// scalarization.
|
||||||
|
bool isInvalidValidViewConsumer(Operation *op,
|
||||||
|
SetVector<Operation *> &workList) {
|
||||||
|
// if the consumer isn't a view op, don't invalidate it
|
||||||
|
auto view = dyn_cast_or_null<AtenViewOp>(op);
|
||||||
|
if (!view)
|
||||||
|
return false;
|
||||||
|
auto resultTy = dyn_cast<ValueTensorType>(view.getType());
|
||||||
|
if (!resultTy || !resultTy.hasDtype())
|
||||||
|
return true;
|
||||||
|
// if the view op doesn't return integer types, then srcOp is not a shape
|
||||||
|
// tensor. note: prim lists will always get added before reaching this
|
||||||
|
// function call.
|
||||||
|
if (!isa<mlir::IntegerType>(resultTy.getDtype()))
|
||||||
|
return true;
|
||||||
|
// check uses of the view op.
|
||||||
|
// If the view op has a use in our worklist, then it needs to be scalarized.
|
||||||
|
for (OpOperand &use : op->getUses()) {
|
||||||
|
Operation *userOp = use.getOwner();
|
||||||
|
if (workList.contains(userOp))
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// invalidate, since the view op was added as a one-off for canonicalization.
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
void populateScalarizationFoldPatterns(RewritePatternSet &patterns) {
|
void populateScalarizationFoldPatterns(RewritePatternSet &patterns) {
|
||||||
patterns.insert<FoldAtenSqueezePattern<AtenSqueezeOp>,
|
patterns.insert<FoldAtenSqueezePattern<AtenSqueezeOp>,
|
||||||
FoldAtenSqueezePattern<AtenSqueezeDimOp>,
|
FoldAtenSqueezePattern<AtenSqueezeDimOp>,
|
||||||
|
@ -1078,6 +1273,11 @@ void populateScalarizationCanonicalizePatterns(RewritePatternSet &patterns) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) {
|
void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) {
|
||||||
|
patterns.add<PropagateAtenViewLikePattern<AtenViewOp>>(patterns.getContext(),
|
||||||
|
/*benefit=*/10);
|
||||||
|
patterns.insert<PropagateAtenViewLikePattern<AtenFlattenUsingIntsOp>,
|
||||||
|
PropagateAtenViewLikePattern<AtenUnflattenIntOp>>(
|
||||||
|
patterns.getContext());
|
||||||
// A note on division: onnx.Div from int, int -> int types rounds towards
|
// A note on division: onnx.Div from int, int -> int types rounds towards
|
||||||
// zero. The torch DivTensorOp actually doesn't allow returning an int dtype,
|
// zero. The torch DivTensorOp actually doesn't allow returning an int dtype,
|
||||||
// but this was artificially plummbed through. Unfortunately, there is no
|
// but this was artificially plummbed through. Unfortunately, there is no
|
||||||
|
@ -1088,6 +1288,7 @@ void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) {
|
||||||
PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern,
|
PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern,
|
||||||
PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern,
|
PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern,
|
||||||
PropagateAtenWhereSelfPattern, PropagateAtenBroadcastToPattern,
|
PropagateAtenWhereSelfPattern, PropagateAtenBroadcastToPattern,
|
||||||
|
PropagateAtenTransposeIntPattern,
|
||||||
PropagateAtenArithmeticPattern<AtenAddTensorOp, AtenAddIntOp>,
|
PropagateAtenArithmeticPattern<AtenAddTensorOp, AtenAddIntOp>,
|
||||||
PropagateAtenArithmeticPattern<AtenSubTensorOp, AtenSubIntOp>,
|
PropagateAtenArithmeticPattern<AtenSubTensorOp, AtenSubIntOp>,
|
||||||
PropagateAtenArithmeticPattern<AtenMulTensorOp, AtenMulIntOp>,
|
PropagateAtenArithmeticPattern<AtenMulTensorOp, AtenMulIntOp>,
|
||||||
|
@ -1105,9 +1306,6 @@ void populateScalarizationRemovePatterns(RewritePatternSet &patterns) {
|
||||||
RemoveUnusedPattern<Torch::AtenSizeIntOp>,
|
RemoveUnusedPattern<Torch::AtenSizeIntOp>,
|
||||||
RemoveUnusedPattern<Torch::AtenSliceTensorOp>,
|
RemoveUnusedPattern<Torch::AtenSliceTensorOp>,
|
||||||
RemoveUnusedPattern<Torch::AtenTensorOp>,
|
RemoveUnusedPattern<Torch::AtenTensorOp>,
|
||||||
RemoveUnusedPattern<Torch::ConstantBoolOp>,
|
|
||||||
RemoveUnusedPattern<Torch::ConstantIntOp>,
|
|
||||||
RemoveUnusedPattern<Torch::ConstantNoneOp>,
|
|
||||||
RemoveUnusedPattern<Torch::PrimListConstructOp>>(
|
RemoveUnusedPattern<Torch::PrimListConstructOp>>(
|
||||||
patterns.getContext());
|
patterns.getContext());
|
||||||
}
|
}
|
||||||
|
@ -1168,12 +1366,12 @@ public:
|
||||||
// shapeCalculationOps. It's consumer (%1) is indeed a shape
|
// shapeCalculationOps. It's consumer (%1) is indeed a shape
|
||||||
// calculation op, but the size.int op is an elementary unit of shape
|
// calculation op, but the size.int op is an elementary unit of shape
|
||||||
// computation. No futher gathering of producers is necessary to
|
// computation. No futher gathering of producers is necessary to
|
||||||
// reduce this. Similarly, don't add the `self` of a view op.
|
// reduce this. Similarly, don't always add the `self` of a view op.
|
||||||
for (OpOperand &use : op->getUses()) {
|
for (OpOperand &use : op->getUses()) {
|
||||||
Operation *userOp = use.getOwner();
|
Operation *userOp = use.getOwner();
|
||||||
if (shapeCalculationOps.contains(userOp) &&
|
if (shapeCalculationOps.contains(userOp) &&
|
||||||
!isSourceOpForShapeScalarization(userOp) &&
|
!isSourceOpForShapeScalarization(userOp) &&
|
||||||
!isa<AtenViewOp>(userOp)) {
|
!isInvalidValidViewConsumer(userOp, shapeCalculationOps)) {
|
||||||
shapeCalculationOps.insert(op);
|
shapeCalculationOps.insert(op);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1903,7 +1903,33 @@ def aten〇_weight_norm_interface〡shape(v: List[int], g: List[int], dim: int =
|
||||||
return upstream_shape_functions.unary(v), upstream_shape_functions.unary(g)
|
return upstream_shape_functions.unary(v), upstream_shape_functions.unary(g)
|
||||||
|
|
||||||
def aten〇slice〇Tensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
|
def aten〇slice〇Tensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
|
||||||
return upstream_shape_functions.slice(self, dim, start, end, step)
|
start_val = start if start is not None else 0
|
||||||
|
end_val = end if end is not None else upstream_shape_functions.max_int()
|
||||||
|
if (step < 0):
|
||||||
|
# Convert to equivalent postive-step parameters, which will require swapping start and end.
|
||||||
|
# If the parameters are in the normal range (0 <= start < d and -1 <= end <= start), then
|
||||||
|
# swapped_end = start + 1 and swapped_begin = end + 1.
|
||||||
|
# The shift of inclusion can cause issues if these parameters are not already resolved on the left.
|
||||||
|
# e.g. start = -1, end = -3 . So valid start is actually d-1, and valid end is d-3. Therefore, we
|
||||||
|
# should have swapped_end = d, but adding 1 to start before making it valid would result in an
|
||||||
|
# incorrect, but "valid", swapped_end = 0 for forward slicing.
|
||||||
|
# Additionally, if adding d doesn't make these values positive, but adding twice would, we need
|
||||||
|
# to clamp after resolving, otherwise the upstream function will try to resolve a second time.
|
||||||
|
if start_val < 0:
|
||||||
|
start_val += self[dim]
|
||||||
|
if start_val < 0:
|
||||||
|
start_val = 0
|
||||||
|
if end_val < 0:
|
||||||
|
end_val += self[dim]
|
||||||
|
if end_val < 0:
|
||||||
|
end_val = -1
|
||||||
|
|
||||||
|
tmp = end_val + 1
|
||||||
|
end_val = start_val + 1
|
||||||
|
start_val = tmp
|
||||||
|
step = -step
|
||||||
|
return upstream_shape_functions.slice(self,dim,start_val,end_val,step)
|
||||||
|
|
||||||
|
|
||||||
def aten〇as_strided〡shape(self: List[int], size: List[int], stride: List[int], storage_offset: Optional[int] = None) -> List[int]:
|
def aten〇as_strided〡shape(self: List[int], size: List[int], stride: List[int], storage_offset: Optional[int] = None) -> List[int]:
|
||||||
return size
|
return size
|
||||||
|
|
|
@ -374,3 +374,262 @@ func.func @squeeze_dim_full_fold(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.li
|
||||||
%59 = torch.prim.ListConstruct %58 : (!torch.int) -> !torch.list<int>
|
%59 = torch.prim.ListConstruct %58 : (!torch.int) -> !torch.list<int>
|
||||||
return %59 : !torch.list<int>
|
return %59 : !torch.list<int>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @pytorch_dynamic_pad_export_view$prop(
|
||||||
|
func.func @pytorch_dynamic_pad_export_view$prop(%arg0: !torch.vtensor<[?,144,?,?],f32>) -> !torch.vtensor<[4,2],si64> {
|
||||||
|
// CHECK: %[[I0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[x0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[I2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[x1:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[I3:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[x2:.*]] = torch.aten.size.int %arg0, %[[I3]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[I144:.*]] = torch.constant.int 144
|
||||||
|
// CHECK: %[[I0_0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[I0_1:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[I0_2:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[I0_3:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[x3:.*]] = torch.prim.ListConstruct %[[x0]], %[[I144]], %[[x1]], %[[x2]], %[[I0_0]], %[[I0_1]], %[[I0_2]], %[[I0_3]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[none:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[false:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[x4:.*]] = torch.aten.tensor %[[x3]], %[[none]], %[[none]], %[[false]] : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4,2],si64>
|
||||||
|
// CHECK: return %[[x4]] : !torch.vtensor<[4,2],si64>
|
||||||
|
%0 = torch.vtensor.literal(dense<0> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||||
|
%1 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||||
|
%2 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||||
|
%int-9223372036854775807 = torch.constant.int -9223372036854775807
|
||||||
|
%int-1 = torch.constant.int -1
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,144,?,?],f32> -> !torch.vtensor<[4],si64>
|
||||||
|
%4 = torch.prim.ListConstruct %3, %0 : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.list<vtensor>
|
||||||
|
%5 = torch.aten.cat %4, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[8],si64>
|
||||||
|
%6 = torch.prim.ListConstruct %int-1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%7 = torch.aten.view %5, %6 : !torch.vtensor<[8],si64>, !torch.list<int> -> !torch.vtensor<[4,2],si64>
|
||||||
|
%8 = torch.aten.slice.Tensor %7, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64>
|
||||||
|
%9 = torch.aten.transpose.int %8, %int0, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4],si64>
|
||||||
|
%10 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
|
||||||
|
%11 = torch.aten.view %9, %10 : !torch.vtensor<[2,4],si64>, !torch.list<int> -> !torch.vtensor<[8],si64>
|
||||||
|
%12 = torch.aten.index_select %11, %int0, %1 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||||
|
%13 = torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
%14 = torch.aten.index_select %11, %int0, %2 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||||
|
%15 = torch.aten.item %14 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
%16 = torch.prim.ListConstruct %13, %15: (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
return %7 : !torch.vtensor<[4,2],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @pytorch_dynamic_pad_export_slice$prop(
|
||||||
|
func.func @pytorch_dynamic_pad_export_slice$prop(%arg0: !torch.vtensor<[?,144,?,?],f32>) -> !torch.vtensor<[4,2],si64> {
|
||||||
|
// CHECK: %[[I0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[x0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[I2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[x1:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[I3:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[x2:.*]] = torch.aten.size.int %arg0, %[[I3]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[I0_0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[I0_1:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[I0_2:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[I0_3:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[I144:.*]] = torch.constant.int 144
|
||||||
|
// CHECK: %[[x3:.*]] = torch.prim.ListConstruct %[[I0_0]], %[[I0_1]], %[[I0_2]], %[[I0_3]], %[[x1]], %[[x2]], %[[x0]], %[[I144]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[none:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[false:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[x4:.*]] = torch.aten.tensor %[[x3]], %[[none]], %[[none]], %[[false]] : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4,2],si64>
|
||||||
|
// CHECK: return %[[x4]] : !torch.vtensor<[4,2],si64>
|
||||||
|
%0 = torch.vtensor.literal(dense<0> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||||
|
%1 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||||
|
%2 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||||
|
%int-9223372036854775807 = torch.constant.int -9223372036854775807
|
||||||
|
%int-1 = torch.constant.int -1
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,144,?,?],f32> -> !torch.vtensor<[4],si64>
|
||||||
|
%4 = torch.prim.ListConstruct %3, %0 : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.list<vtensor>
|
||||||
|
%5 = torch.aten.cat %4, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[8],si64>
|
||||||
|
%6 = torch.prim.ListConstruct %int-1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%7 = torch.aten.view %5, %6 : !torch.vtensor<[8],si64>, !torch.list<int> -> !torch.vtensor<[4,2],si64>
|
||||||
|
%8 = torch.aten.slice.Tensor %7, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64>
|
||||||
|
%9 = torch.aten.transpose.int %8, %int0, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4],si64>
|
||||||
|
%10 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
|
||||||
|
%11 = torch.aten.view %9, %10 : !torch.vtensor<[2,4],si64>, !torch.list<int> -> !torch.vtensor<[8],si64>
|
||||||
|
%12 = torch.aten.index_select %11, %int0, %1 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||||
|
%13 = torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
%14 = torch.aten.index_select %11, %int0, %2 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||||
|
%15 = torch.aten.item %14 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
%16 = torch.prim.ListConstruct %13, %15: (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
return %8 : !torch.vtensor<[4,2],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @pytorch_dynamic_pad_export_transpose$prop(
|
||||||
|
func.func @pytorch_dynamic_pad_export_transpose$prop(%arg0: !torch.vtensor<[?,144,?,?],f32>) -> !torch.vtensor<[2,4],si64> {
|
||||||
|
// CHECK: %[[I0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[I2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[I3:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[DIM3:.*]] = torch.aten.size.int %arg0, %[[I3]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[I0_0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[I0_1:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[I0_2:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[I0_3:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[DIM1:.*]] = torch.constant.int 144
|
||||||
|
// CHECK: %[[x3:.*]] = torch.prim.ListConstruct %[[I0_0]], %[[I0_1]], %[[DIM2]], %[[DIM0]], %[[I0_2]], %[[I0_3]], %[[DIM3]], %[[DIM1]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[none:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[false:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[x4:.*]] = torch.aten.tensor %[[x3]], %[[none]], %[[none]], %[[false]] : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[2,4],si64>
|
||||||
|
// CHECK: return %[[x4]] : !torch.vtensor<[2,4],si64>
|
||||||
|
%0 = torch.vtensor.literal(dense<0> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||||
|
%1 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||||
|
%2 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||||
|
%int-9223372036854775807 = torch.constant.int -9223372036854775807
|
||||||
|
%int-1 = torch.constant.int -1
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,144,?,?],f32> -> !torch.vtensor<[4],si64>
|
||||||
|
%4 = torch.prim.ListConstruct %3, %0 : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.list<vtensor>
|
||||||
|
%5 = torch.aten.cat %4, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[8],si64>
|
||||||
|
%6 = torch.prim.ListConstruct %int-1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%7 = torch.aten.view %5, %6 : !torch.vtensor<[8],si64>, !torch.list<int> -> !torch.vtensor<[4,2],si64>
|
||||||
|
%8 = torch.aten.slice.Tensor %7, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64>
|
||||||
|
%9 = torch.aten.transpose.int %8, %int0, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4],si64>
|
||||||
|
%10 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
|
||||||
|
%11 = torch.aten.view %9, %10 : !torch.vtensor<[2,4],si64>, !torch.list<int> -> !torch.vtensor<[8],si64>
|
||||||
|
%12 = torch.aten.index_select %11, %int0, %1 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||||
|
%13 = torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
%14 = torch.aten.index_select %11, %int0, %2 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||||
|
%15 = torch.aten.item %14 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
%16 = torch.prim.ListConstruct %13, %15: (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
return %9 : !torch.vtensor<[2,4],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @pytorch_dynamic_pad_export_full(
|
||||||
|
func.func @pytorch_dynamic_pad_export_full(%arg0: !torch.vtensor<[?,144,?,?],f32>) -> !torch.list<int> {
|
||||||
|
// CHECK: %[[I2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[I0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[x1:.*]] = torch.prim.ListConstruct %[[DIM2]], %[[I0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: return %[[x1]] : !torch.list<int>
|
||||||
|
%0 = torch.vtensor.literal(dense<0> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
|
||||||
|
%1 = torch.vtensor.literal(dense<2> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||||
|
%2 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||||
|
%int-9223372036854775807 = torch.constant.int -9223372036854775807
|
||||||
|
%int-1 = torch.constant.int -1
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,144,?,?],f32> -> !torch.vtensor<[4],si64>
|
||||||
|
%4 = torch.prim.ListConstruct %3, %0 : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.list<vtensor>
|
||||||
|
%5 = torch.aten.cat %4, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[8],si64>
|
||||||
|
%6 = torch.prim.ListConstruct %int-1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%7 = torch.aten.view %5, %6 : !torch.vtensor<[8],si64>, !torch.list<int> -> !torch.vtensor<[4,2],si64>
|
||||||
|
%8 = torch.aten.slice.Tensor %7, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64>
|
||||||
|
%9 = torch.aten.transpose.int %8, %int0, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4],si64>
|
||||||
|
%10 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
|
||||||
|
%11 = torch.aten.view %9, %10 : !torch.vtensor<[2,4],si64>, !torch.list<int> -> !torch.vtensor<[8],si64>
|
||||||
|
%12 = torch.aten.index_select %11, %int0, %1 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||||
|
%13 = torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
%14 = torch.aten.index_select %11, %int0, %2 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||||
|
%15 = torch.aten.item %14 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
%16 = torch.prim.ListConstruct %13, %15: (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
return %16 : !torch.list<int>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @transpose$prop_3d_0_1
|
||||||
|
func.func @transpose$prop_3d_0_1(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[2,2,2],si64> {
|
||||||
|
// CHECK: %[[I0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[SIZE0_0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[I1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[SIZE0_1:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[I2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[SIZE0_2:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[I3:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[SIZE0_3:.*]] = torch.aten.size.int %arg0, %[[I3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[I0_0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[SIZE1_0:.*]] = torch.aten.size.int %arg1, %[[I0_0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[I1_1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[SIZE1_1:.*]] = torch.aten.size.int %arg1, %[[I1_1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[I2_2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[SIZE1_2:.*]] = torch.aten.size.int %arg1, %[[I2_2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[I3_3:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[SIZE1_3:.*]] = torch.aten.size.int %arg1, %[[I3_3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[SIZE0_0]], %[[SIZE0_1]], %[[SIZE1_0]], %[[SIZE1_1]], %[[SIZE0_2]], %[[SIZE0_3]], %[[SIZE1_2]], %[[SIZE1_3]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[TENSOR:.*]] = torch.aten.tensor %[[LIST]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[2,2,2],si64>
|
||||||
|
// CHECK: return %[[TENSOR]] : !torch.vtensor<[2,2,2],si64>
|
||||||
|
%0 = torch.vtensor.literal(dense<2> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||||
|
%int-1 = torch.constant.int -1
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%1 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,?,?,?],f32> -> !torch.vtensor<[4],si64>
|
||||||
|
%2 = torch.aten._shape_as_tensor %arg1 : !torch.vtensor<[?,?,?,?],f32> -> !torch.vtensor<[4],si64>
|
||||||
|
%3 = torch.prim.ListConstruct %1, %2 : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.list<vtensor>
|
||||||
|
%4 = torch.aten.cat %3, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[8],si64>
|
||||||
|
%5 = torch.prim.ListConstruct %int2, %int2, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%6 = torch.aten.view %4, %5 : !torch.vtensor<[8],si64>, !torch.list<int> -> !torch.vtensor<[2,2,2],si64>
|
||||||
|
%7 = torch.aten.transpose.int %6, %int0, %int1 : !torch.vtensor<[2,2,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,2,2],si64>
|
||||||
|
%8 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
|
||||||
|
%9 = torch.aten.view %7, %8 : !torch.vtensor<[2,2,2],si64>, !torch.list<int> -> !torch.vtensor<[8],si64>
|
||||||
|
%10 = torch.aten.index_select %9, %int0, %0 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||||
|
%11 = torch.aten.item %10 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
%12 = torch.prim.ListConstruct %11 : (!torch.int) -> !torch.list<int>
|
||||||
|
return %7 : !torch.vtensor<[2,2,2],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @transpose$prop_3d_m1_0
|
||||||
|
func.func @transpose$prop_3d_m1_0(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[2,2,2],si64> {
|
||||||
|
// CHECK: %[[I0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[SIZE0_0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[I1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[SIZE0_1:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[I2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[SIZE0_2:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[I3:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[SIZE0_3:.*]] = torch.aten.size.int %arg0, %[[I3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[I0_0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[SIZE1_0:.*]] = torch.aten.size.int %arg1, %[[I0_0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[I1_1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[SIZE1_1:.*]] = torch.aten.size.int %arg1, %[[I1_1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[I2_2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[SIZE1_2:.*]] = torch.aten.size.int %arg1, %[[I2_2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[I3_3:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[SIZE1_3:.*]] = torch.aten.size.int %arg1, %[[I3_3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[SIZE0_0]], %[[SIZE1_0]], %[[SIZE0_2]], %[[SIZE1_2]], %[[SIZE0_1]], %[[SIZE1_1]], %[[SIZE0_3]], %[[SIZE1_3]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[TENSOR:.*]] = torch.aten.tensor %[[LIST]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[2,2,2],si64>
|
||||||
|
// CHECK: return %[[TENSOR]] : !torch.vtensor<[2,2,2],si64>
|
||||||
|
%0 = torch.vtensor.literal(dense<2> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||||
|
%int-1 = torch.constant.int -1
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%1 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,?,?,?],f32> -> !torch.vtensor<[4],si64>
|
||||||
|
%2 = torch.aten._shape_as_tensor %arg1 : !torch.vtensor<[?,?,?,?],f32> -> !torch.vtensor<[4],si64>
|
||||||
|
%3 = torch.prim.ListConstruct %1, %2 : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.list<vtensor>
|
||||||
|
%4 = torch.aten.cat %3, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[8],si64>
|
||||||
|
%5 = torch.prim.ListConstruct %int2, %int2, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%6 = torch.aten.view %4, %5 : !torch.vtensor<[8],si64>, !torch.list<int> -> !torch.vtensor<[2,2,2],si64>
|
||||||
|
%7 = torch.aten.transpose.int %6, %int-1, %int0 : !torch.vtensor<[2,2,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,2,2],si64>
|
||||||
|
%8 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
|
||||||
|
%9 = torch.aten.view %7, %8 : !torch.vtensor<[2,2,2],si64>, !torch.list<int> -> !torch.vtensor<[8],si64>
|
||||||
|
%10 = torch.aten.index_select %9, %int0, %0 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||||
|
%11 = torch.aten.item %10 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
%12 = torch.prim.ListConstruct %11 : (!torch.int) -> !torch.list<int>
|
||||||
|
return %7 : !torch.vtensor<[2,2,2],si64>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue