|
|
|
@ -202,12 +202,13 @@ public:
|
|
|
|
|
op, "Expected a constant boolean value for keepDim");
|
|
|
|
|
|
|
|
|
|
Value input = op.getSelf();
|
|
|
|
|
std::sort(dims.begin(), dims.end());
|
|
|
|
|
// For every dimension included in `dim` of the op, iterated over in
|
|
|
|
|
// reverse order, we create a call to aten.max.dim.
|
|
|
|
|
for (int64_t i = dims.size() - 1; i >= 0; i--) {
|
|
|
|
|
std::sort(dims.begin(), dims.end());
|
|
|
|
|
std::reverse(dims.begin(), dims.end());
|
|
|
|
|
for (int64_t dimInt : dims) {
|
|
|
|
|
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(dims[i]));
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(dimInt));
|
|
|
|
|
// The input to the next invocation of aten.max.dim is the output of the
|
|
|
|
|
// previous aten.max.dim op.
|
|
|
|
|
input = createMaxAlongDimension(rewriter, loc, op, input, dim, keepDim);
|
|
|
|
@ -227,11 +228,12 @@ public:
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value self = op.getSelf();
|
|
|
|
|
MLIRContext *context = op.getContext();
|
|
|
|
|
int64_t rank = getTensorRank(self);
|
|
|
|
|
if (rank < 0)
|
|
|
|
|
Optional<unsigned> maybeRank = getTensorRank(self);
|
|
|
|
|
if (!maybeRank)
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor");
|
|
|
|
|
unsigned rank = *maybeRank;
|
|
|
|
|
SmallVector<Value> sizes;
|
|
|
|
|
for (int i = 0; i < rank; i++) {
|
|
|
|
|
for (unsigned i = 0; i < rank; i++) {
|
|
|
|
|
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(i));
|
|
|
|
|
sizes.push_back(rewriter.create<AtenSizeIntOp>(loc, self, dim));
|
|
|
|
@ -546,7 +548,12 @@ public:
|
|
|
|
|
|
|
|
|
|
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
|
|
|
|
BaseTensorType indicesTensorType = result.getType().cast<BaseTensorType>();
|
|
|
|
|
|
|
|
|
|
Optional<unsigned> maybeInputRank = getTensorRank(input);
|
|
|
|
|
if (!maybeInputRank) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "expected input tensor to have a rank");
|
|
|
|
|
}
|
|
|
|
|
unsigned inputRank = *maybeInputRank;
|
|
|
|
|
if (!indicesTensorType.hasSizes())
|
|
|
|
|
return failure();
|
|
|
|
|
BaseTensorType valueTensorType =
|
|
|
|
@ -565,7 +572,7 @@ public:
|
|
|
|
|
.cast<BaseTensorType>();
|
|
|
|
|
dim = rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
Value end = rewriter.create<ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(getTensorRank(input) - 1));
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(inputRank - 1));
|
|
|
|
|
input = rewriter.create<AtenFlattenUsingIntsOp>(loc, flattenType, input,
|
|
|
|
|
dim, end);
|
|
|
|
|
}
|
|
|
|
@ -674,8 +681,14 @@ public:
|
|
|
|
|
Value lhs = op.getSelf();
|
|
|
|
|
Value rhs = op.getOther();
|
|
|
|
|
|
|
|
|
|
int lhsRank = getTensorRank(lhs);
|
|
|
|
|
int rhsRank = getTensorRank(rhs);
|
|
|
|
|
Optional<unsigned> maybeLhsRank = getTensorRank(lhs);
|
|
|
|
|
Optional<unsigned> maybeRhsRank = getTensorRank(rhs);
|
|
|
|
|
if (!maybeLhsRank || !maybeRhsRank) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "expected input tensors to have a rank");
|
|
|
|
|
}
|
|
|
|
|
unsigned lhsRank = *maybeLhsRank;
|
|
|
|
|
unsigned rhsRank = *maybeRhsRank;
|
|
|
|
|
|
|
|
|
|
if (lhsRank == 2 && rhsRank == 2) {
|
|
|
|
|
// If both lhs and rhs ranks are 2 then map it to `aten.mm` op.
|
|
|
|
@ -773,15 +786,17 @@ public:
|
|
|
|
|
LogicalResult matchAndRewrite(AtenTOp op,
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Value lhs = op.getSelf();
|
|
|
|
|
int lhsRank = getTensorRank(lhs);
|
|
|
|
|
Optional<unsigned> lhsRank = getTensorRank(lhs);
|
|
|
|
|
auto loc = op.getLoc();
|
|
|
|
|
|
|
|
|
|
if (lhsRank > 2 || lhsRank < 0) {
|
|
|
|
|
if (!lhsRank) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
|
|
|
|
|
} else if (*lhsRank > 2) {
|
|
|
|
|
std::string errorMessage =
|
|
|
|
|
"t() expects a tensor with <=2 dimensions, but self is " +
|
|
|
|
|
std::to_string(lhsRank) + "D";
|
|
|
|
|
std::to_string(*lhsRank) + "D";
|
|
|
|
|
return rewriter.notifyMatchFailure(op, errorMessage.c_str());
|
|
|
|
|
} else if (lhsRank < 2)
|
|
|
|
|
} else if (*lhsRank < 2)
|
|
|
|
|
rewriter.replaceOp(op, lhs);
|
|
|
|
|
else {
|
|
|
|
|
Value zero =
|
|
|
|
@ -846,9 +861,10 @@ public:
|
|
|
|
|
loc, listType, llvm::ArrayRef<Value>{slice0, slice1});
|
|
|
|
|
return rewriter.create<AtenCatOp>(loc, self.getType(), slices, dim);
|
|
|
|
|
};
|
|
|
|
|
int rank = getTensorRank(self);
|
|
|
|
|
if (rank < 0)
|
|
|
|
|
Optional<unsigned> maybeRank = getTensorRank(self);
|
|
|
|
|
if (!maybeRank)
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor");
|
|
|
|
|
unsigned rank = *maybeRank;
|
|
|
|
|
Value output = self;
|
|
|
|
|
auto nShifts = shifts.size();
|
|
|
|
|
for (size_t k = 0; k < nShifts; ++k) {
|
|
|
|
@ -901,16 +917,17 @@ public:
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value self = op.getSelf();
|
|
|
|
|
MLIRContext *context = op.getContext();
|
|
|
|
|
int rank = getTensorRank(self);
|
|
|
|
|
if (rank < 0)
|
|
|
|
|
Optional<unsigned> maybeRank = getTensorRank(self);
|
|
|
|
|
if (!maybeRank)
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor");
|
|
|
|
|
unsigned rank = *maybeRank;
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> repeats;
|
|
|
|
|
if (!getListConstructElements(op.getRepeats(), repeats))
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "Unimplemented: repeats not list of Scalar");
|
|
|
|
|
|
|
|
|
|
if (rank > (int)repeats.size()) {
|
|
|
|
|
if (rank > repeats.size()) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "repeats are not matched with self's rank");
|
|
|
|
|
}
|
|
|
|
@ -946,7 +963,7 @@ public:
|
|
|
|
|
|
|
|
|
|
auto selfType = self.getType().dyn_cast<BaseTensorType>();
|
|
|
|
|
auto selfShape = selfType.getSizes();
|
|
|
|
|
for (int i = 0; i < rank; i++) {
|
|
|
|
|
for (unsigned i = 0; i < rank; i++) {
|
|
|
|
|
auto scale = repeats[i + leadingRank];
|
|
|
|
|
Value dimSize;
|
|
|
|
|
if (selfShape[i] == kUnknownSize) {
|
|
|
|
@ -1003,9 +1020,10 @@ public:
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value self = op.getSelf();
|
|
|
|
|
MLIRContext *context = op.getContext();
|
|
|
|
|
int64_t rank = getTensorRank(self);
|
|
|
|
|
if (rank < 0)
|
|
|
|
|
Optional<unsigned> maybeRank = getTensorRank(self);
|
|
|
|
|
if (!maybeRank)
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "unimplemented: unranked tensor");
|
|
|
|
|
unsigned rank = *maybeRank;
|
|
|
|
|
|
|
|
|
|
int64_t start, end;
|
|
|
|
|
if (!matchPattern(op.getStartDim(), m_TorchConstantInt(&start)) ||
|
|
|
|
@ -1239,6 +1257,13 @@ public:
|
|
|
|
|
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
MLIRContext *context = op.getContext();
|
|
|
|
|
Value gradOutput = op.getGradOutput();
|
|
|
|
|
Optional<unsigned> maybeGradRank = getTensorRank(gradOutput);
|
|
|
|
|
if (!maybeGradRank) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
|
"expected grad output to have a rank");
|
|
|
|
|
}
|
|
|
|
|
unsigned gradRank = *maybeGradRank;
|
|
|
|
|
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(0));
|
|
|
|
|
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
@ -1249,10 +1274,8 @@ public:
|
|
|
|
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
|
|
|
|
|
loc, rewriter.getBoolAttr(false));
|
|
|
|
|
|
|
|
|
|
Value gradOutput = op.getGradOutput();
|
|
|
|
|
Value input = op.getInput();
|
|
|
|
|
Value weight = op.getWeight();
|
|
|
|
|
auto gradRank = getTensorRank(gradOutput);
|
|
|
|
|
|
|
|
|
|
if (gradRank != 4)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
@ -1299,7 +1322,7 @@ public:
|
|
|
|
|
|
|
|
|
|
// Rotate weight.
|
|
|
|
|
SmallVector<Value> axes;
|
|
|
|
|
for (auto i = 2; i < gradRank; i++) {
|
|
|
|
|
for (unsigned i = 2; i < gradRank; i++) {
|
|
|
|
|
axes.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(i)));
|
|
|
|
|
}
|
|
|
|
@ -1309,7 +1332,7 @@ public:
|
|
|
|
|
axesList);
|
|
|
|
|
// Calculate padding for first convolution.
|
|
|
|
|
SmallVector<Value> gradInputPaddingValues;
|
|
|
|
|
for (auto i = 2; i < gradRank; i++) {
|
|
|
|
|
for (unsigned i = 2; i < gradRank; i++) {
|
|
|
|
|
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(i));
|
|
|
|
|
Value outDim = rewriter.create<Torch::AtenSizeIntOp>(loc, input, dim);
|
|
|
|
@ -1359,7 +1382,7 @@ public:
|
|
|
|
|
loc, gradWeight.getType(), gradWeight, cstZero, cstOne);
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> dimIntList{cstZero};
|
|
|
|
|
for (auto i = 2; i < gradRank; i++)
|
|
|
|
|
for (unsigned i = 2; i < gradRank; i++)
|
|
|
|
|
dimIntList.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(i)));
|
|
|
|
|
Value gradIntList = rewriter.create<Torch::PrimListConstructOp>(
|
|
|
|
@ -1387,9 +1410,11 @@ public:
|
|
|
|
|
Value input = op.getSelf();
|
|
|
|
|
Value mat1 = op.getMat1();
|
|
|
|
|
Value mat2 = op.getMat2();
|
|
|
|
|
Optional<unsigned> mat1Rank = getTensorRank(mat1);
|
|
|
|
|
Optional<unsigned> mat2Rank = getTensorRank(mat2);
|
|
|
|
|
|
|
|
|
|
// The operands `mat1`, `mat2` to aten.addmm must be of rank 2.
|
|
|
|
|
if (getTensorRank(mat1) != 2 || getTensorRank(mat2) != 2) {
|
|
|
|
|
if (!mat1Rank || !mat2Rank || *mat1Rank != 2 || *mat2Rank != 2) {
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "expected mat1, mat2 operands to aten.addmm to be rank 2");
|
|
|
|
|
}
|
|
|
|
@ -1447,7 +1472,12 @@ public:
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value input = op.getSelf();
|
|
|
|
|
unsigned inputRank = getTensorRank(input);
|
|
|
|
|
Optional<unsigned> maybeInputRank = getTensorRank(input);
|
|
|
|
|
if (!maybeInputRank) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
|
|
|
|
|
}
|
|
|
|
|
unsigned inputRank = *maybeInputRank;
|
|
|
|
|
|
|
|
|
|
Value dimList = op.getDim();
|
|
|
|
|
Value keepDim = op.getKeepdim();
|
|
|
|
|
Value dtype = op.getDtype();
|
|
|
|
@ -1572,7 +1602,11 @@ public:
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value self = op.getSelf();
|
|
|
|
|
unsigned inputRank = getTensorRank(self);
|
|
|
|
|
Optional<unsigned> maybeInputRank = getTensorRank(self);
|
|
|
|
|
if (!maybeInputRank) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
|
|
|
|
|
}
|
|
|
|
|
unsigned inputRank = *maybeInputRank;
|
|
|
|
|
BaseTensorType rank0FloatTensorTy = op.getType().cast<BaseTensorType>();
|
|
|
|
|
if (!rank0FloatTensorTy.hasSizes() ||
|
|
|
|
|
rank0FloatTensorTy.getSizes().size() != 0) {
|
|
|
|
@ -2129,10 +2163,11 @@ class DecomposeAtenNativeBatchNormOp
|
|
|
|
|
|
|
|
|
|
// Rank of the input tensor must be greater than or equal to 2. The shape of
|
|
|
|
|
// the `input` is supposed to be (N, C, D?, H?, W?).
|
|
|
|
|
int64_t inputRank = getTensorRank(input);
|
|
|
|
|
if (inputRank < 2)
|
|
|
|
|
Optional<unsigned> maybeInputRank = getTensorRank(input);
|
|
|
|
|
if (!maybeInputRank || *maybeInputRank < 2)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "input must have rank greater than or equal to 2");
|
|
|
|
|
unsigned inputRank = *maybeInputRank;
|
|
|
|
|
|
|
|
|
|
// In the inference mode, the `runningMean` and `runningVar` must not be
|
|
|
|
|
// None.
|
|
|
|
@ -2142,7 +2177,10 @@ class DecomposeAtenNativeBatchNormOp
|
|
|
|
|
op, "running stats must not be None in inference mode");
|
|
|
|
|
|
|
|
|
|
// Rank of `runningMean` and `runningVar` must be exactly 1.
|
|
|
|
|
if (getTensorRank(runningMean) != 1 || getTensorRank(runningVar) != 1)
|
|
|
|
|
Optional<unsigned> runningMeanRank = getTensorRank(runningMean);
|
|
|
|
|
Optional<unsigned> runningVarRank = getTensorRank(runningVar);
|
|
|
|
|
if (!runningMeanRank || !runningVarRank || *runningMeanRank != 1 ||
|
|
|
|
|
*runningVarRank != 1)
|
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
|
op, "expected runningMean and runningVar to be rank 1");
|
|
|
|
|
|
|
|
|
@ -2191,7 +2229,8 @@ class DecomposeAtenNativeBatchNormOp
|
|
|
|
|
Value batchNormOutput = normalizedInput;
|
|
|
|
|
if (!weight.getType().isa<Torch::NoneType>()) {
|
|
|
|
|
// Rank of `weight` must be exactly 1.
|
|
|
|
|
if (getTensorRank(weight) != 1)
|
|
|
|
|
Optional<unsigned> weightRank = getTensorRank(weight);
|
|
|
|
|
if (!weightRank || *weightRank != 1)
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected weight to be rank 1");
|
|
|
|
|
weight = rewriter.create<AtenViewOp>(loc, reshapeType, weight,
|
|
|
|
|
runningStatsSizeList);
|
|
|
|
@ -2200,7 +2239,8 @@ class DecomposeAtenNativeBatchNormOp
|
|
|
|
|
}
|
|
|
|
|
if (!bias.getType().isa<Torch::NoneType>()) {
|
|
|
|
|
// Rank of `bias` must be exactly 1.
|
|
|
|
|
if (getTensorRank(bias) != 1)
|
|
|
|
|
Optional<unsigned> biasRank = getTensorRank(bias);
|
|
|
|
|
if (!biasRank || *biasRank != 1)
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected bias to be rank 1");
|
|
|
|
|
bias = rewriter.create<AtenViewOp>(loc, reshapeType, bias,
|
|
|
|
|
runningStatsSizeList);
|
|
|
|
@ -2619,7 +2659,11 @@ class DecomposeAtenAdaptiveAvgPool2dOp
|
|
|
|
|
MLIRContext *context = op.getContext();
|
|
|
|
|
|
|
|
|
|
Value input = op.getSelf();
|
|
|
|
|
int64_t rank = getTensorRank(input);
|
|
|
|
|
Optional<unsigned> maybeRank = getTensorRank(input);
|
|
|
|
|
if (!maybeRank) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
|
|
|
|
|
}
|
|
|
|
|
unsigned rank = *maybeRank;
|
|
|
|
|
SmallVector<Value, 2> inputHW;
|
|
|
|
|
Value dimH = rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(rank - 2));
|
|
|
|
@ -2788,12 +2832,19 @@ class DecomposeAtenNumpyTOp : public OpRewritePattern<AtenNumpyTOp> {
|
|
|
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
|
Location loc = op.getLoc();
|
|
|
|
|
Value self = op.getSelf();
|
|
|
|
|
int64_t inputRank = getTensorRank(self);
|
|
|
|
|
Optional<unsigned> maybeInputRank = getTensorRank(self);
|
|
|
|
|
if (!maybeInputRank) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
|
|
|
|
|
}
|
|
|
|
|
unsigned inputRank = *maybeInputRank;
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> dimListElements;
|
|
|
|
|
for (int64_t i = inputRank - 1; i >= 0; i--)
|
|
|
|
|
SmallVector<int> dimListInts(llvm::reverse(
|
|
|
|
|
llvm::iota_range<int>(0, inputRank, /*inclusive=*/false)));
|
|
|
|
|
for (int dimListInt : dimListInts) {
|
|
|
|
|
dimListElements.push_back(rewriter.create<Torch::ConstantIntOp>(
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(i)));
|
|
|
|
|
loc, rewriter.getI64IntegerAttr(dimListInt)));
|
|
|
|
|
}
|
|
|
|
|
Value dimList = rewriter.create<PrimListConstructOp>(
|
|
|
|
|
loc, Torch::ListType::get(Torch::IntType::get(op->getContext())),
|
|
|
|
|
dimListElements);
|
|
|
|
@ -2828,7 +2879,11 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
|
|
|
|
|
inputTensorTy = self.getType().cast<BaseTensorType>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
unsigned inputRank = getTensorRank(self);
|
|
|
|
|
Optional<unsigned> maybeInputRank = getTensorRank(self);
|
|
|
|
|
if (!maybeInputRank) {
|
|
|
|
|
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
|
|
|
|
|
}
|
|
|
|
|
unsigned inputRank = *maybeInputRank;
|
|
|
|
|
SmallVector<Value> dimListElements;
|
|
|
|
|
bool isNoneOrEmpty = true;
|
|
|
|
|
if (!dimList.getType().template isa<Torch::NoneType>()) {
|
|
|
|
|