//===- ReduceOpVariants.cpp --------------------------------------*- C++-*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "mlir/Transforms/DialectConversion.h" #include "npcomp/Dialect/Torch/IR/TorchOps.h" #include "npcomp/Dialect/Torch/Transforms/Passes.h" #include "llvm/ADT/StringExtras.h" using namespace mlir; using namespace mlir::NPCOMP; using namespace mlir::NPCOMP::Torch; namespace { // Convert value semantic ops operating on mutable arrays to instead operate on // immutable tensors. class ConvertToImmutableTensors : public RewritePattern { public: ConvertToImmutableTensors(MLIRContext *context) : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { if (!op->hasTrait()) return rewriter.notifyMatchFailure(op, "does not have value semantics"); rewriter.updateRootInPlace(op, [&]() { // Convert all operands. SmallVector newOperands; for (OpOperand &opOperand : op->getOpOperands()) { auto tensorType = opOperand.get().getType().dyn_cast(); if (!tensorType) continue; opOperand.set(rewriter.create( op->getLoc(), tensorType.getWithValueSemantics(), opOperand.get())); } // Convert all results. rewriter.setInsertionPointAfter(op); for (Value result : op->getResults()) { auto tensorType = result.getType().dyn_cast(); if (!tensorType) continue; auto createArray = rewriter.create( op->getLoc(), result.getType(), result); result.replaceAllUsesExcept(createArray, createArray); result.setType(tensorType.getWithValueSemantics()); } }); return success(); } }; } // namespace namespace { // Reduce the "trailing underscore inplace variant" to the value semantic // variant + an overwrite of the original "self" argument. class ReduceTrailingUnderscoreInplaceVariant : public RewritePattern { public: ReduceTrailingUnderscoreInplaceVariant(MLIRContext *context) : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { if (!op->hasTrait()) return rewriter.notifyMatchFailure(op, "is not trailing_ variant"); SmallVector fragments; llvm::SplitString(op->getName().getStringRef(), fragments, "."); assert(fragments.size() >= 3 && fragments[2].endswith("_") && "IsTrailingUnderscoreInplaceVariant incorrectly applied"); fragments[2] = fragments[2].drop_back(); std::string noUnderscoreName = llvm::join(fragments, "."); OperationState state(op->getLoc(), noUnderscoreName); state.addTypes(op->getResultTypes()); state.addOperands(op->getOperands()); state.addAttributes(op->getAttrDictionary().getValue()); // Note: No successors or regions. Torch JIT operators don't have any. assert(op->getNumRegions() == 0 && op->getNumSuccessors() == 0 && "Torch JIT operators shouldn't have regions or successors"); Operation *newOp = rewriter.createOperation(state); auto tensor = rewriter.create(op->getLoc(), newOp->getResult(0) .getType() .cast() .getWithValueSemantics(), newOp->getResult(0)); rewriter.create(op->getLoc(), tensor, op->getOperand(0)); rewriter.replaceOp(op, op->getOperand(0)); return success(); } }; } // namespace namespace { class ReduceOpVariantsPass : public ReduceOpVariantsBase { void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); patterns.add(context); patterns.add(context); ConversionTarget target(*context); target.markUnknownOpDynamicallyLegal([](Operation *op) { if (op->hasTrait()) { auto hasValueSemantics = [](Type t) { // TODO: Make this an allowlist based on a closed torch dialect // type system. if (auto tensorType = t.dyn_cast()) { return false; } return true; }; return llvm::all_of(op->getOperandTypes(), hasValueSemantics) && llvm::all_of(op->getResultTypes(), hasValueSemantics); } if (op->hasTrait()) { return false; } return true; }); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { return signalPassFailure(); } } }; } // namespace std::unique_ptr> mlir::NPCOMP::Torch::createReduceOpVariantsPass() { return std::make_unique(); }