Add a new RecomposeComplexOps pass, fold slice+copy_ into indeX_put_ (#1901)

pull/1931/head oneshot-20230310.105
gpetters94 2023-03-10 16:42:11 -05:00 committed by GitHub
parent 2be48c3a67
commit 66b1045a80
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 153 additions and 0 deletions

View File

@ -850,6 +850,8 @@ LTC_XFAIL_SET = {
"DropoutTrainModule_basic",
"StdCorrectionKeepDimModule_basic",
"StdCorrectionNoneModule_basic",
"SliceCopy_Module_basic",
"SliceCopyNegative_Module_basic",
"VarBiasedModule_basic",
"VarCorrectionAllDimReduceModule_basic",
"VarCorrectionEmptyDimModule_basic",

View File

@ -98,6 +98,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createDecomposeComplexOpsPass(ArrayRef<std::string> legalOps);
std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOps();
std::unique_ptr<OperationPass<ModuleOp>> createPreprocessShapeLibraryPass();
std::unique_ptr<OperationPass<ModuleOp>> createReifyShapeCalculationsPass();

View File

@ -9,6 +9,7 @@ add_mlir_library(TorchMLIRTorchPasses
LowerToBackendContract.cpp
MaximizeValueSemantics.cpp
PrepareForGlobalizeObjectGraph.cpp
RecomposeComplexOps.cpp
ReduceOpVariants.cpp
RefinePublicReturn.cpp
RefineTypes.cpp

View File

@ -106,6 +106,7 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
// Clean up again to avoid needing to to back around the fixed-point
// iteration.
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createRecomposeComplexOps());
// Reduce variants of ops to a smaller set of primitives.
pm.addNestedPass<func::FuncOp>(createReduceOpVariantsPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());

View File

@ -0,0 +1,103 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace {
class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenCopy_Op op,
PatternRewriter &rewriter) const override {
if (!op.getSelf().getDefiningOp() ||
!isa<AtenSliceTensorOp>(op.getSelf().getDefiningOp()))
return failure();
auto sliceOp = cast<AtenSliceTensorOp>(op.getSelf().getDefiningOp());
// Get indices
int64_t dim;
if (!matchPattern(sliceOp.getDim(), m_TorchConstantInt(&dim)))
return failure();
int64_t end;
if (!matchPattern(sliceOp.getEnd(), m_TorchConstantInt(&end)))
return failure();
Value newEnd = sliceOp.getEnd();
if (end < 0) {
Value dimSize = rewriter.create<AtenSizeIntOp>(
op.getLoc(), sliceOp.getSelf(), sliceOp.getDim());
newEnd =
rewriter.create<AtenAddIntOp>(op.getLoc(), dimSize, sliceOp.getEnd());
}
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
Value falseVal = rewriter.create<ConstantBoolOp>(op.getLoc(), false);
// Create IndexPut_Op
BaseTensorType tensorType = op->getResultTypes()[0].cast<BaseTensorType>();
Value range = rewriter.create<AtenArangeStartStepOp>(
op.getLoc(), tensorType, sliceOp.getStart(), newEnd, sliceOp.getStep(),
/*dtype=*/noneVal, /*layout=*/noneVal, /*device=*/noneVal,
/*pin_memory=*/noneVal);
SmallVector<Value> indicesVector;
for (auto i = 0; i < dim - 1; i++)
indicesVector.push_back(noneVal);
indicesVector.push_back(range);
Value indices = rewriter.create<PrimListConstructOp>(
op.getLoc(),
Torch::ListType::get(op->getContext(),
Torch::OptionalType::get(tensorType)),
indicesVector);
rewriter.replaceOpWithNewOp<Aten_IndexPutImpl_Op>(
op, op->getResultTypes(), sliceOp.getSelf(), indices, op.getSrc(),
/*accumulate=*/falseVal, /*unsafe=*/falseVal);
return success();
}
};
} // namespace
namespace {
class RecomposeComplexOps
: public DecomposeComplexOpsBase<RecomposeComplexOps> {
public:
RecomposeComplexOps() = default;
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
// pattern.add calls go here
patterns.add<RecomposeSliceCopy_>(context);
GreedyRewriteConfig config;
config.useTopDownTraversal = true;
config.maxIterations = GreedyRewriteConfig::kNoLimit;
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config))) {
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createRecomposeComplexOps() {
return std::make_unique<RecomposeComplexOps>();
}

View File

@ -481,3 +481,47 @@ class NarrowVerticalTest2(torch.nn.Module):
@register_test_case(module_factory=lambda: NarrowVerticalTest2())
def NarrowVerticalTest2_basic(module, tu: TestUtils):
module.forward(tu.rand(6,4))
# ==============================================================================
class SliceCopy_Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([10, 4, 4], torch.float32, True),
([4, 4, 4], torch.float32, True),
])
def forward(self, x, y):
xslice = torch.ops.aten.slice(x, 0, 2, 6, 1)
xslice.copy_(y)
return x
@register_test_case(module_factory=lambda: SliceCopy_Module())
def SliceCopy_Module_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 4, 4), tu.rand(4, 4, 4))
# ==============================================================================
class SliceCopyNegative_Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
])
def forward(self, x, y):
xslice = torch.ops.aten.slice(x, 0, 2, -4, 1)
xslice.copy_(y)
return x
@register_test_case(module_factory=lambda: SliceCopyNegative_Module())
def SliceCopyNegative_Module_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 4, 4), tu.rand(4, 4, 4))