//===- Bufferize.cpp - Bufferization of tmtensor ops ------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/Passes.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h" #include "torch-mlir-dialects/Dialect/TMTensor/Transforms/PassDetail.h" #include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h" using namespace ::mlir; using namespace ::mlir::torch::TMTensor; static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { auto memrefType = memref.getType().cast(); auto alloc = b.create( loc, memref::getMixedSizes(b, loc, memref), memrefType.getElementType()); b.create(loc, memref, alloc); return alloc; } static LogicalResult allocateBuffersForResults(Location loc, TMTensorOp tmtensorOp, ValueRange outputs, SmallVectorImpl &resultBuffers, OpBuilder &b) { // Lazily compute loopRanges. SmallVector loopRanges; // Allocate a buffer for every tensor result. assert(tmtensorOp.getNumOutputs() == tmtensorOp->getNumResults()); for (const auto &en : llvm::enumerate(tmtensorOp->getResultTypes())) { size_t resultIndex = en.index(); Type resultType = en.value(); auto tensorType = resultType.dyn_cast(); if (tensorType == nullptr) { tmtensorOp.emitOpError() << "tensor to buffer conversion expects ranked tensor results"; return failure(); } auto tensorShape = tensorType.getShape(); auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType()); Value resultTensor = outputs[resultIndex]; // Clone output buffers whose value is actually used. OpOperand *tiedOpOperand = tmtensorOp.getOutputOperand(resultIndex); if (tmtensorOp.payloadUsesValueFromOperand(tiedOpOperand)) { resultBuffers.push_back(cloneMemref(loc, resultTensor, b)); continue; } // Allocate buffers for statically-shaped results. if (memrefType.hasStaticShape()) { resultBuffers.push_back(b.create(loc, memrefType)); continue; } resultBuffers.push_back(b.create( loc, memref::getMixedSizes(b, loc, resultTensor), memrefType.getElementType())); } return success(); } /// Create TMTensor op on buffers given the original tensor-based operation and /// the buffers for the outputs. static TMTensorOp createTMTensorOpOnBuffers(ConversionPatternRewriter &rewriter, TMTensorOp tmtensorOp, ValueRange inputs, ValueRange outputs) { SmallVector newOperands = inputs; newOperands.append(outputs.begin(), outputs.end()); return cast( tmtensorOp.clone(rewriter, tmtensorOp->getLoc(), {}, newOperands)); } /// Generic conversion pattern that matches any TMTensorOp. This avoids template /// instantiating one pattern for each TMTensorOp. class BufferizeAnyTMTensorOp : public OpInterfaceConversionPattern { public: using OpInterfaceConversionPattern::OpInterfaceConversionPattern; LogicalResult matchAndRewrite(TMTensorOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { Location loc = op.getLoc(); SmallVector newOutputBuffers; SmallVector outputs(operands.begin() + op.getNumInputs(), operands.end()); if (failed(allocateBuffersForResults(loc, op, outputs, newOutputBuffers, rewriter))) { return op.emitOpError() << "Failed to allocate buffers for tensor results."; } SmallVector inputs(operands.begin(), operands.begin() + op.getNumInputs()); createTMTensorOpOnBuffers(rewriter, op, inputs, newOutputBuffers); // Replace the results of the old op with the new output buffers. rewriter.replaceOp(op, newOutputBuffers); return success(); } }; namespace { /// Converts TMTensor operations that work on tensor-type operands or results to /// work on buffers. struct TMTensorBufferizePass : public TMTensorBufferizeBase { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } void runOnOperation() override { MLIRContext &context = getContext(); ConversionTarget target(context); bufferization::BufferizeTypeConverter typeConverter; // Mark all Standard operations legal. target.addLegalDialect(); // Mark all TMTensor operations illegal as long as they work on tensors. auto isLegalOperation = [&](Operation *op) { return typeConverter.isLegal(op); }; target.addDynamicallyLegalDialect(isLegalOperation); RewritePatternSet patterns(&context); patterns.add(typeConverter, patterns.getContext()); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } }; } // namespace std::unique_ptr> torch::TMTensor::createTMTensorBufferizePass() { return std::make_unique(); }