//===- MaximizeValueSemantics.cpp --------------------------------*- C++-*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "npcomp/Dialect/Torch/IR/TorchOps.h" #include "npcomp/Dialect/Torch/Transforms/Passes.h" using namespace mlir; using namespace mlir::NPCOMP; using namespace mlir::NPCOMP::Torch; namespace { class MaximizeValueSemanticsPass : public MaximizeValueSemanticsBase { void runOnOperation() override { MLIRContext *context = &getContext(); auto func = getOperation(); RewritePatternSet patterns(context); CopyTensorOp::getCanonicalizationPatterns(patterns, context); TensorStaticInfoCastOp::getCanonicalizationPatterns(patterns, context); (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); } }; } // namespace std::unique_ptr> mlir::NPCOMP::Torch::createMaximizeValueSemanticsPass() { return std::make_unique(); }