//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LLVM.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/ADT/BitVector.h" #include "llvm/ADT/StringMap.h" #include "llvm/Support/Casting.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; //===----------------------------------------------------------------------===// // Utilities //===----------------------------------------------------------------------===// Value mlir::torch::Torch::adjustStaticInformation(OpBuilder &builder, Location loc, Value value, Type desiredType, bool userAllowsRefinement) { Type type = value.getType(); // If the value is already of the desired type, we're done. if (type == desiredType) return value; // If the type is a tensor, then adjust the static information. if ((type.isa() && desiredType.isa()) || (type.isa() && desiredType.isa())) { Value adjusted = builder.create(value.getLoc(), desiredType, value); return adjusted; } // If the type is a subtype of desiredType, then we need to derefine it to // desiredType, unless the user allows refinement. if (isValidSubtype(type, desiredType)) { if (!userAllowsRefinement) { Value adjusted = builder.create(value.getLoc(), desiredType, value); return adjusted; } else { return value; } } // If the desiredType is subtype of type, then we assume that the desiredType // is dynamically valid, so we do an unchecked cast. if (isValidSubtype(desiredType, type)) { Value adjusted = builder.create(value.getLoc(), desiredType, value); return adjusted; } // No known adjustment. return Value(); } Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc, BaseTensorType newType, Value tensor) { auto originalType = tensor.getType().cast(); // Adjust the static information in the type to match between the original and // new types. if (!originalType.hasSameSizesAndDtype(newType)) { tensor = builder.create( loc, originalType.getWithSizesAndDtypeFrom(newType), tensor); } // Unless both the original and new types are both value tensors, we end // up creating one op that converts between the value and non-value tensor // domains. If both the original and new types are both non-value tensors, // then we do the copy by going to a value tensor and back. if (tensor.getType().isa()) tensor = builder.create(loc, tensor); if (newType.isa()) tensor = builder.create(loc, tensor); return tensor; } bool mlir::torch::Torch::isListPotentiallyMutated(Value list) { assert(list.getType().isa()); return llvm::any_of(list.getUsers(), potentiallyMutatesListOperands); } bool mlir::torch::Torch::potentiallyMutatesListOperands(Operation *op) { // TODO: Find a better place to put this assertion. assert((!op->hasTrait() || op->hasTrait()) && "HasValueSemantics should imply ReadOnly!"); // ReadOnly ops trivially do not mutate any list operands. if (op->hasTrait()) return false; // Ops with no MemoryEffectOpInterface effects also do not mutate any list // operands. if (auto effects = dyn_cast(op)) { if (effects.hasNoEffect()) return false; } // Conservatively assume that an op might mutate any list operands. return true; } static IntegerAttr getI64IntegerAttr(MLIRContext *context, int64_t value) { return IntegerAttr::get(IntegerType::get(context, 64), value); } static FloatAttr getF64FloatAttr(MLIRContext *context, double value) { return FloatAttr::get(Float64Type::get(context), value); } static Value getScalarIntValue(Value input, Location loc, PatternRewriter &rewriter) { auto inputType = input.getType(); if (inputType.isa()) { return input; } auto inputTensorType = inputType.dyn_cast(); if (!inputTensorType) return nullptr; Type inputDtype = inputTensorType.getOptionalDtype(); if (!inputDtype || !inputDtype.isInteger(64)) return nullptr; std::optional inputRank = getTensorRank(input); if (!inputRank || *inputRank != 0) return nullptr; if (auto valueTensorLiteralOp = input.getDefiningOp()) { auto val = valueTensorLiteralOp.getValue() .cast() .getSplatValue(); return rewriter.create( loc, rewriter.getI64IntegerAttr(val)); } else if (auto primNumToTensorScalarOp = input.getDefiningOp()) { return primNumToTensorScalarOp.getA(); } else if (auto tensorIntOp = input.getDefiningOp()) { return tensorIntOp.getT(); } return nullptr; } //===----------------------------------------------------------------------===// // MethodOp //===----------------------------------------------------------------------===// LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) { auto func = symbolTable.lookupNearestSymbolFrom(*this, getFunctionAttr()); if (!func) return emitError() << "'@" << getFunction() << "' does not reference a valid function"; if (func.getVisibility() != SymbolTable::Visibility::Private) return emitError() << "'@" << getFunction() << "' must reference a private function"; if (func.isDeclaration()) return emitError() << "'@" << getFunction() << "' must reference a function that is defined (not " "merely declared)"; auto expectedReceiverArgType = NnModuleType::get( getContext(), getOperation()->getParentOfType().getName()); if (func.getFunctionType().getNumInputs() == 0 || func.getFunctionType().getInput(0) != expectedReceiverArgType) { return emitError() << "the referenced function '" << getFunction() << "' must have a first argument of type " << expectedReceiverArgType; } return success(); } //===----------------------------------------------------------------------===// // NnModuleOp //===----------------------------------------------------------------------===// LogicalResult NnModuleOp::verify() { for (Operation &child : *getBody()) if (!isa(&child)) return child.emitOpError() << "is not allowed inside 'torch.nn_module'"; return success(); } LogicalResult NnModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) { auto classType = symbolTable.lookupNearestSymbolFrom( *this, SymbolRefAttr::get(getContext(), getClassName())); if (!classType) return emitError() << "'" << getClassName() << "' does not reference a valid class type"; auto attrs = llvm::to_vector<6>(getBody()->getOps()); auto attrDefs = llvm::to_vector<6>(classType.getBody()->getOps()); if (attrs.size() != attrDefs.size()) return emitError() << "number of 'torch.slot's in a 'torch.nn_module' must " "match number of 'torch.attr's in " "the corresponding 'torch.class_type'"; for (int i = 0, e = attrs.size(); i != e; i++) { SlotOp attr = attrs[i]; AttrOp attrDef = attrDefs[i]; if (!isValidSubtype(attr.getValue().getType(), attrDef.getType()) || attr.getName() != attrDef.getName()) { return attr.emitOpError() .append("is expected to match type and name of '", attrDef.getOperation(), "'") .attachNote(attrDef.getLoc()) .append("see torch.attr at corresponding index ", i, " here"); } } return success(); } //===----------------------------------------------------------------------===// // PrimListConstructOp //===----------------------------------------------------------------------===// LogicalResult PrimListConstructOp::verify() { auto resultType = getResult().getType(); auto resultElementType = resultType.dyn_cast().getContainedType(); auto matchResultElementType = [&](Type type) { return isValidSubtype(type, resultElementType); }; if (!llvm::all_of(getOperandTypes(), matchResultElementType)) { return emitError() << "operand types should have the same type as the " "list contained type"; } return success(); } //===----------------------------------------------------------------------===// // PrimDictConstructOp //===----------------------------------------------------------------------===// LogicalResult PrimDictConstructOp::verify() { auto isValidSubTypeOf = [](Type expectedType) { return [=](Type type) { return isValidSubtype(type, expectedType); }; }; if (!llvm::all_of(getKeys().getTypes(), isValidSubTypeOf(getKeyType()))) return emitError() << "keys should be of Dict key type"; if (!llvm::all_of(getValues().getTypes(), isValidSubTypeOf(getValueType()))) return emitError() << "values should be of Dict value type"; return success(); } //===----------------------------------------------------------------------===// // ClassTypeOp //===----------------------------------------------------------------------===// LogicalResult ClassTypeOp::verify() { llvm::StringMap namesToOps; for (Operation &child : getBody()->without_terminator()) { if (!isa(&child)) return child.emitOpError() << "is not allowed inside `torch.class_type`"; StringRef name; if (auto attr = dyn_cast(child)) name = attr.getName(); else name = cast(child).getName(); auto itAndWasInserted = namesToOps.insert({name, &child}); auto it = itAndWasInserted.first; bool wasInserted = itAndWasInserted.second; if (!wasInserted) { auto diag = emitOpError().append("has duplicate attr/method with name '", name, "'"); diag.attachNote(it->second->getLoc()) .append("see first conflicting attr/method here"); diag.attachNote(child.getLoc()) .append("see second conflicting attr/method here"); return failure(); } } return success(); } //===----------------------------------------------------------------------===// // PrimLoopOp //===----------------------------------------------------------------------===// OperandRange PrimLoopOp::getEntrySuccessorOperands(RegionBranchPoint point) { assert(point == getRegion()); return getIterArgsInit(); } void PrimLoopOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { Region ®ion = getRegion(); if (!point.getRegionOrNull()) { regions.emplace_back(®ion, region.getArguments().slice(1)); return; } assert(point == region); regions.emplace_back(®ion, region.getArguments().slice(1)); regions.emplace_back(getResults()); } bool PrimLoopOp::isForLike() { bool b; return matchPattern(getInitialCondition(), m_TorchConstantBool(&b)) && b; } //===----------------------------------------------------------------------===// // PrimLoopConditionOp //===----------------------------------------------------------------------===// MutableOperandRange PrimLoopConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) { // Pass all operands except the condition to the successor which is the // parent loop op. return getIterArgsMutable(); } //===----------------------------------------------------------------------===// // PrimIfOp //===----------------------------------------------------------------------===// ParseResult PrimIfOp::parse(OpAsmParser &parser, OperationState &result) { // Create the regions. result.regions.reserve(2); Region *thenRegion = result.addRegion(); Region *elseRegion = result.addRegion(); auto &builder = parser.getBuilder(); OpAsmParser::UnresolvedOperand cond; Type boolType = builder.getType(); if (parser.parseOperand(cond) || parser.resolveOperand(cond, boolType, result.operands)) return failure(); // Parse results type list. if (parser.parseArrowTypeList(result.types)) return failure(); // Parse the 'then' region. if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{})) return failure(); // Parse the 'else' region. if (parser.parseKeyword("else")) return failure(); if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{})) return failure(); // Parse the optional attribute list. if (parser.parseOptionalAttrDict(result.attributes)) return failure(); return success(); } void PrimIfOp::print(OpAsmPrinter &p) { p << " " << getCondition(); p << " -> (" << getResultTypes() << ") "; p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false); p << " else "; p.printRegion(getElseRegion(), /*printEntryBlockArgs=*/false); p.printOptionalAttrDict((*this)->getAttrs()); } void PrimIfOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { // The `then` and the `else` region branch back to the parent operation. if (point.getRegionOrNull()) { regions.push_back(RegionSuccessor(getResults())); return; } // If the condition is constant, we can give a more precise answer. bool condition; if (matchPattern(getCondition(), m_TorchConstantBool(&condition))) { Region *executedRegion = condition ? &getThenRegion() : &getElseRegion(); regions.push_back(RegionSuccessor(executedRegion)); return; } // If the condition isn't constant, both regions may be executed. regions.push_back(RegionSuccessor(&getThenRegion())); regions.push_back(RegionSuccessor(&getElseRegion())); return; } /// Replaces the given op with the contents of the given single-block region, /// using the operands of the block terminator to replace operation results. static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs = {}) { assert(llvm::hasSingleElement(region) && "expected single-region block"); Block *block = ®ion.front(); Operation *terminator = block->getTerminator(); ValueRange results = terminator->getOperands(); rewriter.inlineBlockBefore(block, op, blockArgs); rewriter.replaceOp(op, results); rewriter.eraseOp(terminator); } void PrimIfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { // If the condition is constant, delete the dead branch and inline the live // branch. patterns.add(+[](PrimIfOp op, PatternRewriter &rewriter) { auto constantBool = op.getCondition().getDefiningOp(); if (!constantBool) return rewriter.notifyMatchFailure(op, "non-constant condition"); replaceOpWithRegion( rewriter, op, constantBool.getValue() ? op.getThenRegion() : op.getElseRegion()); return success(); }); // If the thenRegion and elseRegion yield the same Value's, then use those // directly. patterns.add(+[](PrimIfOp op, PatternRewriter &rewriter) { auto trueTerminator = op.getThenRegion().front().getTerminator(); auto falseTerminator = op.getElseRegion().front().getTerminator(); bool madeChange = false; SmallVector resultsToErase; for (auto t : llvm::zip(trueTerminator->getOperands(), falseTerminator->getOperands(), op->getResults())) { auto trueVal = std::get<0>(t); auto falseVal = std::get<1>(t); auto resultToBeReplaced = std::get<2>(t); if (trueVal == falseVal) { madeChange |= !resultToBeReplaced.use_empty(); resultToBeReplaced.replaceAllUsesWith(trueVal); } } // We leave it up to a separate pattern (not yet implemented) to erase the // results that are now dead. That transformation is independently useful, // and also pretty tricky to implement because it changes the number of // results. return success(madeChange); }); // Erase any dead results. patterns.add(+[](PrimIfOp op, PatternRewriter &rewriter) { llvm::BitVector resultsToErase(op.getNumResults()); for (auto result : llvm::enumerate(op->getResults())) { if (result.value().use_empty()) resultsToErase.set(result.index()); } // If no results have uses and there are no side effects, just erase the op. // Approximate the body having no side effects by checking if it is just a // terminator. // Note: We don't want to make this logic too fancy, because in general, // checking for recursive side effects can result in a quadratic amount of // work (N nested If's each resulting in O(N) work). It should probably be // split into its own pattern if we want to make it fancier. if (resultsToErase.all() && llvm::hasSingleElement(op.getThenRegion().front()) && llvm::hasSingleElement(op.getElseRegion().front())) { rewriter.eraseOp(op); return success(); } // If there are no results to erase, we're done. if (!resultsToErase.any()) return failure(); SmallVector newResultTypes; for (int i = 0, e = op->getNumResults(); i < e; ++i) { if (resultsToErase[i]) continue; newResultTypes.push_back(op->getResult(i).getType()); } auto newIf = rewriter.create(op->getLoc(), newResultTypes, op.getCondition()); rewriter.inlineRegionBefore(op.getThenRegion(), newIf.getThenRegion(), newIf.getThenRegion().end()); rewriter.inlineRegionBefore(op.getElseRegion(), newIf.getElseRegion(), newIf.getElseRegion().end()); newIf.getThenRegion().front().getTerminator()->eraseOperands(resultsToErase); newIf.getElseRegion().front().getTerminator()->eraseOperands(resultsToErase); SmallVector replacementValues; for (int i = 0, e = op->getNumResults(), nextNewValue = 0; i < e; ++i) { if (resultsToErase[i]) replacementValues.push_back(nullptr); else replacementValues.push_back(newIf->getResult(nextNewValue++)); } rewriter.replaceOp(op, replacementValues); return success(); }); } //===----------------------------------------------------------------------===// // RuntimeAssertOp //===----------------------------------------------------------------------===// void RuntimeAssertOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](RuntimeAssertOp op, PatternRewriter &rewriter) { bool value; if (!matchPattern(op.getCondition(), m_TorchConstantBool(&value))) return failure(); if (value) { rewriter.eraseOp(op); return success(); } // Even if the condition is statically false, the assert might never be // executed. return failure(); }); } //===----------------------------------------------------------------------===// // DerefineOp //===----------------------------------------------------------------------===// bool DerefineOp::areCastCompatible(mlir::TypeRange inputs, mlir::TypeRange outputs) { return isValidSubtype(inputs[0], outputs[0]); } OpFoldResult DerefineOp::fold(FoldAdaptor adaptor) { auto uncheckedCast = getOperand().getDefiningOp(); if (!uncheckedCast) return nullptr; if (uncheckedCast.getOperand().getType() == getType()) return uncheckedCast.getOperand(); return nullptr; } void DerefineOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](DerefineOp op, PatternRewriter &rewriter) { bool madeChange = false; for (OpOperand &use : llvm::make_early_inc_range(op->getUses())) { if (use.getOwner()->hasTrait()) { use.set(op.getOperand()); madeChange = true; } } return success(madeChange); }); } static OpFoldResult atenIsOrIsNotFoldHelper(Operation *op, bool equalIsTrue) { Value lhs = op->getOperand(0); Value rhs = op->getOperand(1); // Look through DerefineOp's to get more refined static information. if (auto derefine = lhs.getDefiningOp()) lhs = derefine.getOperand(); if (auto derefine = rhs.getDefiningOp()) rhs = derefine.getOperand(); Type lhsType = lhs.getType(); Type rhsType = rhs.getType(); // If either type is a NoneType, make it be the lhsType. if (rhsType.isa()) { std::swap(lhsType, rhsType); std::swap(lhs, rhs); } // For now, check a few specific cases. // If both types are the singleton `!torch.none` type, then we don't even need // to look at the values. if (lhsType.isa() && rhsType.isa()) return IntegerAttr::get(IntegerType::get(op->getContext(), 1), equalIsTrue); // If neither type is a subtype of the other, then the result is false. // TODO: Implement and use subtype infra for this. // For now, check a specific case. // If the rhs is not OptionalType, then we know it cannot be None. if (lhsType.isa() && !rhsType.isa()) { return IntegerAttr::get(IntegerType::get(op->getContext(), 1), !equalIsTrue); } return nullptr; } //===----------------------------------------------------------------------===// // Aten__RangeLengthOp //===----------------------------------------------------------------------===// OpFoldResult Aten__RangeLengthOp::fold(FoldAdaptor adaptor) { auto lo = adaptor.getLo(); auto hi = adaptor.getHi(); auto step = adaptor.getStep(); if (!lo || !hi || !step) return nullptr; auto loInt = lo.dyn_cast_or_null().getValue(); auto hiInt = hi.dyn_cast_or_null().getValue(); auto stepInt = step.dyn_cast_or_null().getValue(); // TODO: Implement folding for negative steps. if (stepInt.isNegative()) return nullptr; // From Python language spec: // r[i] = lo + step*i such that i >= 0 and r[i] < hi // So maximize `i` such that lo + step * i < hi // ==> i == ceildiv(hi - lo, step) return IntegerAttr::get(lo.cast().getType(), llvm::APIntOps::RoundingSDiv(hiInt - loInt, stepInt, APInt::Rounding::UP)); } //===----------------------------------------------------------------------===// // Aten__DeriveIndexOp //===----------------------------------------------------------------------===// OpFoldResult Aten__DeriveIndexOp::fold(FoldAdaptor adaptor) { auto index = adaptor.getIndex(); auto start = adaptor.getStart(); auto step = adaptor.getStep(); if (!index || !start || !step) return nullptr; auto indexInt = index.dyn_cast_or_null().getValue(); auto startInt = start.dyn_cast_or_null().getValue(); auto stepInt = step.dyn_cast_or_null().getValue(); return IntegerAttr::get(index.cast().getType(), startInt + stepInt * indexInt); } //===----------------------------------------------------------------------===// // Aten__Is__Op //===----------------------------------------------------------------------===// OpFoldResult Aten__Is__Op::fold(FoldAdaptor adaptor) { return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/true); } //===----------------------------------------------------------------------===// // Aten__Isnot__Op //===----------------------------------------------------------------------===// OpFoldResult Aten__Isnot__Op::fold(FoldAdaptor adaptor) { return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/false); } //===----------------------------------------------------------------------===// // Aten__Not__Op //===----------------------------------------------------------------------===// OpFoldResult Aten__Not__Op::fold(FoldAdaptor adaptor) { bool value; if (!matchPattern(getOperand(), m_TorchConstantBool(&value))) return nullptr; return IntegerAttr::get(IntegerType::get(getContext(), 1), !value); } //===----------------------------------------------------------------------===// // AtenNeBoolOp //===----------------------------------------------------------------------===// OpFoldResult AtenNeBoolOp::fold(FoldAdaptor adaptor) { if (getOperand(0) == getOperand(1)) return IntegerAttr::get(IntegerType::get(getContext(), 1), false); bool a, b; if (!matchPattern(getOperand(0), m_TorchConstantBool(&a))) return nullptr; if (!matchPattern(getOperand(1), m_TorchConstantBool(&b))) return nullptr; return IntegerAttr::get(IntegerType::get(getContext(), 1), a != b); } //===----------------------------------------------------------------------===// // AtenSqueezeOp //===----------------------------------------------------------------------===// OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) { if (auto tensorType = getOperand().getType().dyn_cast()) { if (tensorType.hasSizes() && tensorType.getSizes().size() == 0) return getOperand(); } return nullptr; } //===----------------------------------------------------------------------===// // AtenSqueezeDimOp //===----------------------------------------------------------------------===// OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) { if (auto tensorType = getOperand(0).getType().dyn_cast()) { if (tensorType.hasSizes() && tensorType.getSizes().size() == 0) return getOperand(0); } return nullptr; } //===----------------------------------------------------------------------===// // AtenRoundOp //===----------------------------------------------------------------------===// OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) { if (auto selfType = getSelf().getType().dyn_cast()) { if (selfType.hasDtype() && selfType.getDtype().isa()) return getSelf(); } return nullptr; } //===----------------------------------------------------------------------===// // AtenToDtypeOp //===----------------------------------------------------------------------===// OpFoldResult AtenToDtypeOp::fold(FoldAdaptor adaptor) { bool nonBlocking, copyArg; // The non_blocking arg must be `False`. if (!matchPattern(getNonBlocking(), m_TorchConstantBool(&nonBlocking)) || nonBlocking) return nullptr; // The copy arg must be `False`. if (!matchPattern(getCopy(), m_TorchConstantBool(©Arg)) || copyArg) return nullptr; // The memory_format arg must be `none`. if (!getMemoryFormat().getType().isa()) return nullptr; auto inputType = getSelf().getType().cast(); auto resType = getType().cast(); // If the types aren't equal, then we can't fold. if (inputType != resType) return nullptr; // If the type does not have a statically known dtype, then we cannot fold. // For example, folding `tensor<*,unk>` to `tensor<*,unk>` would be wrong, // since the `unk` could be dynamically different for the operand and result. if (!inputType.hasDtype()) return nullptr; // Fold when both the input tensor and result are of the same type. return getOperand(0); } //===----------------------------------------------------------------------===// // AtenToDtypeLayoutOp //===----------------------------------------------------------------------===// OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) { // The pin_memory arg should be either constant `False` or `none`. if (!getPinMemory().getType().isa()) { bool pinMemory; if (!matchPattern(getPinMemory(), m_TorchConstantBool(&pinMemory))) return nullptr; else if (pinMemory) return nullptr; } // The non_blocking arg should be constant `False`. bool nonBlocking; if (!matchPattern(getNonBlocking(), m_TorchConstantBool(&nonBlocking))) return nullptr; else if (nonBlocking) return nullptr; // The copy arg should be constant `False`. bool copyArg; if (!matchPattern(getCopy(), m_TorchConstantBool(©Arg))) return nullptr; else if (copyArg) return nullptr; // The device arg must be `none`. if (!getDevice().getType().isa()) return nullptr; // The memory_format arg must be `none`. if (!getMemoryFormat().getType().isa()) return nullptr; auto inputType = getSelf().getType().cast(); auto resType = getType().cast(); // If the types aren't equal, then we can't fold. if (inputType != resType) return nullptr; // If the type does not have a statically known dtype, then we cannot fold. // For example, folding `tensor<*,unk>` to `tensor<*,unk>` would be wrong, // since the `unk` could be dynamically different for the operand and result. if (!inputType.hasDtype()) return nullptr; // The layout arg should be either `none` or `0` i.e. strided. if (!getLayout().getType().isa()) { int64_t tensorLayout; if (!matchPattern(getLayout(), m_TorchConstantInt(&tensorLayout))) return nullptr; else if (tensorLayout != torch_upstream::Layout::Strided) return nullptr; } // Fold when both the input tensor and result are of the same type and the // layout arg is strided. return getOperand(0); } void AtenToDtypeLayoutOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { // `to.dtype_layout` -> `to.device/to.dtype` if layout is none and pin memory // is false patterns.add(+[](AtenToDtypeLayoutOp op, PatternRewriter &rewriter) { // The pin_memory arg should be either constant `False` or `none`. if (!op.getPinMemory().getType().isa()) { bool pinMemory; if (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory))) return failure(); else if (pinMemory) return failure(); } // The layout arg should be either `none` or `0` i.e. strided. if (!op.getLayout().getType().isa()) { int64_t tensorLayout; if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout))) return failure(); else if (tensorLayout != torch_upstream::Layout::Strided) return failure(); } if (op.getDevice().getType().isa()) { // The device arg is `none`. Rewrite to to.dtype. AtenToDtypeOp toDtype = rewriter.create( op.getLoc(), op.getType(), op.getSelf(), op.getDtype(), op.getNonBlocking(), op.getCopy(), op.getMemoryFormat()); rewriter.replaceOp(op, toDtype->getResults()); } else { // The device arg is not `none`. Rewrite to to.device. AtenToDeviceOp toDevice = rewriter.create( op.getLoc(), op.getType(), op.getSelf(), op.getDevice(), op.getDtype(), op.getNonBlocking(), op.getCopy(), op.getMemoryFormat()); rewriter.replaceOp(op, toDevice->getResults()); } return success(); }); } //===----------------------------------------------------------------------===// // AtenToOtherOp //===----------------------------------------------------------------------===// void AtenToOtherOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { // Canonicalize `aten.to.other` to `aten.to.device` patterns.add(+[](AtenToOtherOp op, PatternRewriter &rewriter) { auto lhs = op.getSelf(); auto rhs = op.getOther(); auto getRhsDevice = rewriter.create(op.getLoc(), rhs); auto getRhsDtype = rewriter.create(op.getLoc(), rhs); rewriter.replaceOpWithNewOp( op, op.getType(), lhs, getRhsDevice.getResult(), getRhsDtype.getResult(), op.getNonBlocking(), op.getCopy(), op.getMemoryFormat()); return success(); }); } //===----------------------------------------------------------------------===// // AtenViewOp //===----------------------------------------------------------------------===// OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) { auto inputType = getOperand(0).getType().dyn_cast(); if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 1) return nullptr; auto resType = getType().dyn_cast(); if (!resType || !resType.hasSizes() || resType.getSizes().size() != 1) return nullptr; // Fold when both the input tensor and result are unity rank tensors. return getOperand(0); } //===----------------------------------------------------------------------===// // PrimsViewOfOp //===----------------------------------------------------------------------===// OpFoldResult PrimsViewOfOp::fold(FoldAdaptor adaptor) { // Always fold the op with its only input operand. return getOperand(); } //===----------------------------------------------------------------------===// // AtenDimOp //===----------------------------------------------------------------------===// OpFoldResult AtenDimOp::fold(FoldAdaptor adaptor) { if (auto tensorType = getOperand().getType().dyn_cast()) { if (tensorType.hasSizes()) return IntegerAttr::get(IntegerType::get(getContext(), 64), tensorType.getSizes().size()); } return nullptr; } //===----------------------------------------------------------------------===// // AtenLenTOp //===----------------------------------------------------------------------===// OpFoldResult AtenLenTOp::fold(FoldAdaptor adaptor) { // `len([1,1,1])` -> `3`, if it is not mutated. if (auto listConstruct = getOperand().getDefiningOp()) { if (!isListPotentiallyMutated(listConstruct)) { return IntegerAttr::get(IntegerType::get(getContext(), 64), listConstruct.getNumOperands()); } } return nullptr; } void AtenLenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { // `len(t.size())` -> `t.ndim` patterns.add(+[](AtenLenTOp op, PatternRewriter &rewriter) { auto size = op.getOperand().getDefiningOp(); if (!size) return rewriter.notifyMatchFailure(op, "operand not AtenSizeOp"); rewriter.replaceOpWithNewOp(op, size.getOperand()); return success(); }); } //===----------------------------------------------------------------------===// // AtenMinOtherOp //===----------------------------------------------------------------------===// void AtenMinOtherOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { // `aten.min.other` -> `aten.minimum` patterns.add(+[](AtenMinOtherOp op, PatternRewriter &rewriter) { rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), op.getOther()); return success(); }); } //===----------------------------------------------------------------------===// // AtenMaxOtherOp //===----------------------------------------------------------------------===// void AtenMaxOtherOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { // `aten.max.other` -> `aten.maximum` patterns.add(+[](AtenMaxOtherOp op, PatternRewriter &rewriter) { rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), op.getOther()); return success(); }); } //===----------------------------------------------------------------------===// // AtenLenStrOp //===----------------------------------------------------------------------===// OpFoldResult AtenLenStrOp::fold(FoldAdaptor adaptor) { if (auto stringConstruct = getS().getDefiningOp()) return getI64IntegerAttr(getContext(), stringConstruct.getValueAttr().getValue().size()); return nullptr; } LogicalResult rewrite0DBinaryTensorOp(Operation *op, PatternRewriter &rewriter) { Location loc = op->getLoc(); // This canonicalization pattern also includes aten div/mul/add/sub ops // between tensor and scalar, like aten.add.Scalar op if (op->getNumOperands() < 2) { return failure(); } auto lhs = getScalarIntValue(op->getOperand(0), loc, rewriter); auto rhs = getScalarIntValue(op->getOperand(1), loc, rewriter); auto outType = op->getResult(0).getType(); if (!lhs || !rhs) { return rewriter.notifyMatchFailure( op, "only int scalar lhs or rhs is supported"); } if (isa(op)) { Value alpha = getScalarIntValue(op->getOperand(2), loc, rewriter); if (!alpha) { return rewriter.notifyMatchFailure(op, "only int scalar alpha is supported"); } if (isa(op)) lhs = rewriter.create(loc, lhs, alpha); else rhs = rewriter.create(loc, rhs, alpha); } if (isa(op)) { // None rounding mode if (op->getOperand(2).getType().isa()) { Value quotient = rewriter.create(loc, lhs, rhs); rewriter.replaceOpWithNewOp(op, outType, quotient); return success(); } std::string roundingMode; if (!matchPattern(op->getOperand(2), m_TorchConstantStr(roundingMode))) { return rewriter.notifyMatchFailure( op, "only None, 'floor' or 'trunc' rounding mode is supported"); } if (roundingMode == "floor") { Value quotient = rewriter.create(loc, lhs, rhs); rewriter.replaceOpWithNewOp(op, outType, quotient); return success(); } // For "trunc" rounding mode, insted of canonicalizing it into // aten.abs, aten.floor, aten.sign and aten.mul.int ops, which adds // complexity but helps little in optimization (such as constant folding), // we are trying to fold it. if (roundingMode == "trunc") { int64_t lhsInt; int64_t rhsInt; if (!matchPattern(lhs, m_TorchConstantInt(&lhsInt))) { return failure(); } if (!matchPattern(rhs, m_TorchConstantInt(&rhsInt))) { return failure(); } int64_t result = (int64_t)std::trunc((double)lhsInt / rhsInt); Value resultScalar = rewriter.create( loc, rewriter.getI64IntegerAttr(result)); rewriter.replaceOpWithNewOp(op, outType, resultScalar); return success(); } return failure(); } Value result; // Other Add/Sub/Mul ops if (isa(op)) { result = rewriter.create(loc, lhs, rhs); } else if (isa(op)) { result = rewriter.create(loc, lhs, rhs); } else if (isa(op)) { result = rewriter.create(loc, rhs, lhs); } else if (isa(op)) { result = rewriter.create(loc, lhs, rhs); } rewriter.replaceOpWithNewOp(op, outType, result); return success(); } //===----------------------------------------------------------------------===// // AtenAddTensorOp //===----------------------------------------------------------------------===// void AtenAddTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenAddTensorOp op, PatternRewriter &rewriter) { return rewrite0DBinaryTensorOp(op, rewriter); }); } //===----------------------------------------------------------------------===// // AtenAddScalarOp //===----------------------------------------------------------------------===// void AtenAddScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenAddScalarOp op, PatternRewriter &rewriter) { return rewrite0DBinaryTensorOp(op, rewriter); }); } //===----------------------------------------------------------------------===// // AtenSubTensorOp //===----------------------------------------------------------------------===// void AtenSubTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenSubTensorOp op, PatternRewriter &rewriter) { return rewrite0DBinaryTensorOp(op, rewriter); }); } //===----------------------------------------------------------------------===// // AtenSubScalarOp //===----------------------------------------------------------------------===// void AtenSubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenSubScalarOp op, PatternRewriter &rewriter) { return rewrite0DBinaryTensorOp(op, rewriter); }); } //===----------------------------------------------------------------------===// // AtenRSubScalarOp //===----------------------------------------------------------------------===// void AtenRsubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenRsubScalarOp op, PatternRewriter &rewriter) { return rewrite0DBinaryTensorOp(op, rewriter); }); } //===----------------------------------------------------------------------===// // AtenMulTensorOp //===----------------------------------------------------------------------===// void AtenMulTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenMulTensorOp op, PatternRewriter &rewriter) { return rewrite0DBinaryTensorOp(op, rewriter); }); } //===----------------------------------------------------------------------===// // AtenMulScalarOp //===----------------------------------------------------------------------===// void AtenMulScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenMulScalarOp op, PatternRewriter &rewriter) { return rewrite0DBinaryTensorOp(op, rewriter); }); } //===----------------------------------------------------------------------===// // AtenDivTensorModeOp //===----------------------------------------------------------------------===// void AtenDivTensorModeOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenDivTensorModeOp op, PatternRewriter &rewriter) { return rewrite0DBinaryTensorOp(op, rewriter); }); } //===----------------------------------------------------------------------===// // Aten__Or__TensorOp //===----------------------------------------------------------------------===// void Aten__Or__TensorOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](Aten__Or__TensorOp op, PatternRewriter &rewriter) { rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), op.getOther()); return success(); }); } //===----------------------------------------------------------------------===// // AtenScalarImplicitOp //===----------------------------------------------------------------------===// void AtenScalarImplicitOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenScalarImplicitOp op, PatternRewriter &rewriter) { Location loc = op.getLoc(); Value a = op.getA(); auto outType = op.getResult().getType(); Value scalarValue = getScalarIntValue(a, loc, rewriter); if (!scalarValue) return failure(); rewriter.replaceOpWithNewOp(op, outType, scalarValue); return success(); }); } //===----------------------------------------------------------------------===// // AtenSizeOp //===----------------------------------------------------------------------===// // Traces at most 6 parents of `value` to determine the tensor type with known // dimension size or returns failure if such a type was not found. If `dim` is // `None`, then all dimension's sizes must be known. static FailureOr traceKnownSizeTensorType(Value value, std::optional dim) { // Function to check if we found a type that contains the queried information. auto foundType = [](BaseTensorType tensorType, std::optional(dim)) { if (!tensorType.hasSizes()) return false; if (dim == std::nullopt) return tensorType.areAllSizesKnown(); // If the dimension value is negative, then convert it to a positive value. ArrayRef sizes = tensorType.getSizes(); *dim = toPositiveDim(*dim, sizes.size()); return isValidDim(*dim, sizes.size()) && sizes[*dim] != kUnknownSize; }; // Limit the loop count to 6 to avoid indefinite compilation times from // unbounded IR traversals. for (auto idx = 0; idx < 6; ++idx) { if (!value || !value.getType().isa()) return failure(); auto tensorType = value.getType().cast(); if (foundType(tensorType, dim)) return tensorType; auto op = value.getDefiningOp(); if (!op || !isa(op)) return failure(); // In all ops of interest to us, the source tensor is operand #0. value = op->getOperand(0); } return failure(); } void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenSizeOp op, PatternRewriter &rewriter) { auto type = traceKnownSizeTensorType(op.getOperand(), std::nullopt); if (failed(type)) return rewriter.notifyMatchFailure(op, "all sizes not known"); SmallVector listElements; for (int64_t size : type->getSizes()) { listElements.push_back(rewriter.create( op->getLoc(), rewriter.getI64IntegerAttr(size))); } rewriter.replaceOpWithNewOp( op, Torch::ListType::get(rewriter.getType()), listElements); return success(); }); // One-off pattern to erase if dead. // TODO: Use the effects infra to express the semantics of this op and enable // a centralized "erase if dead" canonicalization. // Specifically, we need to mark the op as only MemoryEffects::Allocate // so that `mlir::wouldOpBeTriviallyDead` does the right thing. patterns.add(+[](AtenSizeOp op, PatternRewriter &rewriter) { if (!op.use_empty()) return failure(); rewriter.eraseOp(op); return failure(); }); } //===----------------------------------------------------------------------===// // AtenSizeIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenSizeIntOp::fold(FoldAdaptor adaptor) { int64_t dim; if (!matchPattern(this->getDim(), m_TorchConstantInt(&dim))) return nullptr; auto type = traceKnownSizeTensorType(this->getSelf(), dim); if (failed(type)) return nullptr; ArrayRef sizes = type->getSizes(); dim = toPositiveDim(dim, sizes.size()); if (!isValidDim(dim, sizes.size())) return nullptr; return IntegerAttr::get(IntegerType::get(getContext(), 64), sizes[dim]); } //===----------------------------------------------------------------------===// // AtenGtIntOp //===----------------------------------------------------------------------===// static IntegerAttr getI1IntegerAttr(MLIRContext *context, bool value) { return IntegerAttr::get(IntegerType::get(context, 1), static_cast(value)); } using ConstantFloatComparator = std::function; template static OpFoldResult floatComparatorFoldHelper(OpTy op, ConstantFloatComparator comparator) { if (op.getOperand(0) == op.getOperand(1)) return getI1IntegerAttr(op.getContext(), comparator(0, 0)); double lhs, rhs; if (!matchPattern(op.getOperand(0), m_TorchConstantFloat(&lhs)) || !matchPattern(op.getOperand(1), m_TorchConstantFloat(&rhs))) return nullptr; return getI1IntegerAttr(op.getContext(), comparator(lhs, rhs)); } //===----------------------------------------------------------------------===// // AtenLtFloatOp //===----------------------------------------------------------------------===// OpFoldResult AtenLtFloatOp::fold(FoldAdaptor adaptor) { return floatComparatorFoldHelper(*this, [](double a, double b) { return a < b; }); } //===----------------------------------------------------------------------===// // AtenGtFloatOp //===----------------------------------------------------------------------===// OpFoldResult AtenGtFloatOp::fold(FoldAdaptor adaptor) { return floatComparatorFoldHelper(*this, [](double a, double b) { return a > b; }); } //===----------------------------------------------------------------------===// // AtenGeFloatOp //===----------------------------------------------------------------------===// OpFoldResult AtenGeFloatOp::fold(FoldAdaptor adaptor) { return floatComparatorFoldHelper(*this, [](double a, double b) { return a >= b; }); } //===----------------------------------------------------------------------===// // AtenEqFloatOp //===----------------------------------------------------------------------===// OpFoldResult AtenEqFloatOp::fold(FoldAdaptor adaptor) { return floatComparatorFoldHelper(*this, [](double a, double b) { return a == b; }); } using ConstantIntComparator = std::function; template static OpFoldResult intComparatorFoldHelper(OpTy op, ConstantIntComparator comparator) { Value lhsValue = op->getOperand(0); Value rhsValue = op->getOperand(1); if (lhsValue == rhsValue) return getI1IntegerAttr(op.getContext(), comparator(0, 0)); int64_t lhs, rhs; bool lhsIsConstant = matchPattern(lhsValue, m_TorchConstantInt(&lhs)); bool rhsIsConstant = matchPattern(rhsValue, m_TorchConstantInt(&rhs)); if (lhsIsConstant && rhsIsConstant) return getI1IntegerAttr(op.getContext(), comparator(lhs, rhs)); // Ensure that if there is a constant, it is on the right. if (lhsIsConstant && !rhsIsConstant) { std::swap(lhs, rhs); std::swap(lhsValue, rhsValue); std::swap(lhsIsConstant, rhsIsConstant); auto newComparator = [comparator](int64_t lhs, int64_t rhs) { return comparator(rhs, lhs); }; comparator = newComparator; } // Fold comparisons of AtenSizeIntOp against negative values. // AtenSizeIntOp is known to always be non-negative. if (rhsIsConstant && rhs < 0) { // We can return `comparator(0, -1)` here because of the property: // If x >= 0 && y < 0, then: // - cmp(x, y) == cmp(x + 1, y) // - cmp(x, y) == cmp(x, y - 1) // By induction all cases here are covered. if (auto size = lhsValue.getDefiningOp()) return getI1IntegerAttr(op->getContext(), comparator(0, -1)); } // Fold comparisons of AtenSizeIntOp against 0: // - torch.aten.size.int >= 0 ==> True. // - torch.aten.size.int < 0 ==> False. // (and the operand-swapped versions of the above) if (rhsIsConstant && rhs == 0) { if (auto size = lhsValue.getDefiningOp()) { // >= 0 comparison. if (comparator(0, 0) && comparator(1, 0)) return getI1IntegerAttr(op->getContext(), true); // < 0 comparison. if (!comparator(0, 0) && comparator(-1, 0) && !comparator(1, 0)) return getI1IntegerAttr(op->getContext(), false); } } return nullptr; } //===----------------------------------------------------------------------===// // AtenDetachOp //===----------------------------------------------------------------------===// OpFoldResult AtenDetachOp::fold(FoldAdaptor adaptor) { return getSelf(); } //===----------------------------------------------------------------------===// // AtenNeIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenNeIntOp::fold(FoldAdaptor adaptor) { return intComparatorFoldHelper(*this, [](int64_t a, int64_t b) { return a != b; }); } //===----------------------------------------------------------------------===// // AtenEqIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenEqIntOp::fold(FoldAdaptor adaptor) { return intComparatorFoldHelper(*this, [](int64_t a, int64_t b) { return a == b; }); } //===----------------------------------------------------------------------===// // AtenEqStrOp //===----------------------------------------------------------------------===// OpFoldResult AtenEqStrOp::fold(FoldAdaptor adaptor) { if (getOperand(0) == getOperand(1)) return getI1IntegerAttr(getContext(), true); auto aStr = getA().getDefiningOp(); auto bStr = getB().getDefiningOp(); if (aStr && bStr) return getI1IntegerAttr(getContext(), aStr == bStr); return nullptr; } //===----------------------------------------------------------------------===// // AtenLtIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenLtIntOp::fold(FoldAdaptor adaptor) { return intComparatorFoldHelper(*this, [](int64_t a, int64_t b) { return a < b; }); } //===----------------------------------------------------------------------===// // AtenLeIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenLeIntOp::fold(FoldAdaptor adaptor) { return intComparatorFoldHelper(*this, [](int64_t a, int64_t b) { return a <= b; }); } //===----------------------------------------------------------------------===// // AtenGtIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenGtIntOp::fold(FoldAdaptor adaptor) { return intComparatorFoldHelper(*this, [](int64_t a, int64_t b) { return a > b; }); } //===----------------------------------------------------------------------===// // AtenGeIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenGeIntOp::fold(FoldAdaptor adaptor) { return intComparatorFoldHelper(*this, [](int64_t a, int64_t b) { return a >= b; }); } //===----------------------------------------------------------------------===// // AtenBoolFloatOp //===----------------------------------------------------------------------===// OpFoldResult AtenBoolFloatOp::fold(FoldAdaptor adaptor) { double c; if (matchPattern(getOperand(), m_TorchConstantFloat(&c))) return getI1IntegerAttr(getContext(), c != 0.0); return nullptr; } //===----------------------------------------------------------------------===// // AtenBoolIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenBoolIntOp::fold(FoldAdaptor adaptor) { int64_t c; if (matchPattern(getOperand(), m_TorchConstantInt(&c))) return getI1IntegerAttr(getContext(), c != 0); return nullptr; } //===----------------------------------------------------------------------===// // AtenAnyBoolOp //===----------------------------------------------------------------------===// OpFoldResult AtenAnyBoolOp::fold(FoldAdaptor adaptor) { auto inputConstruct = getSelf().getDefiningOp(); if (!inputConstruct || isListPotentiallyMutated(inputConstruct)) return nullptr; // If any operand is a constant true, return true. for (auto operand : inputConstruct.getOperands()) { bool b = false; if (matchPattern(operand, m_TorchConstantBool(&b)) && b) { return getI1IntegerAttr(getContext(), true); } } return nullptr; } //===----------------------------------------------------------------------===// // AtenFloatScalarOp //===----------------------------------------------------------------------===// OpFoldResult AtenFloatScalarOp::fold(FoldAdaptor adaptor) { // Constant fold int -> float conversion. if (auto integerAttr = adaptor.getA().dyn_cast_or_null()) { return FloatAttr::get( mlir::Float64Type::get(getContext()), static_cast(integerAttr.getValue().getSExtValue())); } // If the input is float type already, the op is an identity. if (getType() == getOperand().getType()) return getOperand(); return nullptr; } //===----------------------------------------------------------------------===// // AtenIntFloatOp //===----------------------------------------------------------------------===// OpFoldResult AtenIntFloatOp::fold(FoldAdaptor adaptor) { // Constant fold float -> int conversion. if (auto floatAttr = adaptor.getA().dyn_cast_or_null()) { return IntegerAttr::get( mlir::IntegerType::get(getContext(), 64), static_cast(floatAttr.getValue().convertToDouble())); } return nullptr; } //===----------------------------------------------------------------------===// // AtenIntScalarOp //===----------------------------------------------------------------------===// OpFoldResult AtenIntScalarOp::fold(FoldAdaptor adaptor) { // Constant fold float -> int conversion. if (auto floatAttr = adaptor.getA().dyn_cast_or_null()) { return IntegerAttr::get( mlir::IntegerType::get(getContext(), 64), static_cast(floatAttr.getValue().convertToDouble())); } // If the input is int type already, the op is an identity. if (getType() == getOperand().getType()) return getOperand(); return nullptr; } //===----------------------------------------------------------------------===// // AtenIntBoolOp //===----------------------------------------------------------------------===// OpFoldResult AtenIntBoolOp::fold(FoldAdaptor adaptor) { bool b; if (matchPattern(getOperand(), m_TorchConstantBool(&b))) { return getI64IntegerAttr(getContext(), static_cast(b)); } return nullptr; } //===----------------------------------------------------------------------===// // AtenSortIntOp //===----------------------------------------------------------------------===// void AtenSortIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenSortIntOp op, PatternRewriter &rewriter) { SmallVector listElements; if (!matchPattern(op.getSelf(), m_TorchListOfConstantInts(listElements))) return rewriter.notifyMatchFailure( op, "all input list elements must be constant ints"); bool reverse; if (!matchPattern(op.getReverse(), m_TorchConstantBool(&reverse))) return rewriter.notifyMatchFailure( op, "Expected reverse arg to be constant bool."); std::sort(listElements.begin(), listElements.end()); if (reverse) std::reverse(listElements.begin(), listElements.end()); SmallVector sortedListElements; for (int64_t elem : listElements) sortedListElements.push_back(rewriter.create( op->getLoc(), rewriter.getI64IntegerAttr(elem))); Value result = rewriter.create( op->getLoc(), Torch::ListType::get(rewriter.getType()), sortedListElements); op.getSelf().replaceAllUsesWith(result); rewriter.eraseOp(op); return success(); }); } //===----------------------------------------------------------------------===// // NonValueTensorLiteralOp //===----------------------------------------------------------------------===// LogicalResult NonValueTensorLiteralOp::inferReturnTypes( MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { auto attr = properties.as() ->getValue() .dyn_cast_or_null(); if (!attr) return failure(); RankedTensorType tensorType = attr.getType().cast(); NonValueTensorType returnType = NonValueTensorType::get(tensorType.getContext(), tensorType.getShape(), tensorType.getElementType()); inferredReturnTypes.push_back(returnType); return success(); } static bool areSizesAndDtypesCompatible(BaseTensorType a, BaseTensorType b) { if (a.hasSizes() && b.hasSizes()) { if (failed(verifyCompatibleShape(makeShapeLLVMCompatible(a.getSizes()), makeShapeLLVMCompatible(b.getSizes())))) return false; } if (a.hasDtype() && b.hasDtype()) { if (a.getDtype() != b.getDtype()) return false; } return true; } bool NonValueTensorLiteralOp::isCompatibleReturnTypes(TypeRange inferred, TypeRange actual) { if (!actual[0].isa()) return false; return areSizesAndDtypesCompatible(inferred[0].cast(), actual[0].cast()); } //===----------------------------------------------------------------------===// // ValueTensorLiteralOp //===----------------------------------------------------------------------===// LogicalResult ValueTensorLiteralOp::inferReturnTypes( MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { auto attr = properties.as() ->getValue() .dyn_cast_or_null(); if (!attr) return failure(); RankedTensorType tensorType = attr.getType().cast(); ValueTensorType returnType = ValueTensorType::get(tensorType.getContext(), tensorType.getShape(), tensorType.getElementType()); inferredReturnTypes.push_back(returnType); return success(); } OpFoldResult ValueTensorLiteralOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } //----------------------------------------------------------------------------// // TensorStaticInfoCast //----------------------------------------------------------------------------// bool TensorStaticInfoCastOp::areCastCompatible(mlir::TypeRange inputs, mlir::TypeRange outputs) { return areSizesAndDtypesCompatible(inputs[0].cast(), outputs[0].cast()); } void TensorStaticInfoCastOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](TensorStaticInfoCastOp op, PatternRewriter &rewriter) { auto reverseCast = op.getOperand().getDefiningOp(); if (!reverseCast || reverseCast.getOperand().getType() != op.getType()) return failure(); rewriter.replaceOp(op, reverseCast.getOperand()); return success(); }); patterns.add(+[](TensorStaticInfoCastOp op, PatternRewriter &rewriter) { if (isValidSubtype(op.getOperand().getType(), op.getType())) { SmallVector> usesToChange( llvm::make_filter_range(op->getUses(), [](OpOperand &operand) { return operand.getOwner() ->hasTrait(); })); if (usesToChange.empty()) return failure(); for (OpOperand &use : usesToChange) { Operation *user = use.getOwner(); user->setOperand(use.getOperandNumber(), op.getOperand()); } return success(); } return failure(); }); } //===----------------------------------------------------------------------===// // CopyToNonValueTensorOp //===----------------------------------------------------------------------===// LogicalResult CopyToNonValueTensorOp::verify() { auto resultType = getResult().getType().cast(); auto operandType = getOperand().getType().cast(); if (!resultType.hasSameSizesAndDtype(operandType)) return emitError() << "operand and result must have same sizes and dtype"; return success(); } LogicalResult CopyToNonValueTensorOp::inferReturnTypes( MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { auto resultType = operands[0].getType().cast(); inferredReturnTypes.push_back(resultType.getWithoutValueSemantics()); return success(); } void CopyToNonValueTensorOp::getEffects( SmallVectorImpl> &effects) { effects.emplace_back(MemoryEffects::Allocate::get(), getResult()); } //===----------------------------------------------------------------------===// // CopyToValueTensorOp //===----------------------------------------------------------------------===// LogicalResult CopyToValueTensorOp::verify() { auto resultType = getResult().getType().cast(); auto operandType = getOperand().getType().cast(); if (!resultType.hasSameSizesAndDtype(operandType)) return emitError() << "operand and result must have same sizes and dtype"; return success(); } LogicalResult CopyToValueTensorOp::inferReturnTypes( MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { auto resultType = operands[0].getType().cast(); inferredReturnTypes.push_back(resultType.getWithValueSemantics()); return success(); } void CopyToValueTensorOp::getEffects( SmallVectorImpl> &effects) { effects.emplace_back(MemoryEffects::Read::get(), getOperand()); } //===----------------------------------------------------------------------===// // ConstantNoneOp //===----------------------------------------------------------------------===// OpFoldResult ConstantNoneOp::fold(FoldAdaptor adaptor) { return TypeAttr::get(Torch::NoneType::get(getContext())); } void ConstantNoneOp::getAsmResultNames( function_ref setNameFn) { setNameFn(getResult(), "none"); } //===----------------------------------------------------------------------===// // ConstantStrOp //===----------------------------------------------------------------------===// OpFoldResult ConstantStrOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } void ConstantStrOp::getAsmResultNames( function_ref setNameFn) { setNameFn(getResult(), "str"); } //===----------------------------------------------------------------------===// // ConstantDeviceOp //===----------------------------------------------------------------------===// void ConstantDeviceOp::getAsmResultNames( function_ref setNameFn) { setNameFn(getResult(), getValue()); } //===----------------------------------------------------------------------===// // ConstantIntOp //===----------------------------------------------------------------------===// ParseResult ConstantIntOp::parse(OpAsmParser &parser, OperationState &result) { Builder builder(result.getContext()); result.addTypes(builder.getType()); if (parser.parseOptionalAttrDict(result.attributes)) return failure(); int64_t value; if (parser.parseInteger(value)) return failure(); result.addAttribute("value", builder.getI64IntegerAttr(value)); return success(); } void ConstantIntOp::print(OpAsmPrinter &p) { p << " "; p << getValueAttr().getInt(); p.printOptionalAttrDict((*this)->getAttrs(), {"value"}); } OpFoldResult Torch::ConstantIntOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } void Torch::ConstantIntOp::getAsmResultNames( function_ref setNameFn) { SmallVector buf; llvm::raw_svector_ostream os(buf); os << "int" << getValueAttr().getInt(); setNameFn(getResult(), os.str()); } //===----------------------------------------------------------------------===// // ConstantFloatOp //===----------------------------------------------------------------------===// OpFoldResult Torch::ConstantFloatOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } void Torch::ConstantFloatOp::getAsmResultNames( function_ref setNameFn) { // Calculate a stringified version of the number, compatible with MLIR // identifier syntax. (in practice, this just removes the '+' from 'e+' in // float string representation). SmallVector buf; getValue().toString(buf, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0, /*TruncateZero=*/false); auto isValidMLIRIdentifierChar = [](char c) { return isalpha(c) || isdigit(c) || c == '_' || c == '$' || c == '.' || c == '-'; }; auto numberStr = llvm::to_vector<16>( llvm::make_filter_range(buf, isValidMLIRIdentifierChar)); // Construct the identifier string. buf.clear(); llvm::append_range(buf, StringRef("float")); llvm::append_range(buf, numberStr); setNameFn(getResult(), StringRef(buf.data(), buf.size())); } //===----------------------------------------------------------------------===// // ConstantNumberOp //===----------------------------------------------------------------------===// OpFoldResult Torch::ConstantNumberOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } void Torch::ConstantNumberOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](Torch::ConstantNumberOp op, PatternRewriter &rewriter) { Location loc = op->getLoc(); Value constValue; Attribute value = op.getValueAttr(); if (auto floatValue = value.dyn_cast()) { constValue = rewriter.create(loc, floatValue); } else if (auto intValue = value.dyn_cast()) { constValue = rewriter.create(loc, intValue); } else { return failure(); } rewriter.replaceOpWithNewOp(op, op.getType(), constValue); return success(); }); } //===----------------------------------------------------------------------===// // ConstantBoolOp //===----------------------------------------------------------------------===// OpFoldResult Torch::ConstantBoolOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } void Torch::ConstantBoolOp::getAsmResultNames( function_ref setNameFn) { setNameFn(getResult(), getValue() ? "true" : "false"); } //===----------------------------------------------------------------------===// // PrimUncheckedCastOp //===----------------------------------------------------------------------===// bool PrimUncheckedCastOp::areCastCompatible(mlir::TypeRange inputs, mlir::TypeRange outputs) { return isValidSubtype(outputs[0], inputs[0]); } OpFoldResult PrimUncheckedCastOp::fold(FoldAdaptor adaptor) { if (auto derefineOp = getX().getDefiningOp()) { if (derefineOp.getOperand().getType() == getType()) return derefineOp.getOperand(); } return nullptr; } //===----------------------------------------------------------------------===// // Aten__Getitem__TOp //===----------------------------------------------------------------------===// void Aten__Getitem__TOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](Aten__Getitem__TOp op, PatternRewriter &rewriter) { auto torchList = op.getOperand(0); if (isListPotentiallyMutated(torchList)) return failure(); auto listConstruct = torchList.getDefiningOp(); if (!listConstruct) return failure(); // Get the index, but be careful because it might be statically invalid. std::optional indexOpt = matchLegalConstantIndexIntoListOfSize( op.getOperand(1), listConstruct.getNumOperands()); if (!indexOpt) return rewriter.notifyMatchFailure(op, "statically invalid index"); rewriter.replaceOp(op, {listConstruct.getOperand(*indexOpt)}); return success(); }); patterns.add(+[](Aten__Getitem__TOp op, PatternRewriter &rewriter) { auto sizeOp = op.getList().getDefiningOp(); if (!sizeOp) return failure(); // This assumes tht the size doesn't change between the // AtenSizeOp and the Aten__Getitem__TOp. // `t_` is the only op I can find that changes the shape in-place. It seems // like otherwise we can treat the size of a tensor as having value // semantics. The other view-like ops don't have in-place variants -- // they always return a new SSA value that is aliased to the input. // Can we have a pass to normalize the `t_` case and then elsewhere in the // compiler treat the size as having value semantics? // There's a small number of such ops, and they are marked as `inplace_view` // in PyTorch's `native_functions.yaml` file. rewriter.replaceOpWithNewOp(op, sizeOp.getSelf(), op.getIdx()); return success(); }); } //===----------------------------------------------------------------------===// // AtenIsFloatingPointOp //===----------------------------------------------------------------------===// OpFoldResult AtenIsFloatingPointOp::fold(FoldAdaptor adaptor) { auto operandType = getSelf().getType().dyn_cast(); if (!operandType) return nullptr; if (operandType.hasDtype()) { bool isFloatType = operandType.getDtype().isa(); return IntegerAttr::get(IntegerType::get(getContext(), 1), isFloatType); } // doesn't has dtype return nullptr; } //===----------------------------------------------------------------------===// // AtenAddTOp //===----------------------------------------------------------------------===// void AtenAddTOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenAddTOp op, PatternRewriter &rewriter) { auto lhsListConstruct = op.getA().getDefiningOp(); if (!lhsListConstruct || isListPotentiallyMutated(lhsListConstruct)) return failure(); auto rhsListConstruct = op.getB().getDefiningOp(); if (!rhsListConstruct || isListPotentiallyMutated(rhsListConstruct)) return failure(); SmallVector concatenatedList; for (auto a : lhsListConstruct.getOperands()) { concatenatedList.push_back(a); } for (auto b : rhsListConstruct.getOperands()) { concatenatedList.push_back(b); } rewriter.replaceOpWithNewOp(op, op.getType(), concatenatedList); return success(); }); } //===----------------------------------------------------------------------===// // AtenSliceTOp //===----------------------------------------------------------------------===// void AtenSliceTOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenSliceTOp op, PatternRewriter &rewriter) { auto valueList = op.getL(); auto listConstructOp = valueList.getDefiningOp(); if (!listConstructOp || isListPotentiallyMutated(listConstructOp)) { return failure(); } SmallVector listElements = llvm::to_vector<4>(listConstructOp.getElements()); int64_t size = static_cast(listElements.size()); int64_t start; int64_t end; int64_t step; if (op.getStart().getType().isa()) { start = 0; } else if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) { return failure(); } if (op.getEnd().getType().isa()) { end = listElements.size(); } else if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) { return failure(); } if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) { return failure(); } start = start >= 0 ? start : start + size; start = start >= 0 ? start : 0; end = end >= 0 ? end : end + size; end = end < size ? end : size; SmallVector newListElements; for (int64_t i = start; i < end; i += step) { newListElements.push_back(listElements[i]); } rewriter.replaceOpWithNewOp( op, Torch::ListType::get(listElements[0].getType()), newListElements); return success(); }); } //===----------------------------------------------------------------------===// // AtenEqIntListOp //===----------------------------------------------------------------------===// OpFoldResult AtenEqIntListOp::fold(FoldAdaptor adaptor) { auto lhsLiteral = getA().getDefiningOp(); if (!lhsLiteral) return nullptr; auto rhsLiteral = getB().getDefiningOp(); if (!rhsLiteral) return nullptr; // If the sizes don't match, then we know the lists aren't equal. if (lhsLiteral.getNumOperands() != rhsLiteral.getNumOperands()) return getI1IntegerAttr(getContext(), false); // If the sizes match and all corresponding list elements are the same Value, // then we know the lists are equal. // Note that we can't prove that the lists are not-equal with this method, // since two different Value's might dynamically be equal. if (llvm::all_of( llvm::zip(lhsLiteral.getOperands(), rhsLiteral.getOperands()), [](const auto &pair) { return std::get<0>(pair) == std::get<1>(pair); })) return getI1IntegerAttr(getContext(), true); return nullptr; } //===----------------------------------------------------------------------===// // PrimTupleConstructOp //===----------------------------------------------------------------------===// LogicalResult PrimTupleConstructOp::verify() { if (!(isValidSubtype( Torch::TupleType::get(getContext(), llvm::to_vector<6>(getElements().getType())), getResult().getType()))) return emitOpError( "failed to verify that contained types correspond to operand types"); return success(); } //===----------------------------------------------------------------------===// // PrimTupleIndexOp //===----------------------------------------------------------------------===// void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](PrimTupleIndexOp op, PatternRewriter &rewriter) { auto tupleConstruct = op.getTup().getDefiningOp(); if (!tupleConstruct) return failure(); int64_t i; if (!matchPattern(op.getI(), m_TorchConstantInt(&i))) return failure(); if (i >= (int64_t)tupleConstruct.getElements().size()) return failure(); // TODO: We should have a clear picture of whether we want to consistently // allow refinement, and where. It seems desirable to require precise // type equality for TupleConstruct / TupleIndex, but that might break // things. Value replacement = tupleConstruct.getElements()[i]; if (replacement.getType() != op.getType()) { if (op.getType().isa()) { replacement = rewriter.create( op.getLoc(), op.getType(), replacement); } else { return failure(); } } rewriter.replaceOp(op, replacement); return success(); }); } //===----------------------------------------------------------------------===// // PrimUninitializedOp //===----------------------------------------------------------------------===// void PrimUninitializedOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](PrimUninitializedOp op, PatternRewriter &rewriter) { if (!op.use_empty()) return failure(); rewriter.eraseOp(op); return success(); }); } //===----------------------------------------------------------------------===// // PrimTupleUnpackOp //===----------------------------------------------------------------------===// void PrimTupleUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](PrimTupleUnpackOp op, PatternRewriter &rewriter) { auto tupleConstruct = op.getTup().getDefiningOp(); if (!tupleConstruct) return failure(); llvm::SmallVector derefinedElements; // The result types may be supertypes of the tuple element types. // Ensure we maintain the exact type, with identity `derefine`s being // folded. for (auto [type, element] : llvm::zip(op.getResultTypes(), tupleConstruct.getElements())) { derefinedElements.push_back( rewriter.createOrFold(op.getLoc(), type, element)); } rewriter.replaceOp(op, derefinedElements); return success(); }); } //===----------------------------------------------------------------------===// // PrimListUnpackOp //===----------------------------------------------------------------------===// void PrimListUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](PrimListUnpackOp op, PatternRewriter &rewriter) { auto torchList = op.getOperand(); if (isListPotentiallyMutated(torchList)) { return failure(); } auto listConstruct = torchList.getDefiningOp(); if (!listConstruct) return failure(); rewriter.replaceOp(op, listConstruct.getElements()); return success(); }); } static PrimDictConstructOp getDictConstructIfNotModified(Value torchDict) { if (!llvm::all_of(torchDict.getUsers(), [](Operation *op) { return isa(op); })) return nullptr; return torchDict.getDefiningOp(); } //===----------------------------------------------------------------------===// // Aten__Getitem__DictStrOp //===----------------------------------------------------------------------===// OpFoldResult Aten__Getitem__DictStrOp::fold(FoldAdaptor adaptor) { auto dictConstruct = getDictConstructIfNotModified(getSelf()); if (!dictConstruct) return nullptr; auto targetKey = getKey(); for (auto i : llvm::zip(dictConstruct.getKeys(), dictConstruct.getValues())) { auto k = std::get<0>(i); if (k == targetKey) return std::get<1>(i); } return nullptr; } //===----------------------------------------------------------------------===// // Aten__Contains__StrOp //===----------------------------------------------------------------------===// OpFoldResult Aten__Contains__StrOp::fold(FoldAdaptor adaptor) { auto dictConstruct = getDictConstructIfNotModified(getDict()); if (!dictConstruct) return nullptr; auto targetKey = getKey(); for (auto key : dictConstruct.getKeys()) { if (key == targetKey) return getI1IntegerAttr(getContext(), true); } return nullptr; } //===----------------------------------------------------------------------===// // Aten__Contains__IntListOp //===----------------------------------------------------------------------===// static bool isListConstructNotModified(Value torchList) { return llvm::all_of(torchList.getUsers(), [](Operation *op) { return isa(op); }); } OpFoldResult Aten__Contains__IntListOp::fold(FoldAdaptor adaptor) { auto itemConstruct = getItem(); if (!isListConstructNotModified(getL())) return nullptr; int64_t item; SmallVector list; if (!matchPattern(itemConstruct, m_TorchConstantInt(&item))) return nullptr; if (!matchPattern(getL(), m_TorchListOfConstantInts(list))) return nullptr; for (auto elem : list) { if (elem == item) return getI1IntegerAttr(getContext(), true); } return getI1IntegerAttr(getContext(), false); } using BinaryIntOperatorFn = std::function; static OpFoldResult atenBinaryIntOperatorFoldHelper(ArrayRef operands, BinaryIntOperatorFn f) { auto intLhs = operands[0].dyn_cast_or_null(); auto intRhs = operands[1].dyn_cast_or_null(); if (!intLhs || !intRhs) { return nullptr; } return IntegerAttr::get( intLhs.getType(), f(intLhs.getValue().getSExtValue(), intRhs.getValue().getSExtValue())); } using BinaryFloatOperatorFn = std::function; static OpFoldResult atenBinaryFloatOperatorFoldHelper(ArrayRef operands, BinaryFloatOperatorFn f) { double lhs, rhs; auto parseDoubleAttribute = [](Attribute attr, double &value) -> bool { if (auto intLhs = attr.dyn_cast_or_null()) { value = static_cast(intLhs.getValue().getSExtValue()); } else if (auto floatLhs = attr.dyn_cast_or_null()) { value = floatLhs.getValue().convertToDouble(); } else { return false; } return true; }; if (!parseDoubleAttribute(operands[0], lhs) || !parseDoubleAttribute(operands[1], rhs)) { return nullptr; } return getF64FloatAttr(operands[0].getContext(), f(lhs, rhs)); } //===----------------------------------------------------------------------===// // AtenAliasOp //===----------------------------------------------------------------------===// OpFoldResult AtenAliasOp::fold(FoldAdaptor adaptor) { return getOperand(); } //===----------------------------------------------------------------------===// // AtenFloordivIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenFloordivIntOp::fold(FoldAdaptor adaptor) { return atenBinaryIntOperatorFoldHelper( adaptor.getOperands(), [](int64_t a, int64_t b) { return std::floor(a / (double)b); }); } //===----------------------------------------------------------------------===// // AtenRemainderIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenRemainderIntOp::fold(FoldAdaptor adaptor) { return atenBinaryIntOperatorFoldHelper( adaptor.getOperands(), [](int64_t a, int64_t b) { return a % b; }); } //===----------------------------------------------------------------------===// // AtenAddIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenAddIntOp::fold(FoldAdaptor adaptor) { return atenBinaryIntOperatorFoldHelper( adaptor.getOperands(), [](int64_t a, int64_t b) { return a + b; }); } //===----------------------------------------------------------------------===// // AtenSubIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) { return atenBinaryIntOperatorFoldHelper( adaptor.getOperands(), [](int64_t a, int64_t b) { return a - b; }); } //===----------------------------------------------------------------------===// // AtenCatOp //===----------------------------------------------------------------------===// OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) { auto list = getOperand(0).getDefiningOp(); if (!list || !list->hasOneUse() || list.getElements().size() != 1) return nullptr; return list.getElements()[0]; } //===----------------------------------------------------------------------===// // AtenStackOp //===----------------------------------------------------------------------===// OpFoldResult AtenStackOp::fold(FoldAdaptor adaptor) { auto list = getOperand(0).getDefiningOp(); if (!list || !list->hasOneUse() || list.getElements().size() != 1) return nullptr; return list.getElements()[0]; } //===----------------------------------------------------------------------===// // AtenBroadcastToOp //===----------------------------------------------------------------------===// OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) { auto inType = getOperand(0).getType().dyn_cast(); auto outType = getResult().getType().dyn_cast(); if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes()) return nullptr; if (inType.getSizes().size() != outType.getSizes().size()) return nullptr; for (size_t i = 0; i < inType.getSizes().size(); ++i) { if (inType.getSizes()[i] != outType.getSizes()[i]) return nullptr; } return getOperand(0); } //===----------------------------------------------------------------------===// // AtenBroadcastToOp //===----------------------------------------------------------------------===// OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) { auto inType = getOperand(0).getType().dyn_cast(); auto outType = getResult().getType().dyn_cast(); if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes()) return nullptr; if (inType.getSizes().size() != outType.getSizes().size() || !inType.areAllSizesKnown() || !outType.areAllSizesKnown()) return nullptr; for (size_t i = 0; i < inType.getSizes().size(); ++i) { if (inType.getSizes()[i] != outType.getSizes()[i]) return nullptr; } return getOperand(0); } //===----------------------------------------------------------------------===// // AtenSliceTensorOp //===----------------------------------------------------------------------===// OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { int64_t start, end, step; if (matchPattern(getStart(), m_TorchConstantInt(&start)) && matchPattern(getEnd(), m_TorchConstantInt(&end)) && matchPattern(getStep(), m_TorchConstantInt(&step)) && step == 1 && start == 0 && end == std::numeric_limits::max()) return getOperand(0); auto inType = getOperand(0).getType().dyn_cast(); auto outType = getResult().getType().dyn_cast(); if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes()) return nullptr; if (inType.getSizes().size() != outType.getSizes().size() || !inType.areAllSizesKnown() || !outType.areAllSizesKnown()) return nullptr; for (size_t i = 0; i < inType.getSizes().size(); ++i) { if (inType.getSizes()[i] != outType.getSizes()[i]) return nullptr; } return getOperand(0); } //===----------------------------------------------------------------------===// // AtenMulIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) { int64_t lhs, rhs; bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs)); bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs)); if ((lConstant && lhs == 0) || (rConstant && rhs == 0)) return getI64IntegerAttr(getContext(), 0); if (lConstant && rConstant) return getI64IntegerAttr(getContext(), lhs * rhs); return nullptr; } //===----------------------------------------------------------------------===// // AtenMulFloatOp //===----------------------------------------------------------------------===// OpFoldResult AtenMulFloatOp::fold(FoldAdaptor adaptor) { return atenBinaryFloatOperatorFoldHelper( adaptor.getOperands(), [](double a, double b) { return a * b; }); } //===----------------------------------------------------------------------===// // AtenSubFloatOp //===----------------------------------------------------------------------===// OpFoldResult AtenSubFloatOp::fold(FoldAdaptor adaptor) { return atenBinaryFloatOperatorFoldHelper( adaptor.getOperands(), [](double a, double b) { return a - b; }); } //===----------------------------------------------------------------------===// // AtenAddOp //===----------------------------------------------------------------------===// OpFoldResult AtenAddOp::fold(FoldAdaptor adaptor) { if (!adaptor.getA() || !adaptor.getB()) { return nullptr; } if (adaptor.getA().isa() && adaptor.getB().isa()) { return atenBinaryIntOperatorFoldHelper( adaptor.getOperands(), [](int64_t a, int64_t b) -> int64_t { return a + b; }); } return atenBinaryFloatOperatorFoldHelper( adaptor.getOperands(), [](double a, double b) -> double { return a + b; }); } //===----------------------------------------------------------------------===// // AtenSubOp //===----------------------------------------------------------------------===// OpFoldResult AtenSubOp::fold(FoldAdaptor adaptor) { if (!adaptor.getA() || !adaptor.getB()) { return nullptr; } if (adaptor.getA().isa() && adaptor.getB().isa()) { return atenBinaryIntOperatorFoldHelper( adaptor.getOperands(), [](int64_t a, int64_t b) -> int64_t { return a - b; }); } return atenBinaryFloatOperatorFoldHelper( adaptor.getOperands(), [](double a, double b) -> double { return a - b; }); } //===----------------------------------------------------------------------===// // AtenDivOp //===----------------------------------------------------------------------===// OpFoldResult AtenDivOp::fold(FoldAdaptor adaptor) { if (!adaptor.getA() || !adaptor.getB()) { return nullptr; } // Since AtenDivOp always returns float value, we don't need to deal with the // case where the operands are both integers separately. return atenBinaryFloatOperatorFoldHelper( adaptor.getOperands(), [](double a, double b) -> double { return a / b; }); } //===----------------------------------------------------------------------===// // AtenAddFloatIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenAddFloatIntOp::fold(FoldAdaptor adaptor) { if (!adaptor.getA() || !adaptor.getB()) { return nullptr; } return atenBinaryFloatOperatorFoldHelper( adaptor.getOperands(), [](double a, double b) { return a + b; }); } //===----------------------------------------------------------------------===// // AtenPowIntFloatOp //===----------------------------------------------------------------------===// OpFoldResult AtenPowIntFloatOp::fold(FoldAdaptor adaptor) { if (!adaptor.getA() || !adaptor.getB()) { return nullptr; } return atenBinaryFloatOperatorFoldHelper( adaptor.getOperands(), [](double a, double b) { return std::pow(a, b); }); } //===----------------------------------------------------------------------===// // AtenCeilScalarOp //===----------------------------------------------------------------------===// OpFoldResult AtenCeilScalarOp::fold(FoldAdaptor adaptor) { if (!adaptor.getA()) { return nullptr; } auto floatValue = adaptor.getA().dyn_cast_or_null(); if (!floatValue) { return nullptr; } return getI64IntegerAttr( getContext(), static_cast(std::ceil(floatValue.getValue().convertToDouble()))); } //===----------------------------------------------------------------------===// // AtenNegIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenNegIntOp::fold(FoldAdaptor adaptor) { int64_t c; if (matchPattern(getOperand(), m_TorchConstantInt(&c))) return getI64IntegerAttr(getContext(), -c); return nullptr; } //===----------------------------------------------------------------------===// // AtenNegFloatOp //===----------------------------------------------------------------------===// OpFoldResult AtenNegFloatOp::fold(FoldAdaptor adaptor) { if (!adaptor.getA()) { return nullptr; } auto value = adaptor.getA().dyn_cast_or_null(); if (!value) { return nullptr; } return getF64FloatAttr(getContext(), -value.getValue().convertToDouble()); } //===----------------------------------------------------------------------===// // AtenSqrtIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenSqrtIntOp::fold(FoldAdaptor adaptor) { int64_t c; if (matchPattern(getOperand(), m_TorchConstantInt(&c))) return getF64FloatAttr(getContext(), std::sqrt(c)); return nullptr; } //===----------------------------------------------------------------------===// // PrimDtypeOp //===----------------------------------------------------------------------===// OpFoldResult PrimDtypeOp::fold(FoldAdaptor adaptor) { BaseTensorType tensorType = getA().getType().cast(); if (tensorType.hasDtype()) { torch_upstream::ScalarType scalarType = Torch::getScalarTypeForType(tensorType.getDtype()); return getI64IntegerAttr(getContext(), static_cast(scalarType)); } return nullptr; } //===----------------------------------------------------------------------===// // PrimDeviceOp //===----------------------------------------------------------------------===// void PrimDeviceOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](PrimDeviceOp op, PatternRewriter &rewriter) { // Device information isn't relevant to torch-mlir, just replace it with // "cpu". rewriter.replaceOpWithNewOp(op, "cpu"); return success(); }); } //===----------------------------------------------------------------------===// // AtenCudaOp //===----------------------------------------------------------------------===// void AtenCudaOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenCudaOp op, PatternRewriter &rewriter) { // Device information isn't relevant to torch-mlir auto inputTensor = op.getSelf(); rewriter.replaceOp(op, inputTensor); return success(); }); } //===----------------------------------------------------------------------===// // AtenDeviceWithIndexOp //===----------------------------------------------------------------------===// void AtenDeviceWithIndexOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenDeviceWithIndexOp op, PatternRewriter &rewriter) { std::string type; int64_t index; if (!matchPattern(op.getType(), m_TorchConstantStr(type))) { return rewriter.notifyMatchFailure( op, "unimplemented: type must be a constant string"); } if (!matchPattern(op.getIndex(), m_TorchConstantInt(&index))) { return rewriter.notifyMatchFailure( op, "unimplemented: index must be a constant integer"); } rewriter.replaceOpWithNewOp( op, type + ":" + std::to_string(index)); return success(); }); } //===----------------------------------------------------------------------===// // AtenIntTensorOp //===----------------------------------------------------------------------===// OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) { // If a scalar number is converted to a 0-d tensor and passed on to // aten.Int.Tensor, fold to the scalar number. if (auto numToTensorScalar = getA().getDefiningOp()) return numToTensorScalar.getA(); if (auto tensorIntOp = getA().getDefiningOp()) return tensorIntOp.getT(); return nullptr; } //===----------------------------------------------------------------------===// // AtenFloatTensorOp //===----------------------------------------------------------------------===// OpFoldResult AtenFloatTensorOp::fold(FoldAdaptor adaptor) { // If a scalar number is converted to a 0-d tensor and passed on to // aten.Float.Tensor, fold to the scalar number. if (auto numToTensorScalar = getA().getDefiningOp()) return numToTensorScalar.getA(); return nullptr; } //===----------------------------------------------------------------------===// // AtenDivFloatOp //===----------------------------------------------------------------------===// OpFoldResult AtenDivFloatOp::fold(FoldAdaptor adaptor) { double lhs, rhs; bool lConstant = matchPattern(getOperand(0), m_TorchConstantFloat(&lhs)); bool rConstant = matchPattern(getOperand(1), m_TorchConstantFloat(&rhs)); if (lConstant && lhs == 0.0) return getF64FloatAttr(getContext(), 0.0); if (lConstant && rConstant && rhs == 1.0) return getF64FloatAttr(getContext(), lhs); if (lConstant && rConstant) return getF64FloatAttr(getContext(), lhs / rhs); return nullptr; } //===----------------------------------------------------------------------===// // AtenDivIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenDivIntOp::fold(FoldAdaptor adaptor) { int64_t lhs, rhs; bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs)); bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs)); if (lConstant && rConstant) return getF64FloatAttr(getContext(), double(lhs) / rhs); return nullptr; } //===----------------------------------------------------------------------===// // AtenCeilFloatOp //===----------------------------------------------------------------------===// OpFoldResult AtenCeilFloatOp::fold(FoldAdaptor adaptor) { double c; if (matchPattern(getOperand(), m_TorchConstantFloat(&c))) return getI64IntegerAttr(getContext(), std::ceil(c)); return nullptr; } //===----------------------------------------------------------------------===// // PrimMaxIntOp //===----------------------------------------------------------------------===// OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) { // If both operands are the same, then the operation is an identity. if (getA() == getB()) return getA(); auto lhs = adaptor.getA().dyn_cast_or_null(); auto rhs = adaptor.getB().dyn_cast_or_null(); if (!lhs || !rhs) return nullptr; // Torch semantics are that !torch.int is 64-bit signed. return IntegerAttr::get( lhs.getType(), std::max(lhs.getValue().getSExtValue(), rhs.getValue().getSExtValue())); } //===----------------------------------------------------------------------===// // PrimMinSelfIntOp //===----------------------------------------------------------------------===// OpFoldResult PrimMinSelfIntOp::fold(FoldAdaptor adaptor) { auto list = getOperand().getDefiningOp(); if (!list) return nullptr; // TODO: What does it return for an empty list? if (list->getNumOperands() == 0) return nullptr; SmallVector values; for (auto operand : list->getOperands()) { int64_t value; if (!matchPattern(operand, m_TorchConstantInt(&value))) return nullptr; values.push_back(value); } return getI64IntegerAttr(getContext(), *std::min_element(values.begin(), values.end())); } //===----------------------------------------------------------------------===// // PrimMinIntOp //===----------------------------------------------------------------------===// OpFoldResult PrimMinIntOp::fold(FoldAdaptor adaptor) { // If both operands are the same, then the operation is an identity. if (getA() == getB()) return getA(); auto lhs = adaptor.getA().dyn_cast_or_null(); auto rhs = adaptor.getB().dyn_cast_or_null(); if (!lhs || !rhs) return nullptr; // Torch semantics are that !torch.int is 64-bit signed. return IntegerAttr::get( lhs.getType(), std::min(lhs.getValue().getSExtValue(), rhs.getValue().getSExtValue())); } //===----------------------------------------------------------------------===// // ShapeCalculateOp //===----------------------------------------------------------------------===// template static void getSuccessorRegionsForCalculateOp(CalculateOp op, RegionBranchPoint point, SmallVectorImpl ®ions) { if (!point.getRegionOrNull()) { // First thing the op does is branch into the calculation. regions.emplace_back(&op.getCalculation()); return; } if (point == op.getBody()) { // Body returns control to the outer op, passing through results. regions.emplace_back(op.getResults()); return; } assert(point == op.getCalculation()); // Calculation branches to the body. regions.emplace_back(&op.getBody()); } void ShapeCalculateOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { getSuccessorRegionsForCalculateOp(*this, point, regions); } //===----------------------------------------------------------------------===// // DtypeCalculateOp //===----------------------------------------------------------------------===// void DtypeCalculateOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { getSuccessorRegionsForCalculateOp(*this, point, regions); } //===----------------------------------------------------------------------===// // ShapeCalculateYieldShapesOp //===----------------------------------------------------------------------===// MutableOperandRange ShapeCalculateYieldShapesOp::getMutableSuccessorOperands( RegionBranchPoint point) { // The shape operands don't get forwarded to the body. // MutableOperandRange always has an owning operation, even if empty, so // create a 0-length range. return MutableOperandRange(*this, /*start=*/0, /*length=*/0); } LogicalResult ShapeCalculateYieldShapesOp::verify() { auto parent = cast(getOperation()->getParentOp()); if (parent.getNumResults() != getNumOperands()) return emitOpError("expected number of shapes to match number of results"); return success(); } //===----------------------------------------------------------------------===// // DtypeCalculateYieldDtypesOp //===----------------------------------------------------------------------===// MutableOperandRange DtypeCalculateYieldDtypesOp::getMutableSuccessorOperands( RegionBranchPoint point) { // The dtype operands don't get forwarded to the body. // MutableOperandRange always has an owning operation, even if empty, so // create a 0-length range. return MutableOperandRange(*this, /*start=*/0, /*length=*/0); } LogicalResult DtypeCalculateYieldDtypesOp::verify() { auto parent = cast(getOperation()->getParentOp()); if (parent.getNumResults() != getNumOperands()) return emitOpError("expected number of dtypes to match number of results"); return success(); } //===----------------------------------------------------------------------===// // GlobalSlotModuleInitializerOp //===----------------------------------------------------------------------===// LogicalResult GlobalSlotModuleInitializerOp::verify() { // We centralize all verification of the global slots and the // InitializeGlobalSlotsOp into here, since it requires processing the whole // module. // TODO: We should really have a `torch.module` and have this initializer be // a region attached to it. ModuleOp module = cast(getOperation()->getParentOp()); for (auto op : module.getOps()) { if (op.getOperation() != getOperation()) return op.emitError("there must be only one global slot initializer"); } // Collect the relevant symbol names we will verify. DenseSet knownGlobalSlots; for (auto op : module.getOps()) knownGlobalSlots.insert(op.getSymNameAttr()); DenseSet initializedGlobalSlots; auto initialize = cast(getBody()->getTerminator()); for (Attribute symName : initialize.getSlotSymNames()) { auto wasInserted = initializedGlobalSlots .insert(symName.cast().getAttr()) .second; if (!wasInserted) return initialize.emitError("duplicate initialization of global slot: ") << symName; } auto lessThanByStringValue = [](Attribute lhs, Attribute rhs) { return lhs.cast().getValue() < rhs.cast().getValue(); }; auto known = llvm::to_vector(knownGlobalSlots); llvm::sort(known, lessThanByStringValue); auto initialized = llvm::to_vector(initializedGlobalSlots); llvm::sort(initialized, lessThanByStringValue); // Check that the global slots in the module are all initialized. SymbolTable symbolTable(module); if (initializedGlobalSlots != knownGlobalSlots) { InFlightDiagnostic diag = initialize.emitOpError( "must have one initializer for each global slot in the module"); for (auto knownGlobalSlot : known) { auto symName = FlatSymbolRefAttr::get(knownGlobalSlot.cast()); if (!initializedGlobalSlots.count(knownGlobalSlot)) { diag.attachNote( symbolTable.lookup(symName.getAttr()).getLoc()) .append("missing global slot initializer for ", symName); } } for (auto initializedGlobalSlot : initialized) { if (!knownGlobalSlots.count(initializedGlobalSlot)) { diag.attachNote().append( "unexpected global slot initializer for non-existent global slot ", FlatSymbolRefAttr::get(initializedGlobalSlot.cast())); } } return diag; } // Check that initial values satisfy type bounds. for (int i = 0, e = initialize.getNumOperands(); i < e; ++i) { auto symName = initialize.getSlotSymNames()[i].cast(); auto initialValue = initialize.getOperand(i); auto globalSlotOp = symbolTable.lookup(symName.getValue()); if (!isValidSubtype(initialValue.getType(), globalSlotOp.getTypeBound())) { return initialize.emitOpError().append( "initial value for global slot ", symName, " has type ", initialValue.getType(), " which is not within the bound ", globalSlotOp.getTypeBound()); } } auto walkResult = getOperation()->walk([](Operation *op) { // We only permit a small set of ops in the module initializer. // These ops are essentially those which can be produced by the IValue // importer. if (op->hasTrait()) return WalkResult::advance(); op->emitOpError() << "is not allowed in a module initializer"; return WalkResult::interrupt(); }); if (walkResult.wasInterrupted()) return failure(); return success(); } //===----------------------------------------------------------------------===// // InitializeGlobalSlotsOp //===----------------------------------------------------------------------===// ParseResult InitializeGlobalSlotsOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); if (parser.parseLSquare()) return failure(); SmallVector slotSymNames; while (!succeeded(parser.parseOptionalRSquare())) { NamedAttrList dummy; StringAttr slotSymName; if (parser.parseSymbolName(slotSymName, "dummy", dummy)) return failure(); slotSymNames.push_back(FlatSymbolRefAttr::get(slotSymName)); if (parser.parseLParen()) return failure(); OpAsmParser::UnresolvedOperand initialValue; if (parser.parseOperand(initialValue)) return failure(); Type initialValueType; if (parser.parseColonType(initialValueType)) return failure(); if (parser.parseRParen()) return failure(); if (parser.resolveOperand(initialValue, initialValueType, result.operands)) return failure(); } result.addAttribute("slotSymNames", ArrayAttr::get(parser.getContext(), slotSymNames)); return success(); } void InitializeGlobalSlotsOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict(getOperation()->getAttrs(), /*elidedAttrs=*/{"slotSymNames"}); p << " ["; p.printNewline(); for (int i = 0, e = getNumOperands(); i < e; ++i) { p << " " << getSlotSymNames()[i] << "(" << getInitialValues()[i] << " : " << getInitialValues()[i].getType() << ")"; p.printNewline(); } p << "]"; } LogicalResult InitializeGlobalSlotsOp::verify() { if (getInitialValues().size() != getSlotSymNames().size()) return emitOpError("expected number of operands to match number of slots"); return success(); }