//===------------------------------------------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/SMLoc.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::TMTensor; //===----------------------------------------------------------------------===// // Utils. //===----------------------------------------------------------------------===// static void getEffectsImpl( SmallVectorImpl> &effects, ResultRange results, ArrayRef inputBuffers, ArrayRef outputBuffers) { for (OpResult value : results) { effects.emplace_back(MemoryEffects::Allocate::get(), value, SideEffects::DefaultResource::get()); } for (OpOperand *value : inputBuffers) { effects.emplace_back(MemoryEffects::Read::get(), value, SideEffects::DefaultResource::get()); } for (OpOperand *value : outputBuffers) { effects.emplace_back(MemoryEffects::Read::get(), value, SideEffects::DefaultResource::get()); effects.emplace_back(MemoryEffects::Write::get(), value, SideEffects::DefaultResource::get()); } } Value TMTensor::getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim) { return TypeSwitch(v.getType()) .Case([&](RankedTensorType t) -> Value { return builder.create(loc, v, dim); }) .Case([&](MemRefType t) -> Value { return builder.create(loc, v, dim); }) .Default([&](Type t) { return Value(); }); } OpFoldResult TMTensor::getDim(OpBuilder &builder, Location loc, Value v, int64_t dim) { auto t = cast(v.getType()); if (t.isDynamicDim(dim)) { return getDimValue(builder, loc, v, dim); } return builder.getI64IntegerAttr(t.getDimSize(dim)); } //===----------------------------------------------------------------------===// // AttentionOp //===----------------------------------------------------------------------===// LogicalResult AttentionOp::verify() { Operation *op = getOperation(); ShapedType queryType = getQueryType(); ShapedType keyType = getKeyType(); ArrayRef queryShape = queryType.getShape(); ArrayRef keyShape = keyType.getShape(); for (int i = 0, s = queryShape.size() - 2; i < s; ++i) { if (keyShape[i] != queryShape[i]) return op->emitOpError("query and key batch mismatch"); } if (keyShape.back() != queryShape.back()) return op->emitOpError("query and key head dimension mismatch"); return success(); } SmallVector AttentionOp::getIterationDomain(OpBuilder &builder) { SmallVector loopBounds; return loopBounds; } SmallVector AttentionOp::getLoopIteratorTypes() { SmallVector iteratorTypes; return iteratorTypes; } bool AttentionOp::payloadUsesValueFromOperand(OpOperand *opOperand) { Value operand = opOperand->get(); return operand == getQuery() || operand == getKey() || operand == getValue(); } // Performs a matmul between lhs and rhs // Note that "transposed" means the last two dims of rhs are swapped static void matmul(OpBuilder &b, Location loc, Value lhs, ValueRange lhsSizes, Value rhs, ValueRange rhsSizes, Value output, ValueRange outputSizes, bool transposed = false) { auto elementType = cast(lhs.getType()).getElementType(); Value one = b.create(loc, 1); Value zero = b.create(loc, 0); auto rank = outputSizes.size(); Value reductionDimSize = lhsSizes[lhsSizes.size() - 1]; // Loop over output b.create( loc, SmallVector(rank, zero), outputSizes, SmallVector(rank, one), [&](OpBuilder &b, Location loc, ValueRange localIVs) { Value acc = b.create( loc, elementType, b.getFloatAttr(elementType, 0.0)); Value sum = b.create( loc, zero, reductionDimSize, one, SmallVector{acc}, [&](OpBuilder &b, Location loc, Value i, ValueRange accs) { SmallVector lhsIVs(localIVs), rhsIVs(localIVs); lhsIVs[lhsIVs.size() - 1] = i; rhsIVs[rhsIVs.size() - 2] = i; if (transposed) std::swap(rhsIVs[rhsIVs.size() - 1], rhsIVs[rhsIVs.size() - 2]); Value acc = accs[0]; Value rElem = b.create(loc, lhs, lhsIVs); Value cElem = b.create(loc, rhs, rhsIVs); Value x = b.create(loc, rElem, cElem); x = b.create(loc, x, acc); b.create(loc, x); }) ->getResult(0); b.create(loc, sum, output, localIVs); }); } LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, Location loc, ValueRange ivs) { Value query = getQuery(); Value key = getKey(); Value value = getValue(); Value output = getOutput(); auto queryType = cast(query.getType()); auto keyType = cast(key.getType()); auto valueType = cast(value.getType()); auto queryRank = queryType.getRank(); auto keyRank = keyType.getRank(); auto valueRank = valueType.getRank(); auto keySizes = keyType.getShape(); Type elementType = queryType.getElementType(); Value zeroF = b.create(loc, elementType, b.getFloatAttr(elementType, 0.0)); // TODO: This needs to be fixed, it assumes everything is dynamic however if // any shapes are static the `memref.alloc` generated is illegal. SmallVector queryDynSizes, keyDynSizes, valueDynSizes, outputDynSizes; for (auto i = 0; i < queryRank; i++) queryDynSizes.push_back(b.create(loc, query, i)); for (auto i = 0; i < keyRank; i++) keyDynSizes.push_back(b.create(loc, key, i)); for (auto i = 0; i < valueRank; i++) valueDynSizes.push_back(b.create(loc, value, i)); for (auto i = 0; i < queryRank; i++) outputDynSizes.push_back(b.create(loc, output, i)); // weight = query @ key auto weightRank = queryRank; auto weightSizes = SmallVector(queryType.getShape()); weightSizes[weightRank - 1] = keySizes[keyRank - 2]; auto weightType = MemRefType::get(weightSizes, queryType.getElementType()); // Setup the weight dynamic sizes: SmallVector weightDynSizes(queryDynSizes); weightDynSizes[weightRank - 1] = keyDynSizes[keyRank - 2]; SmallVector weightFilteredDynSizes; for (int i = 0; i < weightRank; ++i) if (weightSizes[i] == ShapedType::kDynamic) weightFilteredDynSizes.push_back(weightDynSizes[i]); Value weight = b.create(loc, weightType, weightFilteredDynSizes); matmul(b, loc, query, queryDynSizes, key, keyDynSizes, weight, weightDynSizes, /*transposed=*/true); // weight = softmax(weight) Value one = b.create(loc, 1); Value zero = b.create(loc, 0); Value dim = weightDynSizes[weightRank - 1]; Value scaleFactor = b.create( loc, b.create( loc, elementType, b.create(loc, b.getI32Type(), queryDynSizes[queryRank - 1]))); // calculate max(weight) Value init = b.create(loc, weight, SmallVector(weightRank, zero)); Value globalMax = b.create( loc, SmallVector(weightRank, zero), weightDynSizes, SmallVector(weightRank, one), init, [&](OpBuilder &b, Location loc, ValueRange localIVs, ValueRange accs) { auto reduceOp = b.create(loc, init); // Build reduce body. Block &reductionBody = reduceOp.getReductions()[0].front(); auto bodyBuilder = OpBuilder::atBlockEnd(&reductionBody); Value acc = reductionBody.getArgument(0); Value x = bodyBuilder.create(loc, weight, localIVs); Value max = bodyBuilder.create(loc, x, acc); bodyBuilder.create(loc, max); }) .getResult(0); // weight = (weight - max(weight)) / math.sqrt(querySizes[-1]) b.create( loc, SmallVector(weightRank, zero), weightDynSizes, SmallVector(weightRank, one), [&](OpBuilder &b, Location loc, ValueRange localIVs) { Value x = b.create(loc, weight, localIVs); x = b.create(loc, x, globalMax); x = b.create(loc, x, scaleFactor); b.create(loc, x, weight, localIVs); }); // calculate exp(weight) SmallVector min(weightRank, zero), max(weightDynSizes.begin(), weightDynSizes.end()), steps(weightRank, one); b.create( loc, min, max, steps, [&](OpBuilder &b, Location loc, ValueRange localIVs) { Value x = b.create(loc, weight, localIVs); x = b.create(loc, x); b.create(loc, x, weight, localIVs); }); llvm::SmallVector expWeightDynDims(weightFilteredDynSizes); if (weightSizes.back() == ShapedType::kDynamic) expWeightDynDims.resize(expWeightDynDims.size() - 1); Value expWeightSum = b.create( loc, MemRefType::get( SmallVector(weightSizes.begin(), weightSizes.end() - 1), elementType), expWeightDynDims); b.create( loc, SmallVector(weightRank - 1, zero), SmallVector{weightDynSizes.begin(), weightDynSizes.end() - 1}, SmallVector(weightRank - 1, one), [&](OpBuilder &b, Location loc, ValueRange localIVs) { b.create(loc, zeroF, expWeightSum, localIVs); }); // Loop over all dims but -1 b.create( loc, SmallVector(weightRank - 1, zero), SmallVector(weightDynSizes.begin(), weightDynSizes.end() - 1), SmallVector(weightRank - 1, one), [&](OpBuilder &b, Location loc, ValueRange outsideDims) { // Sum over last dim b.create( loc, zero, dim, one, [&](OpBuilder &b, Location loc, ValueRange localIVs) { SmallVector coords(outsideDims); coords.push_back(localIVs[0]); Value x = b.create(loc, expWeightSum, outsideDims); Value y = b.create(loc, weight, coords); Value sum = b.create(loc, x, y); b.create(loc, sum, expWeightSum, outsideDims); }); }); // calculate exp(weight) / sum(exp(weight)) b.create( loc, SmallVector(weightRank, zero), SmallVector(weightDynSizes.begin(), weightDynSizes.end()), SmallVector(weightRank, one), [&](OpBuilder &b, Location loc, ValueRange localIVs) { SmallVector sumIVs(localIVs); sumIVs.pop_back(); Value x = b.create(loc, weight, localIVs); Value sum = b.create(loc, expWeightSum, sumIVs); x = b.create(loc, x, sum); b.create(loc, x, weight, localIVs); }); // output = weight @ value matmul(b, loc, weight, weightDynSizes, value, valueDynSizes, output, outputDynSizes, /*transposed=*/false); return success(); } //===----------------------------------------------------------------------===// // ScanOp //===----------------------------------------------------------------------===// LogicalResult ScanOp::verify() { if (getNumInputs() != 1) { return emitOpError("expected one input operands"); } if (getNumOutputs() != 2) { return emitOpError("expected two output operands"); } if (!isa(input().getType())) { return emitOpError("expected first input element type to be shaped"); } auto accumulatorType = cast(accumulator().getType()); auto inputType = cast(input().getType()); auto outputType = cast(output().getType()); ArrayRef inputShapes = inputType.getShape(); ArrayRef outputShapes = outputType.getShape(); if (accumulatorType.getElementType() != inputType.getElementType()) { return emitOpError( "expected input/accumulator element types to be identical"); } ArrayRef accumulatorShape = accumulatorType.getShape(); int64_t accumulatorRank = accumulatorType.getRank(); if (accumulatorRank != inputType.getRank() - 1) { return emitOpError( "expected accumulator rank to be equal to input rank - 1"); } SmallVector expectedAccumulatorShape; for (size_t i = 0; i < (size_t)inputType.getRank(); i++) { if (i != getDimension()) expectedAccumulatorShape.push_back(inputShapes[i]); } if (llvm::any_of(llvm::zip(expectedAccumulatorShape, accumulatorShape), [](std::tuple s) { return std::get<0>(s) != ShapedType::kDynamic && std::get<1>(s) != ShapedType::kDynamic && std::get<0>(s) != std::get<1>(s); })) { return emitOpError("incompatible input/accumulator shapes"); } if (inputType.getElementType() != outputType.getElementType()) { return emitOpError("expected input/output element types to be identical"); } if (inputShapes.size() != outputShapes.size()) { return emitOpError("expected input/output to have identical ranks"); } if (llvm::any_of(llvm::zip(inputShapes, outputShapes), [](std::tuple s) { return std::get<0>(s) != ShapedType::kDynamic && std::get<1>(s) != ShapedType::kDynamic && std::get<0>(s) != std::get<1>(s); })) { return emitOpError("incompatible input/output shapes"); } return success(); } SmallVector ScanOp::getIterationDomain(OpBuilder &builder) { int64_t operandRank = getOperandRank(); SmallVector loopBounds(operandRank); Location loc = getLoc(); Value zero = builder.create(loc, 0); Value one = builder.create(loc, 1); Value source = input(); for (auto dim : llvm::seq(0, operandRank)) { loopBounds[dim].offset = zero; loopBounds[dim].size = getDimValue(builder, loc, source, dim); loopBounds[dim].stride = one; } return loopBounds; } SmallVector ScanOp::getLoopIteratorTypes() { SmallVector iteratorTypes(getOperandRank(), utils::IteratorType::parallel); iteratorTypes[getDimension()] = utils::IteratorType::reduction; return iteratorTypes; } bool ScanOp::payloadUsesValueFromOperand(OpOperand *opOperand) { Value operand = opOperand->get(); if (operand == accumulator()) return !getInclusive(); else if (operand == output()) return false; else { assert(operand == input() && "operand must belong to the current tm_tensor.scan op"); return true; } } // Generates naive scalar implementation of scan for a given operator f. // For inclusive, // output[0] = input[0] // output[i] = f(output[i-1], input[i]) // // For exclusive, // output[0] = 0 // output[i] = f(output[i-1], input[i-1]) LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc, ValueRange ivs) { SmallVector indices, scanBlkArgs; indices.append(ivs.begin(), ivs.end()); Value zero = b.create(loc, 0); Value one = b.create(loc, 1); uint64_t scanDim = getDimension(); Value cond = b.create(loc, arith::CmpIPredicate::eq, indices[scanDim], zero); bool isInclusive = getInclusive(); SmallVector accIndices; for (size_t i = 0; i < indices.size(); i++) { if (i != scanDim) accIndices.push_back(indices[i]); } auto scfIf = b.create( loc, cond, [&](OpBuilder &b, Location loc) { if (isInclusive) { auto value = b.create(loc, input(), indices); b.create(loc, value, output(), indices); } else { auto value = b.create(loc, accumulator(), accIndices); b.create(loc, value, output(), indices); } b.create(loc); }, [&](OpBuilder &b, Location loc) { SmallVector indices(ivs.begin(), ivs.end()); Value iv = indices[scanDim]; Value ivMinusOne = b.create(loc, iv, one); indices[scanDim] = ivMinusOne; scanBlkArgs.push_back(b.create(loc, output(), indices)); Value i0; if (!isInclusive) i0 = b.create(loc, input(), indices); indices[scanDim] = iv; if (isInclusive) i0 = b.create(loc, input(), indices); scanBlkArgs.push_back(i0); }); auto &srcBlock = getRegion().front(); Region &thisRegion = scfIf.getElseRegion(); IRMapping bvm; { OpBuilder::InsertionGuard guard(b); auto &block = thisRegion.front(); b.setInsertionPointToEnd(&block); for (auto it : llvm::zip(srcBlock.getArguments(), scanBlkArgs)) { bvm.map(std::get<0>(it), std::get<1>(it)); } for (auto &blockOp : srcBlock.without_terminator()) { b.clone(blockOp, bvm); } b.create( loc, bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)), output(), indices); b.create( loc, bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)), accumulator(), accIndices); b.create(loc); } return success(); } static LogicalResult foldMemRefCast(Operation *op) { bool folded = false; for (OpOperand &operand : op->getOpOperands()) { auto castOp = operand.get().getDefiningOp(); if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { operand.set(castOp.getOperand()); folded = true; } } return success(folded); } LogicalResult ScanOp::fold(FoldAdaptor adaptor, SmallVectorImpl &) { return foldMemRefCast(*this); } //===----------------------------------------------------------------------===// // ScatterOp //===----------------------------------------------------------------------===// static Type getComplexElementTypeOrSelf(Type ty) { if (auto complex = dyn_cast_or_null(ty)) return complex.getElementType(); return ty; } static bool isInvalid(ArrayRef dimsPos, int64_t rank) { // early exit. if (static_cast(dimsPos.size()) > rank) return true; DenseSet uniqued; for (int64_t dim : dimsPos) uniqued.insert(dim); if (static_cast(dimsPos.size()) != uniqued.size()) return true; return llvm::any_of( dimsPos, [rank](int64_t dimPos) { return dimPos < 0 || dimPos >= rank; }); } LogicalResult ScatterOp::verify() { Operation *op = getOperation(); if (getInputs().size() != 2) { return op->emitOpError("expected two input operands"); } if (getOutputs().size() != 1) { return op->emitOpError("expected one output operand"); } auto checkDimensionsMatch = [&](ShapedType t1, ShapedType t2, unsigned dim) { return t1.getShape()[dim] == t2.getShape()[dim]; }; auto indicesType = getIndicesType(); if (indicesType.getRank() != 2 || !indicesType.getElementType().isInteger(32)) { return emitOpError("expected indices to be of rank 2 of i32 element type"); } auto indexDepth = getIndexDepth(); if (ShapedType::isDynamic(indexDepth)) { return emitOpError("expected index depth is static"); } ArrayRef dimMap = getDimensionMap(); if (static_cast(dimMap.size()) != indexDepth) { return op->emitOpError("invalid number of dimension map entries "); } auto originalType = getOriginalType(); if (isInvalid(dimMap, originalType.getRank())) return op->emitOpError("dimension map is invalid"); // The first dimension of the indices should match the first dimension of the // output. They indicate to the number of updates. auto updateType = getUpdateType(); if (updateType.getRank() < 1) { return emitOpError("expected update value to be at least rank 1"); } if (!checkDimensionsMatch(indicesType, updateType, 0)) { return emitOpError( "mismatch in shape of indices and update value at dim#0"); } if (updateType.getRank() - 1 > originalType.getRank()) { return emitOpError( "update value rank exceeds the rank of the original value"); } // indexDepth + update dims should cover the original dims. The first dim of // update is the number of updates. if (originalType.getRank() > indexDepth + updateType.getRank() - 1) { return emitOpError( "index depth and update value does not cover rank of original value"); } // Validate the non-indexed update dims cover the full slice size of the // original tensor. int64_t fullSliceDims = originalType.getRank() - indexDepth; for (auto it : llvm::zip(llvm::seq(indexDepth, originalType.getRank()), llvm::seq(updateType.getRank() - fullSliceDims, updateType.getRank()))) { int64_t originalDim = std::get<0>(it); int64_t updateDim = std::get<1>(it); if (!originalType.isDynamicDim(originalDim) && updateType.getDimSize(updateDim) > originalType.getDimSize(originalDim)) { return op->emitOpError("shape of update value dim#") << updateDim << " exceeds original value at dim#" << originalDim; } } // Check that the remaining update indices do not exceed the update length. int64_t insertDims = originalType.getRank() - updateType.getRank() + 1; for (auto it : llvm::zip( llvm::seq(insertDims, indexDepth), llvm::seq(1, updateType.getRank() - fullSliceDims))) { int64_t originalDim = std::get<0>(it); int64_t updateDim = std::get<1>(it); if (!originalType.isDynamicDim(originalDim) && updateType.getDimSize(updateDim) > originalType.getDimSize(originalDim)) { return op->emitOpError("indexed shape of update value dim#") << updateDim << " exceeds original value at dim#" << originalDim << " " << updateType.getDimSize(updateDim) << " " << originalType.getDimSize(originalDim); } } Region ®ion = this->getRegion(); Block *body = ®ion.front(); if (body->getNumArguments() != 2) { return op->emitOpError("expected region to have two arguments"); } Type arg0Type = body->getArgument(0).getType(); Type arg1Type = body->getArgument(1).getType(); if (!getComplexElementTypeOrSelf(arg0Type).isIntOrFloat() || !getComplexElementTypeOrSelf(arg1Type).isIntOrFloat()) { return emitOpError( "expected region to have scalar argument of integer or float types"); } if (arg0Type != updateType.getElementType()) { return emitOpError("mismatch in argument 0 of region ") << arg0Type << " and element type of update value " << updateType.getElementType(); } if (arg1Type != originalType.getElementType()) { return emitOpError("mismatch in argument 1 of region ") << arg1Type << " and element type of original value " << originalType.getElementType(); } if (arg0Type != arg1Type) { return emitOpError("mismatch in region argument types ") << arg0Type << " and " << arg1Type; } auto yieldOp = cast(body->getTerminator()); if (yieldOp->getNumOperands() != 1) { return yieldOp.emitOpError("expected region to yield a single value"); } auto yieldedType = yieldOp->getOperand(0).getType(); if (yieldedType != arg0Type) { return yieldOp.emitOpError("mismatch in type of yielded value ") << yieldedType << " and argument of the region " << arg0Type; } return success(); } SmallVector ScatterOp::getLoopIteratorTypes() { SmallVector iteratorTypes(getUpdateType().getRank(), utils::IteratorType::parallel); if (!getUniqueIndices()) { iteratorTypes[0] = utils::IteratorType::reduction; } return iteratorTypes; } bool ScatterOp::payloadUsesValueFromOperand(OpOperand *opOperand) { unsigned bbArgNumber; Value operand = opOperand->get(); if (operand == updates()) bbArgNumber = 0; // block arg 0 is `update`. else { bool isValidOperand = operand == indices() || operand == original(); (void)isValidOperand; assert(isValidOperand && "operand must belong to the current tm_tensor.scatter op"); return true; } assert(this->getOperation()->getNumRegions() == 1 && "unexpected " "missing region (calling `payloadUsesValueFromOperand` on " "manually defined named TMTensor op?)"); Block &block = this->getOperation()->getRegion(0).front(); return !block.getArgument(bbArgNumber).use_empty(); } SmallVector ScatterOp::getIterationDomain(OpBuilder &builder) { Location loc = getLoc(); Value zero = builder.create(loc, 0); Value one = builder.create(loc, 1); SmallVector ranges; for (auto dim : llvm::seq(0, getUpdateType().getRank())) { Value ub = getDimValue(builder, loc, updates(), dim); ranges.emplace_back(Range{zero, ub, one}); } return ranges; } LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b, Location loc, ValueRange ivs) { auto indexDepth = getIndexDepth(); Value update = b.create(loc, updates(), ivs); SmallVector starts; SmallVector loadIndices; loadIndices.push_back(ivs.front()); loadIndices.push_back(Value()); // Populate with empty values. auto originalTy = cast(original().getType()); starts.resize(originalTy.getRank(), Value()); auto updateIvs = ivs.drop_front(1); int64_t offset = starts.size() - updateIvs.size(); for (auto it : llvm::enumerate(updateIvs)) { starts[it.index() + offset] = it.value(); } ArrayRef dimMap = getDimensionMap(); for (auto i : llvm::seq(0, indexDepth)) { loadIndices.back() = b.create(loc, i); Value idx = b.create(loc, indices(), loadIndices); Value ret = b.create(loc, b.getIndexType(), idx); auto dim = dimMap[i]; if (starts[dim]) ret = b.create(loc, ret, starts[dim]); starts[dim] = ret; } Value init = b.create(loc, original(), starts); IRMapping bvm; Block &block = getRegion().front(); bvm.map(block.getArgument(0), update); bvm.map(block.getArgument(1), init); for (auto &blockOp : block.without_terminator()) { b.clone(blockOp, bvm); } // The last op is linalg_ext.yield op. Store the operand to // destination. b.create( loc, bvm.lookupOrDefault(block.getTerminator()->getOperand(0)), original(), starts); return success(); } //===----------------------------------------------------------------------===// // SortOp //===----------------------------------------------------------------------===// LogicalResult SortOp::verify() { Operation *op = getOperation(); if (getNumInputs()) { return op->emitOpError("does not expect to take any inputs"); } if (getNumOutputs() == 0) { return op->emitOpError("expected at least one `outs` operand"); } Block &block = getRegion().front(); size_t numOutputs = getNumOutputs(); if (block.getNumArguments() != 2 * numOutputs) { return op->emitOpError("region block should have ") << 2 * numOutputs << " arguments"; } int64_t rank = getOperandRank(); int sortDim = getDimension(); if (sortDim < 0 || sortDim >= rank) { return op->emitOpError("dimension must be within (0, ") << rank << "]"; } ArrayRef shape = getOperandShape(); for (auto indexedOperand : llvm::enumerate(getOutputs())) { int index = indexedOperand.index(); auto operandType = getOperandType(index); if (operandType.getRank() != rank) { return op->emitOpError("expected operand ") << index << " to be rank " << rank << ", same as other operands"; } if (operandType.getShape() != shape) { return op->emitOpError("expected operand ") << index << " to have same shape as other operands"; } Type elemType = operandType.getElementType(); for (int i : {2 * index, 2 * index + 1}) { Type argType = block.getArgument(i).getType(); if (argType != elemType) { return op->emitOpError("region block argument #") << i << " should be of type " << elemType << " but got " << argType; } } } auto yieldOp = cast(block.getTerminator()); if (yieldOp.getNumOperands() != 1) { return op->emitOpError("should yield exactly one operand"); } auto ty = dyn_cast(yieldOp.getOperand(0).getType()); if (!ty || ty.getWidth() != 1) { return op->emitOpError("should yield i1 type"); } return success(); } SmallVector SortOp::getLoopIteratorTypes() { // All loops except the dimension to sort along are parallel. SmallVector iteratorTypes(getOperandRank(), utils::IteratorType::parallel); iteratorTypes[getDimension()] = utils::IteratorType::reduction; return iteratorTypes; } SmallVector SortOp::getIterationDomain(OpBuilder &builder) { int64_t operandRank = getOperandRank(); SmallVector loopBounds(operandRank); Location loc = getLoc(); Value zero = builder.create(loc, 0); Value one = builder.create(loc, 1); Value source = operand(0); for (auto dim : llvm::seq(0, operandRank)) { loopBounds[dim].offset = zero; loopBounds[dim].size = getDimValue(builder, loc, source, dim); loopBounds[dim].stride = one; } return loopBounds; } LogicalResult SortOp::generateScalarImplementation(OpBuilder &b, Location loc, ValueRange ivs) { auto sortDim = getDimension(); SmallVector indices, sortBlkArgs; indices.append(ivs.begin(), ivs.end()); // Bubble sort innermost loop. Value zero = b.create(loc, 0); Value one = b.create(loc, 1); Value ub; if (getOperandType(0).isDynamicDim(sortDim)) { ub = b.create(loc, operand(0), sortDim); } else { ub = b.create( loc, getOperandType(0).getDimSize(sortDim)); } ub = b.create(loc, ub, one); auto scfFor = b.create( loc, zero, ub, one, ValueRange{}, [&](OpBuilder &b, Location loc, Value iv, ValueRange iters) { SmallVector indices(ivs); Value ivPlusOne = b.create(loc, iv, one); for (auto output : getOutputOperands()) { indices[sortDim] = iv; sortBlkArgs.push_back( b.create(loc, output->get(), indices)); indices[sortDim] = ivPlusOne; sortBlkArgs.push_back( b.create(loc, output->get(), indices)); } }); auto &srcBlock = getRegion().front(); Region ®ion = scfFor.getRegion(); IRMapping bvm; { OpBuilder::InsertionGuard guard(b); auto &block = region.front(); b.setInsertionPointToEnd(&block); for (auto it : llvm::zip(srcBlock.getArguments(), sortBlkArgs)) { bvm.map(std::get<0>(it), std::get<1>(it)); } for (auto &blockOp : srcBlock.without_terminator()) { b.clone(blockOp, bvm); } } Value cond = bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)); OpBuilder::InsertionGuard g(b); b.setInsertionPointToEnd(®ion.front()); b.create( loc, cond, [&](OpBuilder &b, Location loc) { // Do not swap the pairs if true. b.create(loc); }, [&](OpBuilder &b, Location loc) { // Swap the pairs if false. SmallVector indices(ivs.begin(), ivs.end()); Value ivPlusOne = b.create(loc, scfFor.getInductionVar(), one); for (int i = 0, e = getNumOutputs(); i < e; ++i) { Value v1 = sortBlkArgs[i * 2]; Value v2 = sortBlkArgs[i * 2 + 1]; indices[sortDim] = scfFor.getInductionVar(); b.create(loc, v2, getOutputOperand(i)->get(), indices); indices[sortDim] = ivPlusOne; b.create(loc, v1, getOutputOperand(i)->get(), indices); } b.create(loc); }); b.create(loc); return success(); } bool SortOp::payloadUsesValueFromOperand(OpOperand *opOperand) { // All operands of SortOp will be sorted. So, we'll end up loading/storing // from them - hence setting this utility to always return `true`. return true; } //===----------------------------------------------------------------------===// // TopkOp //===----------------------------------------------------------------------===// LogicalResult TopkOp::verify() { Operation *op = getOperation(); if (getNumInputs() != 1 && getNumInputs() != 2) { return op->emitOpError("expected one or two input operands"); } if (getNumOutputs() != 2) { return op->emitOpError("expected two output operands"); } // First check added to eliminate comparison of different int types if (getInputRank() < 0 || (getDimension() >= static_cast(getInputRank()))) { return op->emitOpError("dimension exceeds rank"); } // Ensure input/output element types match auto inputValuesType = cast(values().getType()); auto outputValuesType = cast(outputValues().getType()); if (inputValuesType.getElementType() != outputValuesType.getElementType()) { return op->emitOpError("expected input/output value types to be identical"); } // Indices must be int if provided auto outputIndicesType = cast(outputIndices().getType()); if (auto inputIndices = indices()) { auto inputIndicesType = cast(inputIndices->getType()); if (!inputIndicesType.getElementType().isInteger(32) || !outputIndicesType.getElementType().isInteger(32)) { return op->emitOpError("expected input/output indices types to be int32"); } } // Ranks must match if (inputValuesType.getRank() != outputValuesType.getRank()) { return op->emitOpError("expected input/output to have the same rank"); } if (auto inputIndices = indices()) { auto inputIndicesType = cast(inputIndices->getType()); if (inputIndicesType.getRank() != outputIndicesType.getRank()) { return op->emitOpError("expected input/output to have the same rank"); } } // Input indicies and values must have the same shape. if (auto inputIndices = indices()) { auto inputIndicesType = cast(inputIndices->getType()); if (failed(verifyCompatibleShape(inputValuesType, inputIndicesType))) return op->emitOpError("input indices/values shape must match"); } // Output indicies and values must have the same shape. if (failed(verifyCompatibleShape(outputValuesType, outputIndicesType))) return op->emitOpError("output indices/values shape must match"); // Input shape must match the output shape except for the dimension() uint64_t dim = getDimension(); if (!llvm::all_of(llvm::enumerate(llvm::zip(inputValuesType.getShape(), outputValuesType.getShape())), [dim](auto e) { if (e.index() == dim) { return true; } std::tuple s = e.value(); return succeeded(verifyCompatibleShape(std::get<0>(s), std::get<1>(s))); })) { return op->emitOpError("incompatible input/output shapes"); } // Check region compatibility Block &block = getRegion().front(); if (block.getNumArguments() != 2) { return op->emitOpError("region block should have 2 arguments"); } if (block.getArgument(0).getType() != inputValuesType.getElementType() || block.getArgument(1).getType() != inputValuesType.getElementType()) { return op->emitOpError("region block types must match input"); } auto terminatorOp = llvm::dyn_cast(block.getTerminator()); if (!terminatorOp || !terminatorOp.getOperand(0).getType().isInteger(1)) { return op->emitOpError("region block must end with a linalg_ext.yield i1!"); } return success(); } SmallVector TopkOp::getLoopIteratorTypes() { SmallVector iteratorTypes(getInputRank(), utils::IteratorType::parallel); iteratorTypes[getDimension()] = utils::IteratorType::reduction; return iteratorTypes; } SmallVector TopkOp::getIterationDomain(OpBuilder &builder) { int64_t operandRank = getInputRank(); SmallVector loopBounds(operandRank); Location loc = getLoc(); Value zero = builder.create(loc, 0); Value one = builder.create(loc, 1); Value source = values(); for (auto dim : llvm::enumerate(getInputType().getShape())) { loopBounds[dim.index()].offset = zero; loopBounds[dim.index()].size = getDimValue(builder, loc, source, dim.index()); loopBounds[dim.index()].stride = one; } return loopBounds; } LogicalResult TopkOp::generateScalarImplementation(OpBuilder &b, Location loc, ValueRange ivs) { uint64_t kDim = getDimension(); Value zero = b.create(loc, 0); Value one = b.create(loc, 1); Value initialValue = b.create(loc, values(), ivs); // If the indices tensor is not provided, the value index is derived from the // loop induction variables. Value initialIndex; if (indices()) { initialIndex = b.create(loc, *indices(), ivs); } else { Value rawInitialIndex = ivs[kDim]; initialIndex = b.create(loc, b.getI32Type(), rawInitialIndex); } // Compute K (ub) from the selected dim of the output Value ub = b.create(loc, outputValues(), getDimension()); // Inner K loop functions: // Load current K value and index // Compare N/K using inserted block compare // Check if N == K using strict weak ordering, select which index came first // Select new K value from N/K comparison // Select new K index from N/K comparison or which index came first // Store new k value and index // Yield loop carry values after K selection Value kValue, kIndex; auto scfFor = b.create( loc, zero, ub, one, ValueRange{initialValue, initialIndex}, [&](OpBuilder &b, Location loc, Value iv, ValueRange loopCarryValues) { SmallVector indices(ivs); indices[kDim] = iv; kValue = b.create(loc, outputValues(), indices); kIndex = b.create(loc, outputIndices(), indices); }); SmallVector indices(ivs); indices[kDim] = scfFor.getInductionVar(); auto loopCarryValues = scfFor.getRegionIterArgs(); // Retrieve region as black box comparision function f(x,y). Plug into op. auto &srcBlock = getRegion().front(); IRMapping bvmF; // f(x,y) IRMapping bvmR; // f(y,x) { // Save previous insertion point. Continue within loop body. OpBuilder::InsertionGuard guard(b); b.setInsertionPointToEnd(&scfFor.getRegion().front()); SmallVector forwardValues{loopCarryValues[0], kValue}; SmallVector reverseValues{kValue, loopCarryValues[0]}; for (auto it : llvm::zip(srcBlock.getArguments(), forwardValues)) { bvmF.map(std::get<0>(it), std::get<1>(it)); } for (auto it : llvm::zip(srcBlock.getArguments(), reverseValues)) { bvmR.map(std::get<0>(it), std::get<1>(it)); } for (auto &blockOp : srcBlock.without_terminator()) { b.clone(blockOp, bvmF); b.clone(blockOp, bvmR); } Value forwardCmpRes = bvmF.lookup(srcBlock.getTerminator()->getOperand(0)); Value reverseCmpRes = bvmR.lookup(srcBlock.getTerminator()->getOperand(0)); // Check value equality using strictly weak ordering from the region: // f(x,y) --> forwardCmpRes // f(y,x) --> reverseCmpRes // if forwardCmpRes == reverseCmpRes then select which came first Value cmpValuesEqual = b.create( loc, arith::CmpIPredicate::eq, forwardCmpRes, reverseCmpRes); Value cmpFirstIndex = b.create( loc, arith::CmpIPredicate::slt, loopCarryValues[1], kIndex); Value combinedCmpEqRes = b.create(loc, cmpValuesEqual, cmpFirstIndex); // True if N > K or N came before K Value indexCmpRes = b.create(loc, forwardCmpRes, combinedCmpEqRes); // Select results for K based on comparisons Value resultKValue = b.create(loc, forwardCmpRes, loopCarryValues[0], kValue); Value resultKIndex = b.create(loc, indexCmpRes, loopCarryValues[1], kIndex); b.create(loc, resultKValue, outputValues(), indices); b.create(loc, resultKIndex, outputIndices(), indices); // Select loop carry, opposite of K results Value resultCarryValue = b.create( loc, forwardCmpRes, kValue, loopCarryValues[0]); Value resultCarryIndex = b.create(loc, indexCmpRes, kIndex, loopCarryValues[1]); b.create(loc, ValueRange{resultCarryValue, resultCarryIndex}); } return success(); } bool TopkOp::payloadUsesValueFromOperand(OpOperand *opOperand) { // Set to true so that output operands are always initialized. return true; } #define DEFINE_OP_GET_EFFECTS(OP_NAME) \ void OP_NAME::getEffects( \ SmallVectorImpl> \ &effects) { \ OpOperandVector inputBuffers = getInputBufferOperands(); \ OpOperandVector outputBuffers = getOutputBufferOperands(); \ getEffectsImpl(effects, getOperation()->getResults(), inputBuffers, \ outputBuffers); \ } DEFINE_OP_GET_EFFECTS(AttentionOp) DEFINE_OP_GET_EFFECTS(ScanOp) DEFINE_OP_GET_EFFECTS(ScatterOp) DEFINE_OP_GET_EFFECTS(SortOp) DEFINE_OP_GET_EFFECTS(TopkOp) namespace { /// This is derived from mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp without any /// changes. struct FoldTensorCastOp : public OpInterfaceRewritePattern { using OpInterfaceRewritePattern::OpInterfaceRewritePattern; LogicalResult matchAndRewrite(TMTensorOp op, PatternRewriter &rewriter) const override { // If no operand comes from a tensor::CastOp and can be folded then fail. bool hasTensorCastOperand = llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) { if (isa(opOperand->get())) return false; auto castOp = opOperand->get().getDefiningOp(); return castOp && canFoldIntoConsumerOp(castOp); }); if (!hasTensorCastOperand) return failure(); SmallVector newResultTypes; newResultTypes.reserve(op->getNumResults()); SmallVector newOperands; newOperands.reserve(op->getNumOperands()); // Inputs may fold. for (OpOperand *opOperand : op.getInputOperands()) { auto tensorCastOp = opOperand->get().getDefiningOp(); newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp) ? tensorCastOp.getSource() : opOperand->get()); } // Init tensors may fold, in which case the resultType must also change. for (OpOperand *opOperand : op.getOutputOperands()) { auto tensorCastOp = opOperand->get().getDefiningOp(); bool fold = canFoldIntoConsumerOp(tensorCastOp); newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand->get()); newResultTypes.push_back(newOperands.back().getType()); } // Clone op. Operation *newOp = op.clone(rewriter, op->getLoc(), newResultTypes, newOperands); SmallVector replacements; replacements.reserve(newOp->getNumResults()); for (auto result : llvm::zip(op->getResults(), newOp->getResults())) { Value oldResult = std::get<0>(result); Value newResult = std::get<1>(result); if (newResult.getType() != oldResult.getType()) { replacements.push_back(rewriter.create( op->getLoc(), oldResult.getType(), newResult)); } else { replacements.push_back(newResult); } } rewriter.replaceOp(op, replacements); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // TMTensorDialect //===----------------------------------------------------------------------===// void TMTensorDialect::getCanonicalizationPatterns( RewritePatternSet &results) const { results.add(getContext()); } #define GET_OP_CLASSES #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.cpp.inc"