mirror of https://github.com/llvm/torch-mlir
125 lines
5.0 KiB
C++
125 lines
5.0 KiB
C++
//===------------------------------------------------------------*- C++ -*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/Math/IR/Math.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
|
|
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h"
|
|
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/PassDetail.h"
|
|
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h"
|
|
#include "llvm/ADT/ArrayRef.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::torch::TMTensor;
|
|
|
|
/// Recursive method that lowers one dimension of the `ScalarLoopOpInterface` to
|
|
/// scalar loops at a time.
|
|
static LogicalResult lowerToLoopsImpl(OpBuilder &builder,
|
|
ScalarLoopOpInterface scalarLoopOp,
|
|
ArrayRef<Range> loopRanges,
|
|
unsigned loopDepth,
|
|
SmallVectorImpl<Value> &ivs) {
|
|
Location loc = scalarLoopOp.getLoc();
|
|
if (loopDepth == loopRanges.size()) {
|
|
return scalarLoopOp.generateScalarImplementation(builder, loc, ivs);
|
|
}
|
|
LogicalResult status = success();
|
|
Value offset = getValueOrCreateConstantIndexOp(builder, loc,
|
|
loopRanges[loopDepth].offset);
|
|
Value size =
|
|
getValueOrCreateConstantIndexOp(builder, loc, loopRanges[loopDepth].size);
|
|
Value stride = getValueOrCreateConstantIndexOp(builder, loc,
|
|
loopRanges[loopDepth].stride);
|
|
builder.create<scf::ForOp>(
|
|
loc, offset, size, stride, ValueRange{},
|
|
[&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
|
|
ivs.push_back(iv);
|
|
status =
|
|
lowerToLoopsImpl(b, scalarLoopOp, loopRanges, loopDepth + 1, ivs);
|
|
b.create<scf::YieldOp>(loc);
|
|
});
|
|
return status;
|
|
}
|
|
|
|
/// Main entry point for lowering `ScalarLoopOpInterface` op to loops.
|
|
static LogicalResult lowerToLoops(OpBuilder &builder,
|
|
ScalarLoopOpInterface scalarLoopOp) {
|
|
SmallVector<Range> loopBounds = scalarLoopOp.getIterationDomain(builder);
|
|
SmallVector<Value> ivs;
|
|
return lowerToLoopsImpl(builder, scalarLoopOp, loopBounds, 0, ivs);
|
|
}
|
|
|
|
/// Pattern rewriter hook to lower a `ScalarLoopOpInterface` to loops.
|
|
namespace {
|
|
struct ScalarLoopOpInterfaceLowerToLoopsPattern : public RewritePattern {
|
|
ScalarLoopOpInterfaceLowerToLoopsPattern(MLIRContext *context,
|
|
PatternBenefit benefit = 1)
|
|
: RewritePattern(MatchAnyOpTypeTag(), benefit, context) {}
|
|
|
|
LogicalResult matchAndRewrite(Operation *op,
|
|
PatternRewriter &rewriter) const override {
|
|
auto scalarLoopOp = dyn_cast<ScalarLoopOpInterface>(op);
|
|
if (!scalarLoopOp) {
|
|
return failure();
|
|
}
|
|
if (llvm::any_of(scalarLoopOp->getResults(),
|
|
[&](Value v) { return v.getType().isa<ShapedType>(); })) {
|
|
return rewriter.notifyMatchFailure(
|
|
scalarLoopOp, "lower to loops needs to have tensor semantics");
|
|
}
|
|
if (failed(lowerToLoops(rewriter, scalarLoopOp))) {
|
|
return failure();
|
|
}
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Pass
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
struct TMTensorToLoopsPass : public TMTensorToLoopsBase<TMTensorToLoopsPass> {
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<linalg::LinalgDialect, func::FuncDialect,
|
|
mlir::arith::ArithDialect, math::MathDialect,
|
|
memref::MemRefDialect, scf::SCFDialect>();
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
MLIRContext *context = &getContext();
|
|
|
|
RewritePatternSet patterns(context);
|
|
patterns.insert<ScalarLoopOpInterfaceLowerToLoopsPattern>(context);
|
|
if (failed(applyPatternsAndFoldGreedily(getOperation(),
|
|
std::move(patterns)))) {
|
|
return signalPassFailure();
|
|
}
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
|
torch::TMTensor::createTMTensorToLoopsPass() {
|
|
return std::make_unique<TMTensorToLoopsPass>();
|
|
}
|