//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Iterators.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; namespace { LogicalResult materializeFolds(ImplicitLocOpBuilder b, ArrayRef fold, SmallVectorImpl &values) { for (auto f : fold) { if (auto val = dyn_cast(f)) { values.push_back(val); continue; } if (auto attr = dyn_cast(f)) { if (auto val = dyn_cast(attr)) { values.push_back(b.create( b.getType(), val)); continue; } if (auto val = dyn_cast(attr)) { values.push_back( b.create(val.getValue().getSExtValue())); continue; } } return failure(); } return success(); } LogicalResult getListOperands(Value value, SmallVector &vals) { auto list = value.getDefiningOp(); if (!list) return failure(); for (auto operand : list.getOperands()) vals.push_back(operand); return success(); } LogicalResult getListFromTensor(Value value, SmallVector &vals) { constexpr int64_t kMaxFold = 16; if (auto tensor = value.getDefiningOp()) { SmallVector unfolded; LogicalResult gotList = getListOperands(tensor.getData(), unfolded); vals = getAsOpFoldResult(unfolded); return gotList; } if (auto full = value.getDefiningOp()) { auto ty = cast(full.getType()); if (!ty.areAllSizesKnown() || ty.getSizes().size() != 1) return failure(); if (ty.getSizes()[0] > kMaxFold) return failure(); vals.resize(vals.size() + ty.getSizes()[0], getAsOpFoldResult(full.getFillValue())); return success(); } if (auto unsqueeze = value.getDefiningOp()) { Value usqSelf = unsqueeze.getSelf(); if (auto numToTensor = usqSelf.getDefiningOp()) { vals.push_back(getAsOpFoldResult(numToTensor.getA())); return success(); } } // A common rank 0 tensor producer if (auto numToTensor = value.getDefiningOp()) { vals.push_back(getAsOpFoldResult(numToTensor.getA())); return success(); } // Last supported case: ValueTensorLiteralOp auto literalOp = value.getDefiningOp(); if (!literalOp) return failure(); // Check the type. auto ty = cast(literalOp.getType()); if (!ty.hasSizes() || ty.getSizes().size() > 1) return failure(); // make sure the type is not unsigned here before trying to materialize auto intTy = dyn_cast_or_null(ty.getDtype()); if (!intTy || intTy.isUnsigned()) return failure(); // if we have a rank 0 literal, we will be adding one element to the list int64_t listSize = ty.getSizes().size() == 1 ? ty.getSizes().front() : 1; if (listSize > kMaxFold) return failure(); // check for a splat or dense attr auto splattr = dyn_cast_or_null(literalOp.getValue()); auto denseAttr = dyn_cast_or_null(literalOp.getValue()); if (!splattr && !denseAttr) return failure(); // These are not mutually exclusive, so try splat first. if (splattr) { auto attr = splattr.getSplatValue(); vals.resize((int64_t)vals.size() + listSize, attr); return success(); } // remaining case: denseAttr if ((int64_t)denseAttr.getValues().size() != listSize) return failure(); for (auto e : denseAttr.getValues()) vals.push_back(e); return success(); } Value constructAtenTensorOpFromList(ImplicitLocOpBuilder b, mlir::Type resultTy, SmallVector &listValues) { auto dimList = b.create( b.getType(listValues.front().getType()), listValues); Value cstNone = b.create(); Value cstFalse = b.create(b.getBoolAttr(false)); return b.create(resultTy, dimList, cstNone, cstNone, cstFalse); } } // namespace /// ------ Propagation Patterns ------ /// // The general goal of these patterns is to convert SomeTensorOp to [scalarOps // -> PrimListOfInts -> AtenTensorOp] Since these tensorized shape calculation // ops are chained together, sequences like OpA -> OpB will propagate OpA first: // [scalarOpsA -> ListA -> TensorA] -> OpB. Then OpB will be able to // getListFromTensor(A), and further propagate scalarization. namespace { class PropagateAtenBroadcastToPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenBroadcastToOp op, PatternRewriter &rewriter) const override { constexpr int64_t kMaxFold = 16; // for tensor, or tensor<1xsi64>, broadcasted to tensor, grab // the element and convert to a full op. auto ty = cast(op.getType()); if (!ty.areAllSizesKnown() || ty.getSizes().size() != 1) return failure(); if (ty.getSizes()[0] > kMaxFold) return failure(); SmallVector fillFold; if (failed(getListFromTensor(op.getSelf(), fillFold)) || fillFold.size() != 1) return failure(); ImplicitLocOpBuilder b(op.getLoc(), rewriter); SmallVector fillVals; if (failed(materializeFolds(b, fillFold, fillVals))) return failure(); Value size = b.create(ty.getSizes().front()); Value sizeList = b.create( rewriter.getType(rewriter.getType()), size); Value none = b.create(); Value cstFalse = b.create(false); rewriter.replaceOpWithNewOp(op, ty, sizeList, fillVals.front(), none, none, none, cstFalse); return success(); } }; } // namespace namespace { class PropagateAtenShapeToTensorPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten_ShapeAsTensorOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); ImplicitLocOpBuilder b(loc, rewriter); auto self = op.getSelf(); auto selfTy = cast(self.getType()); if (!selfTy.hasSizes()) return rewriter.notifyMatchFailure(op, "self has unknown rank"); int64_t rank = selfTy.getSizes().size(); SmallVector dims; for (int64_t i = 0; i < rank; ++i) { auto iv = b.create(i); dims.push_back(b.createOrFold( rewriter.getType(), self, iv)); } SmallVector materializedDims; if (failed(materializeFolds(b, dims, materializedDims))) { return failure(); } Value result = constructAtenTensorOpFromList(b, op.getType(), materializedDims); rewriter.replaceOp(op, result); return success(); } }; } // namespace namespace { class PropagateAtenCatPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenCatOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); ImplicitLocOpBuilder b(loc, rewriter); constexpr int64_t kMaxFold = 16; auto resultTy = dyn_cast(op.getType()); if (!resultTy.hasSizes() || resultTy.getSizes().size() != 1 || !resultTy.areAllSizesKnown()) return failure(); if (resultTy.getSizes().front() > kMaxFold) return failure(); if (!resultTy.hasDtype()) return failure(); SmallVector tensors; if (failed(getListOperands(op.getTensors(), tensors))) return failure(); SmallVector scalars; for (auto element : tensors) { llvm::SmallVector delisted; if (failed(getListFromTensor(element, delisted))) return rewriter.notifyMatchFailure(op, "unknown op fold type"); for (auto scalar : delisted) scalars.push_back(scalar); } SmallVector values; if (failed(materializeFolds(b, scalars, values)) || values.empty()) return rewriter.notifyMatchFailure(op, "unable to materialize constants"); Value result = constructAtenTensorOpFromList(b, resultTy, values); rewriter.replaceOp(op, result); return success(); } }; } // namespace namespace { class PropagateAtenIndexSelectPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenIndexSelectOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); ImplicitLocOpBuilder b(loc, rewriter); SmallVector elements; if (failed(getListFromTensor(op.getSelf(), elements))) return failure(); int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure(op, "requires a constant dim"); SmallVector idxFolds; if (failed(getListFromTensor(op.getIndex(), idxFolds))) return rewriter.notifyMatchFailure(op, "requires a constant index"); auto selfTy = cast(op.getSelf().getType()); if (!selfTy.hasSizes()) return rewriter.notifyMatchFailure(op, "requires known rank"); auto selfShape = selfTy.getSizes(); int64_t selfRank = selfShape.size(); dim = dim < 0 ? dim + selfRank : dim; int64_t dimLength = elements.size(); if (selfShape[dim] != dimLength) return rewriter.notifyMatchFailure( op, "dim length does not match number of elements"); for (int64_t i = 0; i < selfRank; ++i) { if (i == dim) continue; if (selfShape[i] != 1) return rewriter.notifyMatchFailure(op, "expects unary non-dim dimension"); } SmallVector selected; for (auto idx : idxFolds) { auto attr = dyn_cast_or_null(dyn_cast(idx)); if (!attr) return failure(); int64_t indexInt = attr.getValue().getSExtValue(); indexInt = indexInt < 0 ? indexInt + dimLength : indexInt; if (indexInt < 0 || indexInt >= dimLength) return failure(); selected.push_back(elements[indexInt]); } SmallVector materializedSelected; if (failed(materializeFolds(b, selected, materializedSelected))) return failure(); Value result = constructAtenTensorOpFromList(b, op.getType(), materializedSelected); rewriter.replaceOp(op, result); return success(); } }; } // namespace namespace { // Conversion attempts to handle some common propagatable slice cases, namely // splatted values, no-op slices, known list of values, or any case where a // new construction can be generated from a previous set of scalars allowing // the parent tensor to be bypassed. class PropagateAtenSliceTensorPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenSliceTensorOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); ImplicitLocOpBuilder b(loc, rewriter); SmallVector elements; if (failed(getListFromTensor(op.getSelf(), elements))) return failure(); int64_t dim, start, end, step; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure(op, "requires a constant dim"); if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) return rewriter.notifyMatchFailure(op, "requires a constant start"); if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) return rewriter.notifyMatchFailure(op, "requires a constant end"); if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) return rewriter.notifyMatchFailure(op, "requires a constant step"); if (step < 0) return rewriter.notifyMatchFailure(op, "requires a positive step value"); auto selfTy = cast(op.getSelf().getType()); auto selfShape = selfTy.getSizes(); int64_t selfRank = selfShape.size(); // Correct for negative indexing: dim = dim < 0 ? dim + selfRank : dim; int64_t dimLength = elements.size(); start = start < 0 ? start + dimLength : start; end = end < 0 ? end + dimLength : end; start = start < 0 ? 0 : start; end = end < 0 ? 0 : end; end = end > dimLength ? dimLength : end; if (selfShape[dim] != dimLength) return rewriter.notifyMatchFailure( op, "dim length does not match number of elements"); for (int64_t i = 0; i < selfRank; ++i) { if (i == dim) continue; if (selfShape[i] != 1) return rewriter.notifyMatchFailure(op, "expects unary non-dim dimension"); } SmallVector selected; for (int i = start; i < end; i += step) selected.push_back(elements[i]); SmallVector values; if (failed(materializeFolds(b, selected, values))) return failure(); Value result = constructAtenTensorOpFromList(b, op.getType(), values); rewriter.replaceOp(op, result); return success(); } }; } // namespace namespace { class PropagateAtenWhereSelfPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenWhereSelfOp op, PatternRewriter &rewriter) const override { Value condition = op.getCondition(); Value self = op.getSelf(); Value other = op.getOther(); auto conditionTy = dyn_cast(condition.getType()); if (!conditionTy || !conditionTy.hasSizes() || conditionTy.getSizes().size() != 1) return rewriter.notifyMatchFailure(op, "bad condition type"); auto selfTy = dyn_cast(self.getType()); if (!selfTy || !selfTy.hasSizes() || selfTy.getSizes().size() != 1) return rewriter.notifyMatchFailure(op, "bad self type"); auto otherTy = dyn_cast(other.getType()); if (!otherTy || !otherTy.hasSizes() || otherTy.getSizes().size() != 1) return rewriter.notifyMatchFailure(op, "bad other type"); int64_t conditionSize = selfTy.getSizes()[0]; int64_t selfSize = selfTy.getSizes()[0]; int64_t otherSize = otherTy.getSizes()[0]; if (selfSize != otherSize || selfSize != conditionSize) return rewriter.notifyMatchFailure( op, "unimplemented: support for propogating with implicit broadcasting."); constexpr int64_t kMaxFold = 16; if (selfSize == Torch::kUnknownSize || selfSize > kMaxFold) return rewriter.notifyMatchFailure(op, "arguments are dynamic or too big"); SmallVector conditionFolds, selfFolds, otherFolds; if (failed(getListFromTensor(condition, conditionFolds)) || failed(getListFromTensor(self, selfFolds)) || failed(getListFromTensor(other, otherFolds))) return failure(); ImplicitLocOpBuilder b(op.getLoc(), rewriter); SmallVector conditionList, selfList, otherList; if (failed(materializeFolds(b, conditionFolds, conditionList)) || failed(materializeFolds(b, selfFolds, selfList)) || failed(materializeFolds(b, otherFolds, otherList))) return failure(); SmallVector whereVals; auto rank0IntTy = rewriter.getType( ArrayRef({}), selfTy.getDtype()); auto rank0BoolTy = rewriter.getType( ArrayRef({}), conditionTy.getDtype()); for (uint64_t i = 0; i < selfList.size(); i++) { Value rank0Cond = b.create( rank0BoolTy, conditionList[i]); Value rank0Self = b.create(rank0IntTy, selfList[i]); Value rank0Other = b.create(rank0IntTy, otherList[i]); Value rank0Where = b.create(rank0IntTy, rank0Cond, rank0Self, rank0Other); whereVals.push_back( b.create(rewriter.getType(), rank0Where)); } Value result = constructAtenTensorOpFromList(b, op.getType(), whereVals); rewriter.replaceOp(op, result); return success(); } }; } // namespace namespace { class PropagateAtenEqTensorPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenEqTensorOp op, PatternRewriter &rewriter) const override { Value self = op.getSelf(); Value other = op.getOther(); auto selfTy = dyn_cast(self.getType()); if (!selfTy || !selfTy.hasSizes() || selfTy.getSizes().size() != 1) return rewriter.notifyMatchFailure(op, "bad self type"); auto otherTy = dyn_cast(other.getType()); if (!otherTy || !otherTy.hasSizes() || otherTy.getSizes().size() != 1) return rewriter.notifyMatchFailure(op, "bad other type"); int64_t selfSize = selfTy.getSizes()[0]; int64_t otherSize = otherTy.getSizes()[0]; if (selfSize != otherSize) return rewriter.notifyMatchFailure( op, "unimplemented: support for propogating with implicit broadcasting."); constexpr int64_t kMaxFold = 16; if (selfSize == Torch::kUnknownSize || selfSize > kMaxFold || otherSize == Torch::kUnknownSize || otherSize > kMaxFold) return rewriter.notifyMatchFailure(op, "self or other is dynamic or too big"); SmallVector selfFolds, otherFolds; if (failed(getListFromTensor(self, selfFolds)) || failed(getListFromTensor(other, otherFolds))) return rewriter.notifyMatchFailure(op, "failed to get list from tensor"); ImplicitLocOpBuilder b(op.getLoc(), rewriter); SmallVector selfList, otherList; if (failed(materializeFolds(b, selfFolds, selfList)) || failed(materializeFolds(b, otherFolds, otherList))) return rewriter.notifyMatchFailure(op, "failed to materialize folds"); SmallVector eqBoolFolds; for (uint64_t i = 0; i < selfList.size(); i++) { OpFoldResult eqInt = b.createOrFold(selfList[i], otherList[i]); if (auto eqIntVal = dyn_cast(eqInt)) eqInt = b.createOrFold(eqIntVal); // if eqInt was an Attribute, it will materialize to a constant int op, // which is what we want. eqBoolFolds.push_back(eqInt); } SmallVector eqVals; if (failed(materializeFolds(b, eqBoolFolds, eqVals))) { return failure(); } Value result = constructAtenTensorOpFromList(b, op.getType(), eqVals); rewriter.replaceOp(op, result); return success(); } }; } // namespace namespace { class PropagateAtenItemPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenItemOp op, PatternRewriter &rewriter) const override { SmallVector elements; Value self = op.getSelf(); auto selfTy = cast(self.getType()); ImplicitLocOpBuilder b(op.getLoc(), rewriter); // Rank 0 item op prop if (selfTy.getSizes().size() == 0) { auto numToTensor = self.getDefiningOp(); auto squeezeDim = self.getDefiningOp(); if (!squeezeDim && !numToTensor) return rewriter.notifyMatchFailure(op, "unhandled item of rank 0 operand"); if (numToTensor) { rewriter.replaceOp(op, numToTensor.getA()); return success(); } rewriter.replaceOpWithNewOp(op, op.getType(), squeezeDim.getSelf()); return success(); } // Rank 1 item op prop if (failed(getListFromTensor(op.getSelf(), elements))) return failure(); if (elements.size() != 1) return rewriter.notifyMatchFailure(op, "expected one element"); SmallVector materialized; if (failed(materializeFolds(b, elements, materialized))) return failure(); rewriter.replaceOp(op, materialized.front()); return success(); } }; } // namespace namespace { template struct ArithmeticHelper { static LogicalResult getAlphaAndVerify(OpTy &op, int64_t &alpha) { alpha = 1; return success(); } }; template <> struct ArithmeticHelper { static LogicalResult getAlphaAndVerify(AtenAddTensorOp &op, int64_t &alpha) { if (!matchPattern(op.getAlpha(), m_TorchConstantInt(&alpha)) || alpha != 1) return failure(); return success(); } }; template <> struct ArithmeticHelper { static LogicalResult getAlphaAndVerify(AtenSubTensorOp &op, int64_t &alpha) { if (!matchPattern(op.getAlpha(), m_TorchConstantInt(&alpha)) || alpha != 1) return failure(); return success(); } }; template class PropagateAtenArithmeticPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { // Check type auto resultTy = cast(op.getType()); if (resultTy.getSizes().size() > 1) return rewriter.notifyMatchFailure(op, "unsupported: rank > 1"); if (!resultTy.hasDtype() || !isa(resultTy.getDtype())) return rewriter.notifyMatchFailure(op, "not an int type"); int64_t alpha; if (failed(ArithmeticHelper::getAlphaAndVerify(op, alpha))) return rewriter.notifyMatchFailure(op, "alpha must be 1"); ImplicitLocOpBuilder b(op.getLoc(), rewriter); SmallVector selfFold, otherFold; if (failed(getListFromTensor(op.getSelf(), selfFold)) || failed(getListFromTensor(op.getOther(), otherFold)) || selfFold.size() != otherFold.size()) return failure(); SmallVector selfVals, otherVals; if (failed(materializeFolds(b, selfFold, selfVals)) || failed(materializeFolds(b, otherFold, otherVals))) return failure(); SmallVector resultFolds; for (uint64_t i = 0; i < selfVals.size(); i++) { resultFolds.push_back(b.createOrFold( selfVals[i].getType(), selfVals[i], otherVals[i])); } SmallVector resultVals; if (failed(materializeFolds(b, resultFolds, resultVals))) return failure(); if (resultTy.getSizes().size() == 0) { rewriter.replaceOpWithNewOp( op, resultTy, resultVals.front()); return success(); } Value result = constructAtenTensorOpFromList(b, resultTy, resultVals); rewriter.replaceOp(op, result); return success(); } }; } // namespace /// ------ Fold Patterns ------ /// // These are shape-specific folding patterns namespace { class FoldAtenEqIntPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenEqIntOp op, PatternRewriter &rewriter) const override { // replaces (size.int == 0) with false and adds an assert // these comparisons are getting generated because onnx.Reshape considers 0 // to mean "don't change this dim". However, if the size we are passing to // onnx.Reshape is a tensor dim, this is definitely never supposed to be // interpreted as "don't change this dim". int64_t otherInt; if (!matchPattern(op.getB(), m_TorchConstantInt(&otherInt)) || otherInt != 0) return failure(); // in case the shape is a product of two ints, check each if (auto mulOp = op.getA().getDefiningOp()) { Value self = mulOp.getA(); Value other = mulOp.getB(); Value selfEq = rewriter.create(op.getLoc(), self, op.getB()); Value otherEq = rewriter.create(op.getLoc(), other, op.getB()); rewriter.replaceOpWithNewOp(op, selfEq, otherEq); return success(); } // if lhs is size.int op, assert size > 0 and replace with false. if (auto sizeOp = op.getA().getDefiningOp()) { Value selfGtOther = rewriter.create( op.getLoc(), op.getType(), op.getA(), op.getB()); rewriter.create( op.getLoc(), selfGtOther, rewriter.getStringAttr("Expected dim size > 0.")); Value cstFalse = rewriter.create(op.getLoc(), false); rewriter.replaceOp(op, cstFalse); return success(); } return failure(); } }; } // namespace namespace { class FoldAtenTensorSplatPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenTensorOp op, PatternRewriter &rewriter) const override { ImplicitLocOpBuilder b(op.getLoc(), rewriter); SmallVector elements; if (failed(getListOperands(op.getData(), elements))) return failure(); if (elements.size() < 1) return rewriter.notifyMatchFailure(op, "no elements"); auto front = elements.front(); for (auto element : elements) if (element != front) return rewriter.notifyMatchFailure(op, "multiple elements found"); if (elements.size() != 1) return rewriter.notifyMatchFailure(op, "expected no elements"); auto resultTy = cast(op.getType()); if (!resultTy.hasSizes() || !resultTy.areAllSizesKnown()) return rewriter.notifyMatchFailure(op, "dynamic output shape"); auto loc = op.getLoc(); SmallVector sizes; for (auto size : resultTy.getSizes()) sizes.push_back(rewriter.create( loc, rewriter.getI64IntegerAttr(size))); Value one = rewriter.create( loc, rewriter.getType(), 1); Value sizeList = rewriter.create( loc, rewriter.getType(rewriter.getType()), one); Value none = rewriter.create(loc); Value cstFalse = rewriter.create(loc, false); rewriter.replaceOpWithNewOp( op, resultTy, sizeList, elements.front(), none, none, none, cstFalse); return success(); } }; } // namespace namespace { template class FoldAtenSqueezePattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SqueezeOp op, PatternRewriter &rewriter) const override { auto resultTy = cast(op.getType()); if (!resultTy.hasSizes() || !resultTy.areAllSizesKnown()) return rewriter.notifyMatchFailure(op, "Unknown result shape"); Value self = op.getSelf(); if (auto atenFull = self.getDefiningOp()) { // in the rank 0 case, just return the rank 0 scalar if (resultTy.getSizes().size() == 0) { rewriter.replaceOpWithNewOp( op, resultTy, atenFull.getFillValue()); return success(); } SmallVector sizes; for (int i = 0, s = resultTy.getSizes().size(); i < s; ++i) sizes.push_back(rewriter.create( op.getLoc(), rewriter.getType(), rewriter.getI64IntegerAttr(i))); Value sizeList = rewriter.create( op.getLoc(), rewriter.getType(rewriter.getType()), sizes); Value none = rewriter.create(op.getLoc()); rewriter.replaceOpWithNewOp(op, resultTy, sizeList, atenFull.getFillValue(), none, none, none, none); return success(); } return failure(); } }; } // namespace namespace { class FoldAtenWhereSelf : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenWhereSelfOp op, PatternRewriter &rewriter) const override { auto getRoot = [](Value v) { while (true) { if (auto numToTensor = v.getDefiningOp()) { v = numToTensor.getA(); continue; } break; } return v; }; auto self = getRoot(op.getSelf()); auto other = getRoot(op.getOther()); if (self == other) { rewriter.replaceOp(op, op.getSelf()); return success(); } auto selfSize = self.getDefiningOp(); auto otherSize = other.getDefiningOp(); if (selfSize && otherSize) { if (selfSize.getSelf() != otherSize.getSelf()) return rewriter.notifyMatchFailure(op, "sizes not of same tensor"); int64_t dimSelf, dimOther; if ((selfSize.getDim() != otherSize.getDim()) && (!matchPattern(selfSize.getDim(), m_TorchConstantInt(&dimSelf)) || !matchPattern(otherSize.getDim(), m_TorchConstantInt(&dimOther)) || (dimSelf != dimOther))) return rewriter.notifyMatchFailure(op, "sizes not of same dim"); rewriter.replaceOp(op, op.getSelf()); return success(); } return rewriter.notifyMatchFailure(op, "unable to fold"); } }; } // namespace namespace { class FoldAtenUnsqueezePattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenUnsqueezeOp op, PatternRewriter &rewriter) const override { auto resultTy = cast(op.getType()); if (!resultTy.hasSizes() || !resultTy.areAllSizesKnown()) return rewriter.notifyMatchFailure(op, "Unknown result shape"); if (auto atenFull = op.getSelf().getDefiningOp()) { SmallVector sizes; for (int i = 0, s = resultTy.getSizes().size(); i < s; ++i) sizes.push_back(rewriter.create( op.getLoc(), rewriter.getType(), rewriter.getI64IntegerAttr(i))); Value sizeList = rewriter.create( op.getLoc(), rewriter.getType(rewriter.getType()), sizes); Value none = rewriter.create(op.getLoc()); rewriter.replaceOpWithNewOp(op, resultTy, sizeList, atenFull.getFillValue(), none, none, none, none); return success(); } auto squeezeOp = op.getSelf().getDefiningOp(); if (squeezeOp && resultTy.getSizes().size() == 1) { rewriter.replaceOp(op, squeezeOp.getSelf()); return success(); } return failure(); } }; } // namespace /// ------ Canonicalization Patterns ------ /// namespace { // This is a specific pattern for converting views like [?,...,?,lastDim] -> // [?,...,?,factor0,factor1] to unflatten, and views like // [?,...,?,factor0,factor1] -> [?,...,?,lastDim] to flatten, whenever it is // possible to infer that all but last shared dim match // TODO: move this to an actual canonicalizer for view after deleting the // conflicting decompositions for flatten/unflatten -> view. class CanonicalizeAtenViewPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenViewOp op, PatternRewriter &rewriter) const override { SmallVector viewSizes; if (failed(getListOperands(op.getSize(), viewSizes))) return rewriter.notifyMatchFailure( op, "view size must be from a list construct"); auto selfTy = dyn_cast(op.getSelf().getType()); if (!selfTy || !selfTy.hasSizes()) return rewriter.notifyMatchFailure(op, "missing input type or sizes"); auto resultTy = dyn_cast(op.getType()); if (!resultTy || !resultTy.hasSizes() || resultTy.getSizes().size() != viewSizes.size()) return rewriter.notifyMatchFailure(op, "missing result type or sizes"); int64_t inRank = selfTy.getSizes().size(); int64_t outRank = resultTy.getSizes().size(); SmallVector sizes(selfTy.getSizes()); int64_t endMatchingDim = -1; // input sizes vs. provided view sizes comparison loop for (int64_t i = 0; i < std::min(outRank, inRank); i++) { int64_t providedSize; bool providedStatic = matchPattern(viewSizes[i], m_TorchConstantInt(&providedSize)); // if sizes[i] is static, it must match a constant in viewSizes[i] if (sizes[i] != Torch::kUnknownSize) { if (!providedStatic) return rewriter.notifyMatchFailure( op, "unsupported: found static input dim, but unable to match " "provided view size on a constant. See position : " + std::to_string(i)); if (providedSize != sizes[i]) { endMatchingDim = i; break; } continue; } // the remaining assumes sizes[i] is dynamic // if provided dim is static, we can't verify it is a flatten/unflatten // unless -1 if (i == outRank - 1 && providedStatic && providedSize == -1) { endMatchingDim = i; break; } if (providedStatic) return rewriter.notifyMatchFailure( op, "unexpected static view dim corresponding to dynamic input dim " "at position : " + std::to_string(i)); auto sizeIntOp = viewSizes[i].getDefiningOp(); // if we don't have a size int op on self, fail if (!sizeIntOp || sizeIntOp.getSelf() != op.getSelf()) return rewriter.notifyMatchFailure( op, "expected dynamic view dim to come from a corresponding " "size.int op. See position : " + std::to_string(i)); int64_t dim; // if the dim of the size int op doesn't match, fail if (!matchPattern(sizeIntOp.getDim(), m_TorchConstantInt(&dim)) || dim != i) return rewriter.notifyMatchFailure( op, "size int op dim cannot be matched to current dim at position : " + std::to_string(i)); // passing the previous checks means viewSizes[i] = aten.size.int(self, // i), so continue } // if all dims match and the ranks are equal, fold if (endMatchingDim == -1 && inRank == outRank) { rewriter.replaceOp(op, op.getSelf()); return success(); } if (endMatchingDim > -1 && inRank > outRank) { // only support flattening last dim if (endMatchingDim != outRank - 1) return rewriter.notifyMatchFailure( op, "unimplemented: output has more than back dim mismatching"); // flatten Value start = rewriter.create(op.getLoc(), endMatchingDim); Value end = rewriter.create(op.getLoc(), inRank - 1); rewriter.replaceOpWithNewOp( op, resultTy, op.getSelf(), start, end); return success(); } if (endMatchingDim > -1 && inRank < outRank) { // only support unflattening last dim if (endMatchingDim != inRank - 1) return rewriter.notifyMatchFailure( op, "unimplemented: input has more than back dim mismatching"); // unflatten Value dim = rewriter.create(op.getLoc(), endMatchingDim); Value primList = rewriter.create( op.getLoc(), op.getSize().getType(), ArrayRef(viewSizes.begin() + endMatchingDim, viewSizes.end())); rewriter.replaceOpWithNewOp( op, resultTy, op.getSelf(), dim, primList); return success(); } // examples that might reach this: // input shape = [10, 5]; view sizes = [5, 10] (or dynamic variants) // input shape = [dim0, dim1]; view sizes = [dim0, dim1, 1, 1] (unsqueezes) // input shape = [dim0, dim1, 1, 1] view sizes = [dim0, dim1] (squeezes) return rewriter.notifyMatchFailure( op, "unhandled case: endMatchingDim=" + std::to_string(endMatchingDim) + ", inRank=" + std::to_string(inRank) + ", outRank=" + std::to_string(outRank)); } }; } // namespace namespace { template class RemoveUnusedPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const override { for (auto use : op->getResults()) if (!use.use_empty()) return failure(); rewriter.eraseOp(op); return success(); } }; } // namespace namespace { bool isSourceOpForShapeScalarization(Operation *op) { return llvm::isa(op); } bool isPrimListOfInts(Operation *op) { auto primListOp = dyn_cast(op); if (!primListOp) return false; auto listType = dyn_cast(primListOp.getType()); if (!listType) return false; return llvm::isa(listType.getContainedType()); } bool isAnchorOp(Operation *op) { return isa(op) || isa(op) || isPrimListOfInts(op); } void populateScalarizationFoldPatterns(RewritePatternSet &patterns) { patterns.insert, FoldAtenSqueezePattern, FoldAtenUnsqueezePattern, FoldAtenWhereSelf, FoldAtenTensorSplatPattern, FoldAtenEqIntPattern>( patterns.getContext()); } void populateScalarizationCanonicalizePatterns(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); } void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) { // A note on division: onnx.Div from int, int -> int types rounds towards // zero. The torch DivTensorOp actually doesn't allow returning an int dtype, // but this was artificially plummbed through. Unfortunately, there is no // scalar trunc div op in torch; however, we can safely assume all operands // are positive so floor divide should be a sufficient scalar replacement. patterns.insert< PropagateAtenCatPattern, PropagateAtenIndexSelectPattern, PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern, PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern, PropagateAtenWhereSelfPattern, PropagateAtenBroadcastToPattern, PropagateAtenArithmeticPattern, PropagateAtenArithmeticPattern, PropagateAtenArithmeticPattern, PropagateAtenArithmeticPattern>( patterns.getContext()); } void populateScalarizationRemovePatterns(RewritePatternSet &patterns) { patterns.insert, RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern>( patterns.getContext()); } } // namespace namespace { class ScalarizeShapesPass : public ScalarizeShapesBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); // populate patterns populateScalarizationPropagationPatterns(patterns); populateScalarizationFoldPatterns(patterns); populateScalarizationCanonicalizePatterns(patterns); populateScalarizationRemovePatterns(patterns); context->getLoadedDialect() ->getCanonicalizationPatterns(patterns); // don't load torch canonicalization patterns, since these may lead to // issues with propagation // walk func op bottom-up to collect a SetVector of shape-related operations // When we pass this SetVector to the pattern rewrite driver, it will // process the operations top-down, thereby propagating scalarization // starting from sources. auto funcOp = getOperation(); llvm::SetVector shapeCalculationOps; funcOp.walk( [&](Operation *op) { // Walking bottom-up, start adding ops when we reach an anchor point // (a prim list of ints) if (isAnchorOp(op)) { shapeCalculationOps.insert(op); return; } // add view ops for now until the decompositions for flatten and // unflatten are removed. if (isa(op)) { shapeCalculationOps.insert(op); return; } // Insert the op if any of it's consumers have already been identified // as a shape calculation op. To avoid adding the producer of // something like a size.int op, don't add ops when their consumer is // a source op for shape scalarization. Here is some sample IR: // ------ // %0 = aten.matmul %arg0, %arg1 : ... -> !torch.vtensor<[?,?,?],f32> // %1 = aten.size.int %0, %int0 : !torch.int // %2 = prim.ListConstruct %1 : (!torch.int) -> !torch.list // return %2 : !torch.list // ------ // In this example, don't add the matmul (%0), or it's producers, to // shapeCalculationOps. It's consumer (%1) is indeed a shape // calculation op, but the size.int op is an elementary unit of shape // computation. No futher gathering of producers is necessary to // reduce this. Similarly, don't add the `self` of a view op. for (OpOperand &use : op->getUses()) { Operation *userOp = use.getOwner(); if (shapeCalculationOps.contains(userOp) && !isSourceOpForShapeScalarization(userOp) && !isa(userOp)) { shapeCalculationOps.insert(op); return; } } }); GreedyRewriteConfig config; // When propagating, we need to go back and clean up aten.Tensor ops that // have been futher propagated. It is also necessary to add newly created // ops for custom folding after scalarizing a where.self op. config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; if (failed(applyOpPatternsAndFold(shapeCalculationOps.getArrayRef(), std::move(patterns), config))) { return signalPassFailure(); } // TODO: Warn when failing to process operations in the worklist. } }; } // namespace std::unique_ptr> mlir::torch::Torch::createScalarizeShapesPass() { return std::make_unique(); }