//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; namespace { LogicalResult getListOperands(Value value, SmallVector &vals) { auto list = value.getDefiningOp(); if (!list) return failure(); for (auto operand : list.getOperands()) vals.push_back(operand); return success(); } LogicalResult getListFromTensor(Value value, SmallVector &vals) { auto tensor = value.getDefiningOp(); if (!tensor) return failure(); return getListOperands(tensor.getData(), vals); } } // namespace namespace { class PropagateAtenShapeToTensorPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten_ShapeAsTensorOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto self = op.getSelf(); auto selfTy = cast(self.getType()); if (!selfTy.hasSizes()) return rewriter.notifyMatchFailure(op, "self has unknown rank"); int64_t rank = selfTy.getSizes().size(); SmallVector dims; for (int64_t i = 0; i < rank; ++i) { auto iv = rewriter.create( loc, rewriter.getI64IntegerAttr(i)); dims.push_back(rewriter.create( loc, rewriter.getType(), self, iv)); } auto dimList = rewriter.create( loc, rewriter.getType(rewriter.getType()), dims); Value cstNone = rewriter.create(loc); Value cstFalse = rewriter.create( loc, rewriter.getBoolAttr(false)); rewriter.replaceOpWithNewOp( op, op.getType(), dimList, cstNone, cstNone, cstFalse); return success(); } }; } // namespace namespace { class PropagateAtenIndexSelectPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenIndexSelectOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); SmallVector elements; if (failed(getListFromTensor(op.getSelf(), elements))) return failure(); int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure(op, "requires a constant dim"); DenseElementsAttr idx; if (!matchPattern(op.getIndex(), m_Constant(&idx))) return rewriter.notifyMatchFailure(op, "requires a constant index"); auto selfTy = cast(op.getSelf().getType()); if (!selfTy.hasSizes()) return rewriter.notifyMatchFailure(op, "requires known rank"); auto selfShape = selfTy.getSizes(); int64_t selfRank = selfShape.size(); dim = dim < 0 ? dim + selfRank : dim; int64_t dimLength = elements.size(); if (selfShape[dim] != dimLength) return rewriter.notifyMatchFailure( op, "dim length does not match number of elements"); for (int64_t i = 0; i < selfRank; ++i) { if (i == dim) continue; if (selfShape[i] != 1) return rewriter.notifyMatchFailure(op, "expects unary non-dim dimension"); } SmallVector selected; if (idx.isSplat()) { int64_t indexInt = idx.getSplatValue().getSExtValue(); indexInt = indexInt < 0 ? indexInt + dimLength : indexInt; selected.resize(idx.getNumElements(), elements[indexInt]); } else { for (APInt val : idx.getValues()) { int64_t indexInt = val.getSExtValue(); selected.push_back(elements[indexInt]); } } auto eTy = elements.front().getType(); auto dimList = rewriter.create( loc, rewriter.getType(eTy), selected); Value cstNone = rewriter.create(loc); Value cstFalse = rewriter.create( loc, rewriter.getBoolAttr(false)); rewriter.replaceOpWithNewOp( op, op.getType(), dimList, cstNone, cstNone, cstFalse); return success(); } }; } // namespace namespace { // Conversion attempts to handle some common propagatable slice cases, namely // splatted values, no-op slices, known list of values, or any case where a // new construction can be generated from a previous set of scalars allowing // the parent tensor to be bypassed. class PropagateAtenSliceTensorPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenSliceTensorOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); SmallVector elements; if (failed(getListFromTensor(op.getSelf(), elements))) return failure(); int64_t dim, start, end, step; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure(op, "requires a constant dim"); if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) return rewriter.notifyMatchFailure(op, "requires a constant start"); if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) return rewriter.notifyMatchFailure(op, "requires a constant end"); if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) return rewriter.notifyMatchFailure(op, "requires a constant step"); if (step < 0) return rewriter.notifyMatchFailure(op, "requires a positive step value"); auto selfTy = cast(op.getSelf().getType()); auto selfShape = selfTy.getSizes(); int64_t selfRank = selfShape.size(); // Correct for negative indexing: dim = dim < 0 ? dim + selfRank : dim; int64_t dimLength = elements.size(); start = start < 0 ? start + dimLength : start; end = end < 0 ? end + dimLength : end; start = start < 0 ? 0 : start; end = end < 0 ? 0 : end; end = end > dimLength ? dimLength : end; if (selfShape[dim] != dimLength) return rewriter.notifyMatchFailure( op, "dim length does not match number of elements"); for (int64_t i = 0; i < selfRank; ++i) { if (i == dim) continue; if (selfShape[i] != 1) return rewriter.notifyMatchFailure(op, "expects unary non-dim dimension"); } SmallVector selected; for (int i = start; i < end; i += step) selected.push_back(elements[i]); auto eTy = elements.front().getType(); auto dimList = rewriter.create( loc, rewriter.getType(eTy), selected); Value cstNone = rewriter.create(loc); Value cstFalse = rewriter.create( loc, rewriter.getBoolAttr(false)); rewriter.replaceOpWithNewOp( op, op.getType(), dimList, cstNone, cstNone, cstFalse); return success(); } }; } // namespace namespace { class PropagateAtenItemPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenItemOp op, PatternRewriter &rewriter) const override { SmallVector elements; if (failed(getListFromTensor(op.getSelf(), elements))) return failure(); if (elements.size() != 1) return rewriter.notifyMatchFailure(op, "expected no elements"); rewriter.replaceOp(op, elements[0]); return success(); } }; } // namespace namespace { class FoldAtenTensorSplatPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenTensorOp op, PatternRewriter &rewriter) const override { SmallVector elements; if (failed(getListOperands(op.getData(), elements))) return failure(); if (elements.size() < 1) return rewriter.notifyMatchFailure(op, "no elements"); auto front = elements.front(); for (auto element : elements) if (element != front) return rewriter.notifyMatchFailure(op, "multiple elements found"); if (elements.size() != 1) return rewriter.notifyMatchFailure(op, "expected no elements"); auto resultTy = cast(op.getType()); if (!resultTy.hasSizes() || !resultTy.areAllSizesKnown()) return rewriter.notifyMatchFailure(op, "dynamic output shape"); auto loc = op.getLoc(); llvm::SmallVector sizes; for (auto size : resultTy.getSizes()) sizes.push_back(rewriter.create( loc, rewriter.getI64IntegerAttr(size))); Value one = rewriter.create( loc, rewriter.getType(), 1); Value sizeList = rewriter.create( loc, rewriter.getType(rewriter.getType()), one); Value none = rewriter.create(loc); Value cstFalse = rewriter.create(loc, false); rewriter.replaceOpWithNewOp(op, resultTy, sizeList, front, none, none, none, cstFalse); return success(); } }; } // namespace namespace { template class RemoveUnusedPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const override { for (auto use : op->getResults()) if (!use.use_empty()) return failure(); rewriter.eraseOp(op); return success(); } }; } // namespace namespace { class ScalarizeShapesPass : public ScalarizeShapesBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); patterns.insert, RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern>(context); context->getLoadedDialect() ->getCanonicalizationPatterns(patterns); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } }; } // namespace std::unique_ptr> mlir::torch::Torch::createScalarizeShapesPass() { return std::make_unique(); }