//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; namespace { LogicalResult materializeFolds(ImplicitLocOpBuilder b, ArrayRef fold, SmallVector &values) { for (auto f : fold) { if (auto val = dyn_cast(f)) { values.push_back(val); continue; } if (auto attr = dyn_cast(f)) { if (auto val = dyn_cast(attr)) { values.push_back(b.create( b.getType(), val)); continue; } if (auto val = dyn_cast(attr)) { values.push_back( b.create(b.getType(), val)); continue; } } return failure(); } return success(); } LogicalResult getListOperands(Value value, SmallVector &vals) { auto list = value.getDefiningOp(); if (!list) return failure(); for (auto operand : list.getOperands()) vals.push_back(operand); return success(); } LogicalResult getListFromTensor(Value value, SmallVector &vals) { constexpr int64_t kMaxFold = 16; if (auto tensor = value.getDefiningOp()) return getListOperands(tensor.getData(), vals); if (auto full = value.getDefiningOp()) { auto ty = cast(full.getType()); if (!ty.areAllSizesKnown() || ty.getSizes().size() != 1) return failure(); if (ty.getSizes()[0] > kMaxFold) return failure(); vals.resize(vals.size() + ty.getSizes()[0], full.getFillValue()); return success(); } return failure(); } } // namespace namespace { class PropagateAtenShapeToTensorPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten_ShapeAsTensorOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto self = op.getSelf(); auto selfTy = cast(self.getType()); if (!selfTy.hasSizes()) return rewriter.notifyMatchFailure(op, "self has unknown rank"); int64_t rank = selfTy.getSizes().size(); SmallVector dims; for (int64_t i = 0; i < rank; ++i) { auto iv = rewriter.create( loc, rewriter.getI64IntegerAttr(i)); dims.push_back(rewriter.create( loc, rewriter.getType(), self, iv)); } auto dimList = rewriter.create( loc, rewriter.getType(rewriter.getType()), dims); Value cstNone = rewriter.create(loc); Value cstFalse = rewriter.create( loc, rewriter.getBoolAttr(false)); rewriter.replaceOpWithNewOp( op, op.getType(), dimList, cstNone, cstNone, cstFalse); return success(); } }; } // namespace namespace { class PropagateAtenCatPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenCatOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); ImplicitLocOpBuilder b(loc, rewriter); constexpr int64_t kMaxFold = 16; auto resultTy = dyn_cast(op.getType()); if (!resultTy.hasSizes() || resultTy.getSizes().size() != 1 || !resultTy.areAllSizesKnown()) return failure(); if (resultTy.getSizes().front() > kMaxFold) return failure(); if (!resultTy.hasDtype()) return failure(); SmallVector tensors; if (failed(getListOperands(op.getTensors(), tensors))) return failure(); SmallVector scalars; for (auto element : tensors) { llvm::SmallVector delisted; if (succeeded(getListFromTensor(element, delisted))) { for (auto scalar : delisted) scalars.push_back(scalar); continue; } DenseElementsAttr attr; if (matchPattern(element, m_Constant(&attr))) { if (attr.isSplat()) { scalars.resize(scalars.size() + attr.getNumElements(), attr.getSplatValue()); continue; } for (auto e : attr.getValues()) { scalars.push_back(e); } continue; } return rewriter.notifyMatchFailure(op, "unknown op fold type"); } for (auto &scalar : scalars) { if (auto attr = dyn_cast(scalar)) { if (auto iattr = dyn_cast(attr)) { auto i64 = iattr.getValue().getSExtValue(); scalar = rewriter.getI64IntegerAttr(i64); } } } SmallVector values; if (failed(materializeFolds(b, scalars, values))) return rewriter.notifyMatchFailure(op, "unable to materialize constants"); Type eTy = b.getType(); if (isa(resultTy.getDtype())) eTy = rewriter.getType(); auto elementsList = b.create( rewriter.getType(eTy), values); Value cstNone = b.create(); Value cstFalse = b.create(rewriter.getBoolAttr(false)); rewriter.replaceOpWithNewOp( op, op.getType(), elementsList, cstNone, cstNone, cstFalse); return success(); } }; } // namespace namespace { class PropagateAtenIndexSelectPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenIndexSelectOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); ImplicitLocOpBuilder b(loc, rewriter); SmallVector elements; if (failed(getListFromTensor(op.getSelf(), elements))) return failure(); int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure(op, "requires a constant dim"); DenseElementsAttr idx; if (!matchPattern(op.getIndex(), m_Constant(&idx))) return rewriter.notifyMatchFailure(op, "requires a constant index"); auto selfTy = cast(op.getSelf().getType()); if (!selfTy.hasSizes()) return rewriter.notifyMatchFailure(op, "requires known rank"); auto selfShape = selfTy.getSizes(); int64_t selfRank = selfShape.size(); dim = dim < 0 ? dim + selfRank : dim; int64_t dimLength = elements.size(); if (selfShape[dim] != dimLength) return rewriter.notifyMatchFailure( op, "dim length does not match number of elements"); for (int64_t i = 0; i < selfRank; ++i) { if (i == dim) continue; if (selfShape[i] != 1) return rewriter.notifyMatchFailure(op, "expects unary non-dim dimension"); } SmallVector selected; if (idx.isSplat()) { int64_t indexInt = idx.getSplatValue().getSExtValue(); indexInt = indexInt < 0 ? indexInt + dimLength : indexInt; selected.resize(idx.getNumElements(), elements[indexInt]); } else { for (APInt val : idx.getValues()) { int64_t indexInt = val.getSExtValue(); selected.push_back(elements[indexInt]); } } auto eTy = elements.front().getType(); auto dimList = rewriter.create( loc, rewriter.getType(eTy), selected); Value cstNone = rewriter.create(loc); Value cstFalse = rewriter.create( loc, rewriter.getBoolAttr(false)); rewriter.replaceOpWithNewOp( op, op.getType(), dimList, cstNone, cstNone, cstFalse); return success(); } }; } // namespace namespace { // Conversion attempts to handle some common propagatable slice cases, namely // splatted values, no-op slices, known list of values, or any case where a // new construction can be generated from a previous set of scalars allowing // the parent tensor to be bypassed. class PropagateAtenSliceTensorPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenSliceTensorOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); ImplicitLocOpBuilder b(loc, rewriter); SmallVector elements; if (failed(getListFromTensor(op.getSelf(), elements))) return failure(); int64_t dim, start, end, step; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure(op, "requires a constant dim"); if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) return rewriter.notifyMatchFailure(op, "requires a constant start"); if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) return rewriter.notifyMatchFailure(op, "requires a constant end"); if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) return rewriter.notifyMatchFailure(op, "requires a constant step"); if (step < 0) return rewriter.notifyMatchFailure(op, "requires a positive step value"); auto selfTy = cast(op.getSelf().getType()); auto selfShape = selfTy.getSizes(); int64_t selfRank = selfShape.size(); // Correct for negative indexing: dim = dim < 0 ? dim + selfRank : dim; int64_t dimLength = elements.size(); start = start < 0 ? start + dimLength : start; end = end < 0 ? end + dimLength : end; start = start < 0 ? 0 : start; end = end < 0 ? 0 : end; end = end > dimLength ? dimLength : end; if (selfShape[dim] != dimLength) return rewriter.notifyMatchFailure( op, "dim length does not match number of elements"); for (int64_t i = 0; i < selfRank; ++i) { if (i == dim) continue; if (selfShape[i] != 1) return rewriter.notifyMatchFailure(op, "expects unary non-dim dimension"); } SmallVector selected; for (int i = start; i < end; i += step) selected.push_back(elements[i]); auto eTy = elements.front().getType(); auto dimList = rewriter.create( loc, rewriter.getType(eTy), selected); Value cstNone = rewriter.create(loc); Value cstFalse = rewriter.create( loc, rewriter.getBoolAttr(false)); rewriter.replaceOpWithNewOp( op, op.getType(), dimList, cstNone, cstNone, cstFalse); return success(); } }; } // namespace namespace { class PropagateAtenItemPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenItemOp op, PatternRewriter &rewriter) const override { ImplicitLocOpBuilder b(op.getLoc(), rewriter); SmallVector elements; if (failed(getListFromTensor(op.getSelf(), elements))) return failure(); if (elements.size() != 1) return rewriter.notifyMatchFailure(op, "expected no elements"); rewriter.replaceOp(op, elements[0]); return success(); } }; } // namespace namespace { class FoldAtenTensorSplatPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenTensorOp op, PatternRewriter &rewriter) const override { ImplicitLocOpBuilder b(op.getLoc(), rewriter); SmallVector elements; if (failed(getListOperands(op.getData(), elements))) return failure(); if (elements.size() < 1) return rewriter.notifyMatchFailure(op, "no elements"); auto front = elements.front(); for (auto element : elements) if (element != front) return rewriter.notifyMatchFailure(op, "multiple elements found"); if (elements.size() != 1) return rewriter.notifyMatchFailure(op, "expected no elements"); auto resultTy = cast(op.getType()); if (!resultTy.hasSizes() || !resultTy.areAllSizesKnown()) return rewriter.notifyMatchFailure(op, "dynamic output shape"); auto loc = op.getLoc(); SmallVector sizes; for (auto size : resultTy.getSizes()) sizes.push_back(rewriter.create( loc, rewriter.getI64IntegerAttr(size))); Value one = rewriter.create( loc, rewriter.getType(), 1); Value sizeList = rewriter.create( loc, rewriter.getType(rewriter.getType()), one); Value none = rewriter.create(loc); Value cstFalse = rewriter.create(loc, false); rewriter.replaceOpWithNewOp( op, resultTy, sizeList, elements.front(), none, none, none, cstFalse); return success(); } }; } // namespace namespace { class FoldAtenSqueezePattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenSqueezeOp op, PatternRewriter &rewriter) const override { auto resultTy = cast(op.getType()); if (!resultTy.hasSizes() || !resultTy.areAllSizesKnown()) return rewriter.notifyMatchFailure(op, "Unknown result shape"); if (auto atenFull = op.getSelf().getDefiningOp()) { SmallVector sizes; for (int i = 0, s = resultTy.getSizes().size(); i < s; ++i) sizes.push_back(rewriter.create( op.getLoc(), rewriter.getType(), rewriter.getI64IntegerAttr(i))); Value sizeList = rewriter.create( op.getLoc(), rewriter.getType(rewriter.getType()), sizes); Value none = rewriter.create(op.getLoc()); rewriter.replaceOpWithNewOp(op, resultTy, sizeList, atenFull.getFillValue(), none, none, none, none); return success(); } return failure(); } }; } // namespace namespace { class FoldAtenWhereSelf : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenWhereSelfOp op, PatternRewriter &rewriter) const override { auto getRoot = [](Value v) { while (true) { if (auto numToTensor = v.getDefiningOp()) { v = numToTensor.getA(); continue; } break; } return v; }; auto self = getRoot(op.getSelf()); auto other = getRoot(op.getOther()); if (self == other) { rewriter.replaceOp(op, op.getSelf()); return success(); } auto selfSize = self.getDefiningOp(); auto otherSize = other.getDefiningOp(); if (selfSize && otherSize) { if (selfSize.getSelf() != otherSize.getSelf()) return failure(); if (selfSize.getDim() != otherSize.getDim()) return failure(); rewriter.replaceOp(op, op.getSelf()); return success(); } return failure(); } }; } // namespace namespace { class FoldAtenUnsqueezePattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenUnsqueezeOp op, PatternRewriter &rewriter) const override { auto resultTy = cast(op.getType()); if (!resultTy.hasSizes() || !resultTy.areAllSizesKnown()) return rewriter.notifyMatchFailure(op, "Unknown result shape"); if (auto atenFull = op.getSelf().getDefiningOp()) { SmallVector sizes; for (int i = 0, s = resultTy.getSizes().size(); i < s; ++i) sizes.push_back(rewriter.create( op.getLoc(), rewriter.getType(), rewriter.getI64IntegerAttr(i))); Value sizeList = rewriter.create( op.getLoc(), rewriter.getType(rewriter.getType()), sizes); Value none = rewriter.create(op.getLoc()); rewriter.replaceOpWithNewOp(op, resultTy, sizeList, atenFull.getFillValue(), none, none, none, none); return success(); } return failure(); } }; } // namespace namespace { template class RemoveUnusedPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const override { for (auto use : op->getResults()) if (!use.use_empty()) return failure(); rewriter.eraseOp(op); return success(); } }; } // namespace namespace { class ScalarizeShapesPass : public ScalarizeShapesBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); patterns .insert, RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern>(context); context->getLoadedDialect() ->getCanonicalizationPatterns(patterns); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } }; } // namespace std::unique_ptr> mlir::torch::Torch::createScalarizeShapesPass() { return std::make_unique(); }