mirror of https://github.com/llvm/torch-mlir
Add a new RecomposeComplexOps pass, fold slice+copy_ into indeX_put_ (#1901)
parent
2be48c3a67
commit
66b1045a80
|
@ -850,6 +850,8 @@ LTC_XFAIL_SET = {
|
|||
"DropoutTrainModule_basic",
|
||||
"StdCorrectionKeepDimModule_basic",
|
||||
"StdCorrectionNoneModule_basic",
|
||||
"SliceCopy_Module_basic",
|
||||
"SliceCopyNegative_Module_basic",
|
||||
"VarBiasedModule_basic",
|
||||
"VarCorrectionAllDimReduceModule_basic",
|
||||
"VarCorrectionEmptyDimModule_basic",
|
||||
|
|
|
@ -98,6 +98,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
|
|||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
createDecomposeComplexOpsPass(ArrayRef<std::string> legalOps);
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOps();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createPreprocessShapeLibraryPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createReifyShapeCalculationsPass();
|
||||
|
|
|
@ -9,6 +9,7 @@ add_mlir_library(TorchMLIRTorchPasses
|
|||
LowerToBackendContract.cpp
|
||||
MaximizeValueSemantics.cpp
|
||||
PrepareForGlobalizeObjectGraph.cpp
|
||||
RecomposeComplexOps.cpp
|
||||
ReduceOpVariants.cpp
|
||||
RefinePublicReturn.cpp
|
||||
RefineTypes.cpp
|
||||
|
|
|
@ -106,6 +106,7 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
|
|||
// Clean up again to avoid needing to to back around the fixed-point
|
||||
// iteration.
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
pm.addNestedPass<func::FuncOp>(createRecomposeComplexOps());
|
||||
// Reduce variants of ops to a smaller set of primitives.
|
||||
pm.addNestedPass<func::FuncOp>(createReduceOpVariantsPass());
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
namespace {
|
||||
class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenCopy_Op op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!op.getSelf().getDefiningOp() ||
|
||||
!isa<AtenSliceTensorOp>(op.getSelf().getDefiningOp()))
|
||||
return failure();
|
||||
auto sliceOp = cast<AtenSliceTensorOp>(op.getSelf().getDefiningOp());
|
||||
|
||||
// Get indices
|
||||
int64_t dim;
|
||||
if (!matchPattern(sliceOp.getDim(), m_TorchConstantInt(&dim)))
|
||||
return failure();
|
||||
int64_t end;
|
||||
if (!matchPattern(sliceOp.getEnd(), m_TorchConstantInt(&end)))
|
||||
return failure();
|
||||
|
||||
Value newEnd = sliceOp.getEnd();
|
||||
if (end < 0) {
|
||||
Value dimSize = rewriter.create<AtenSizeIntOp>(
|
||||
op.getLoc(), sliceOp.getSelf(), sliceOp.getDim());
|
||||
newEnd =
|
||||
rewriter.create<AtenAddIntOp>(op.getLoc(), dimSize, sliceOp.getEnd());
|
||||
}
|
||||
|
||||
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
|
||||
Value falseVal = rewriter.create<ConstantBoolOp>(op.getLoc(), false);
|
||||
|
||||
// Create IndexPut_Op
|
||||
BaseTensorType tensorType = op->getResultTypes()[0].cast<BaseTensorType>();
|
||||
Value range = rewriter.create<AtenArangeStartStepOp>(
|
||||
op.getLoc(), tensorType, sliceOp.getStart(), newEnd, sliceOp.getStep(),
|
||||
/*dtype=*/noneVal, /*layout=*/noneVal, /*device=*/noneVal,
|
||||
/*pin_memory=*/noneVal);
|
||||
|
||||
SmallVector<Value> indicesVector;
|
||||
for (auto i = 0; i < dim - 1; i++)
|
||||
indicesVector.push_back(noneVal);
|
||||
indicesVector.push_back(range);
|
||||
Value indices = rewriter.create<PrimListConstructOp>(
|
||||
op.getLoc(),
|
||||
Torch::ListType::get(op->getContext(),
|
||||
Torch::OptionalType::get(tensorType)),
|
||||
indicesVector);
|
||||
|
||||
rewriter.replaceOpWithNewOp<Aten_IndexPutImpl_Op>(
|
||||
op, op->getResultTypes(), sliceOp.getSelf(), indices, op.getSrc(),
|
||||
/*accumulate=*/falseVal, /*unsafe=*/falseVal);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class RecomposeComplexOps
|
||||
: public DecomposeComplexOpsBase<RecomposeComplexOps> {
|
||||
public:
|
||||
RecomposeComplexOps() = default;
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
RewritePatternSet patterns(context);
|
||||
|
||||
// pattern.add calls go here
|
||||
patterns.add<RecomposeSliceCopy_>(context);
|
||||
|
||||
GreedyRewriteConfig config;
|
||||
config.useTopDownTraversal = true;
|
||||
config.maxIterations = GreedyRewriteConfig::kNoLimit;
|
||||
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
|
||||
config))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::Torch::createRecomposeComplexOps() {
|
||||
return std::make_unique<RecomposeComplexOps>();
|
||||
}
|
|
@ -481,3 +481,47 @@ class NarrowVerticalTest2(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: NarrowVerticalTest2())
|
||||
def NarrowVerticalTest2_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6,4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class SliceCopy_Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([10, 4, 4], torch.float32, True),
|
||||
([4, 4, 4], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
xslice = torch.ops.aten.slice(x, 0, 2, 6, 1)
|
||||
xslice.copy_(y)
|
||||
return x
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceCopy_Module())
|
||||
def SliceCopy_Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 4, 4), tu.rand(4, 4, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class SliceCopyNegative_Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
xslice = torch.ops.aten.slice(x, 0, 2, -4, 1)
|
||||
xslice.copy_(y)
|
||||
return x
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceCopyNegative_Module())
|
||||
def SliceCopyNegative_Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 4, 4), tu.rand(4, 4, 4))
|
||||
|
|
Loading…
Reference in New Issue