//===- ReduceOpVariants.cpp --------------------------------------*- C++-*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "ReifyAbstractInterpCalculationsUtils.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "llvm/ADT/StringExtras.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; // Create an overwrite in a manner that preserves the // `OverwriteTensorContentsOp` invariant that both arguments // must have the same shape and dtype. static void createOverwriteTensorContents(PatternRewriter &rewriter, Location loc, Value overwriterTensor, Value overwrittenTensor) { Type overwriterTensorType = overwriterTensor.getType(); Type overwrittenTensorType = dyn_cast(overwrittenTensor.getType()) .getWithValueSemantics(); if (overwriterTensorType != overwrittenTensorType) { overwriterTensor = rewriter.create( loc, overwrittenTensorType, overwriterTensor); } rewriter.create(loc, overwriterTensor, overwrittenTensor); } static Type getContainerOrTensorTypeWithValueSemantics(Type type) { if (auto optionalType = dyn_cast(type)) { Type newContainedType = getContainerOrTensorTypeWithValueSemantics( optionalType.getContainedType()); return OptionalType::get(newContainedType); } else if (auto listType = dyn_cast(type)) { Type newContainedType = getContainerOrTensorTypeWithValueSemantics(listType.getContainedType()); return ListType::get(newContainedType); } else if (auto tensorType = dyn_cast(type)) { return tensorType.getWithValueSemantics(); } else { return nullptr; } } static bool operatorOpHasValueSemantics(OperatorOp opOp, std::optional extraLibrary) { if (!extraLibrary.has_value()) return false; auto opName = cast(opOp->getAttr("name")).getValue(); std::string libFuncName = (mlir::torch::Torch::getLibraryFunctionPrefix( LibraryFunctionKind::HasValueSemantics) + Twine(opName)) .str(); auto libFunc = extraLibrary->lookup(libFuncName); return bool(libFunc); } namespace { // Convert value semantic ops operating on mutable arrays to instead operate on // immutable tensors. class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { public: ConvertHasValueSemanticsOpsToValueTensors( MLIRContext *context, const std::optional &extraLibrary) : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) { this->extraLibrary = extraLibrary; } LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { if (isa(op)) { if (!operatorOpHasValueSemantics(cast(op), extraLibrary)) { return rewriter.notifyMatchFailure(op, "does not have value semantics"); } } else if (!op->hasTrait()) { return rewriter.notifyMatchFailure(op, "does not have value semantics"); } rewriter.startOpModification(op); // Convert all operands. SmallVector newOperands; for (OpOperand &opOperand : op->getOpOperands()) { Type operandType = opOperand.get().getType(); if (isa(operandType)) { opOperand.set(rewriter.create(op->getLoc(), opOperand.get())); } else if (auto listType = dyn_cast(operandType)) { if (!(isa(listType.getContainedType()) || isa(listType.getContainedType()))) continue; // Construct a new list whose elements are value tensors copied from // the non-value tensors of the original list. auto listConstruct = opOperand.get().getDefiningOp(); if (!listConstruct) { rewriter.cancelOpModification(op); return rewriter.notifyMatchFailure( op, "unimplemented: list of non vtensor type not constructed " "from list construct"); } if (listConstruct.getElements().empty()) continue; // TODO: Handle optional type in list type. if (auto optionalType = dyn_cast(listType.getContainedType())) { if (!llvm::all_of(listConstruct.getElements(), [](Value val) { return val.getType().isa(); })) { rewriter.cancelOpModification(op); return rewriter.notifyMatchFailure( op, "unimplemented: list containing optional type is not " "handled."); } } auto newListElements = llvm::to_vector(llvm::map_range( listConstruct.getElements(), [&](Value tensor) -> Value { if (isa(tensor.getType())) { return rewriter.create(op->getLoc(), tensor); } return tensor; })); Type newListType = getContainerOrTensorTypeWithValueSemantics(listType); if (!newListType) { rewriter.cancelOpModification(op); return rewriter.notifyMatchFailure( op, "Unable to convert list type to value semantics."); } opOperand.set(rewriter.create( op->getLoc(), newListType, newListElements)); } else if (auto optionalType = dyn_cast(operandType)) { // TODO: A more general way to handle the optional type is to // introduce a `copy.to_optional_vtensor` op. if (!isa(optionalType.getContainedType())) continue; // Create a new optional value whose input is a value tensor copied // from the non value tensor of the original optional value. auto derefine = opOperand.get().getDefiningOp(); if (!derefine) { rewriter.cancelOpModification(op); return rewriter.notifyMatchFailure( op, "unimplemented: optional of non vtensor type not from " "derefine"); } if (!isa(derefine.getOperand().getType())) continue; auto newOperand = rewriter.create( op->getLoc(), derefine.getOperand()); opOperand.set(rewriter.create( op->getLoc(), Torch::OptionalType::get(newOperand.getType()), newOperand)); } } // Convert all results. rewriter.setInsertionPointAfter(op); for (Value result : op->getResults()) { auto tensorType = dyn_cast(result.getType()); if (!tensorType) continue; result.setType(tensorType.getWithValueSemantics()); auto nonValueTensor = rewriter.create(op->getLoc(), result); result.replaceAllUsesExcept(nonValueTensor, nonValueTensor); } rewriter.finalizeOpModification(op); return success(); } private: std::optional extraLibrary; }; } // namespace namespace { class TorchMatchSpecializedBackendOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using HandlerFn = LogicalResult (*)(OperatorOp op, ConversionPatternRewriter &rewriter); LogicalResult matchAndRewrite(Torch::OperatorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (namedHandlers.contains(op.getNameAttr())) { return namedHandlers.lookup(op.getNameAttr()).front()(op, rewriter); } return failure(); } static void populateSpecializedConversions(TorchMatchSpecializedBackendOp &matcher); static std::unique_ptr getPopulatedMatcher(MLIRContext *context) { auto matcher = std::make_unique(context); populateSpecializedConversions(*matcher); return matcher; }; void populate(StringRef name, HandlerFn fn) { namedHandlers[StringAttr::get(getContext(), name)].push_back(fn); } void populateLegalizedNames(llvm::DenseSet &set) { for (auto handle : namedHandlers) { set.insert(handle.first); } } private: DenseMap> namedHandlers; }; void TorchMatchSpecializedBackendOp::populateSpecializedConversions( TorchMatchSpecializedBackendOp &matcher) { matcher.populate( "torch.aten._scaled_dot_product_flash_attention_for_cpu", [](Torch::OperatorOp op, ConversionPatternRewriter &rewriter) -> LogicalResult { auto uses = op.getResult(1).getUses(); if (uses.end() == uses.begin()) { auto oldOperands = op->getOperands(); llvm::SmallVector newOperands{ oldOperands[0], oldOperands[1], oldOperands[2], oldOperands[5], oldOperands[3], oldOperands[4], oldOperands[6]}; auto newOp = rewriter.create( op.getLoc(), op->getResultTypes()[0], newOperands, op->getAttrs()); rewriter.replaceAllUsesWith(op.getResult(0), newOp.getResult()); rewriter.eraseOp(op); return success(); } return failure(); }); } bool isSpecializedOperation(Torch::OperatorOp op) { return true; } } // namespace // Reduce Ops without value semantics but the corresponding without trailing // underscore variant doesn't exist. namespace { // int(ceil((end - start) / step)) Value calculateArangeResultNumElements(PatternRewriter &rewriter, Location loc, Value start, Value end, Value step) { Value sub = rewriter.create( loc, Torch::NumberType::get(rewriter.getContext()), end, start); Value div = rewriter.create(loc, sub, step); return rewriter.create(loc, div); } class ReduceNonValueSemanticOps : public RewritePattern { public: ReduceNonValueSemanticOps(MLIRContext *context) : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { Location loc = op->getLoc(); MLIRContext *ctx = op->getContext(); if (isa(op)) { Operation *newOp = rewriter.create( loc, op->getResultTypes(), op->getOperands()); auto tensor = rewriter.create(loc, newOp->getResult(0)); createOverwriteTensorContents(rewriter, loc, tensor, op->getOperand(0)); rewriter.replaceOp(op, op->getOperand(0)); return success(); } else if (auto arangeOutOp = dyn_cast(op)) { Value start = arangeOutOp.getStart(); Value end = arangeOutOp.getEnd(); Value step = arangeOutOp.getStep(); Value out = arangeOutOp.getOut(); // `overwrite.tensor.contents` cannot change the tensor shape, // so `out` tensor should have same num_elements with result tensor. // It means that we don't support code like: // `x = torch.randn(12)` // `y = torch.arange(13, out=x)` Value resultNumElements = calculateArangeResultNumElements(rewriter, loc, start, end, step); Value outNumElements = rewriter.create(loc, out); Value eqOrNot = rewriter.create(loc, resultNumElements, outNumElements); rewriter.create( loc, eqOrNot, rewriter.getStringAttr("`out` tensor should have the same " "num_elements with result tenosr")); auto dtype = rewriter.create(loc, out); auto device = rewriter.create(loc, out); auto shape = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(ctx)), out); auto none = rewriter.create(loc); Value newArange = rewriter.create( loc, arangeOutOp.getResult().getType(), start, end, step, dtype, /*layout=*/none, device, /*pin_memory=*/none); Value reshape = rewriter.create( loc, arangeOutOp.getResult().getType(), newArange, shape); auto vtensor = rewriter.create(loc, reshape); createOverwriteTensorContents(rewriter, loc, vtensor, out); rewriter.replaceOp(arangeOutOp, out); return success(); } else { return failure(); } } }; } // namespace namespace { // Reduce the "trailing underscore inplace variant" to the value semantic // variant + an overwrite of the original "self" argument. class ReduceTrailingUnderscoreInplaceVariant : public RewritePattern { public: ReduceTrailingUnderscoreInplaceVariant(MLIRContext *context) : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { if (!op->hasTrait()) return rewriter.notifyMatchFailure(op, "is not trailing_ variant"); SmallVector fragments; llvm::SplitString(op->getName().getStringRef(), fragments, "."); assert(fragments.size() >= 3 && fragments[2].ends_with("_") && "IsTrailingUnderscoreInplaceVariant incorrectly applied"); fragments[2] = fragments[2].drop_back(); std::string noUnderscoreName = llvm::join(fragments, "."); OperationState state(op->getLoc(), noUnderscoreName); state.addTypes(op->getResultTypes()); state.addOperands(op->getOperands()); state.addAttributes(op->getAttrDictionary().getValue()); // Note: No successors or regions. Torch JIT operators don't have any. assert(op->getNumRegions() == 0 && op->getNumSuccessors() == 0 && "Torch JIT operators shouldn't have regions or successors"); Operation *newOp = rewriter.create(state); // Note: need to convert result to first input's dtype because mix precision // compute would result in different behaviors. // For example: // a = torch.randn(3, 3).half() # float16 // b = torch.randn(3, 3) # float32 // a += b # i.e. torch.ops.aten.add_(a, b), result is float16 // c = a + b # i.e. torch.ops.aten.add(a, b), result is float32 Value none = rewriter.create(op->getLoc()); Value cstFalse = rewriter.create(op->getLoc(), false); auto aDtype = rewriter.create(op->getLoc(), op->getOperand(0)); auto toDtype = rewriter.create( op->getLoc(), newOp->getResult(0).getType(), newOp->getResult(0), aDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); auto tensor = rewriter.create(op->getLoc(), toDtype); createOverwriteTensorContents(rewriter, op->getLoc(), tensor, op->getOperand(0)); rewriter.replaceOp(op, op->getOperand(0)); return success(); } }; } // namespace static LogicalResult reduceNonValueTensorLiteralOpToValueTensorLiteralOp(NonValueTensorLiteralOp op, PatternRewriter &rewriter) { Value valueTensor = rewriter.create(op->getLoc(), op.getValue()); Value tensor = copyTensorToType(rewriter, op->getLoc(), op.getType(), valueTensor); rewriter.replaceOp(op, {tensor}); return success(); } namespace { struct ReduceOpVariantsPass : public ReduceOpVariantsBase { ReduceOpVariantsPass() = default; ReduceOpVariantsPass(StringRef extraLibrary) { this->extraLibrary = extraLibrary.str(); } void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); OwningOpRef extraLibraryModule = ModuleOp::create(UnknownLoc::get(context)); std::optional extraLibraryModuleSymTable = std::nullopt; if (!extraLibrary.empty()) { if (failed(loadExtraLibrary(extraLibrary, extraLibraryModule))) { emitError(getOperation()->getLoc(), "Failed to load extra-library file at " + extraLibrary); return signalPassFailure(); } extraLibraryModuleSymTable = SymbolTable(extraLibraryModule->getOperation()); } patterns.add( context, extraLibraryModuleSymTable); patterns.add(context); patterns.add(reduceNonValueTensorLiteralOpToValueTensorLiteralOp); patterns.add(context); // Create specialized matcher: auto specialized = TorchMatchSpecializedBackendOp::getPopulatedMatcher(context); DenseSet specializedNames; specialized->populateLegalizedNames(specializedNames); patterns.insert(std::move(specialized)); ConversionTarget target(*context); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable, &specializedNames](Operation *op) { if (isa(op)) { if (specializedNames.contains(cast(op).getNameAttr())) { return false; } } if (op->hasTrait() || (isa(op) && operatorOpHasValueSemantics(cast(op), extraLibraryModuleSymTable))) { auto hasValueSemantics = [](Type t) { // TODO: Make this an allowlist based on a closed torch dialect // type system. if (auto tensorType = dyn_cast(t)) { return false; } return true; }; return llvm::all_of(op->getOperandTypes(), hasValueSemantics) && llvm::all_of(op->getResultTypes(), hasValueSemantics); } if (op->hasTrait()) { return false; } if (isa(op) && isSpecializedOperation(cast(op))) return false; return true; }); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { return signalPassFailure(); } } }; } // namespace std::unique_ptr> mlir::torch::Torch::createReduceOpVariantsPass(StringRef extraLibrary) { return std::make_unique(extraLibrary); }