//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #define DEBUG_TYPE "torch-mlir-torch-dialect" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LLVM.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/ADT/BitVector.h" #include "llvm/ADT/StringMap.h" #include "llvm/Support/Casting.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; //===----------------------------------------------------------------------===// // Utilities //===----------------------------------------------------------------------===// Value mlir::torch::Torch::adjustStaticInformation(OpBuilder &builder, Location loc, Value value, Type desiredType, bool userAllowsRefinement) { Type type = value.getType(); // If the value is already of the desired type, we're done. if (type == desiredType) return value; // If the type is a tensor, then adjust the static information. if ((type.isa() && desiredType.isa()) || (type.isa() && desiredType.isa())) { Value adjusted = builder.create(value.getLoc(), desiredType, value); return adjusted; } // If the type is a subtype of desiredType, then we need to derefine it to // desiredType, unless the user allows refinement. if (isValidSubtype(type, desiredType)) { if (!userAllowsRefinement) { Value adjusted = builder.create(value.getLoc(), desiredType, value); return adjusted; } else { return value; } } // If the desiredType is subtype of type, then we assume that the desiredType // is dynamically valid, so we do an unchecked cast. if (isValidSubtype(desiredType, type)) { Value adjusted = builder.create(value.getLoc(), desiredType, value); return adjusted; } // No known adjustment. return Value(); } Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc, BaseTensorType newType, Value tensor) { auto originalType = tensor.getType().cast(); // Adjust the static information in the type to match between the original and // new types. if (!originalType.hasSameSizesAndDtype(newType)) { tensor = builder.create( loc, originalType.getWithSizesAndDtypeFrom(newType), tensor); } // Unless both the original and new types are both value tensors, we end // up creating one op that converts between the value and non-value tensor // domains. If both the original and new types are both non-value tensors, // then we do the copy by going to a value tensor and back. if (tensor.getType().isa()) tensor = builder.create(loc, tensor); if (newType.isa()) tensor = builder.create(loc, tensor); return tensor; } bool mlir::torch::Torch::isListPotentiallyMutated(Value list) { assert(list.getType().isa()); return llvm::any_of(list.getUsers(), potentiallyMutatesListOperands); } bool mlir::torch::Torch::potentiallyMutatesListOperands(Operation *op) { // TODO: Find a better place to put this assertion. assert((!op->hasTrait() || op->hasTrait()) && "HasValueSemantics should imply ReadOnly!"); // ReadOnly ops trivially do not mutate any list operands. if (op->hasTrait()) return false; // Ops with no MemoryEffectOpInterface effects also do not mutate any list // operands. if (auto effects = dyn_cast(op)) { if (effects.hasNoEffect()) return false; } // Conservatively assume that an op might mutate any list operands. return true; } static IntegerAttr getI64IntegerAttr(MLIRContext *context, int64_t value) { return IntegerAttr::get(IntegerType::get(context, 64), value); } static FloatAttr getF64FloatAttr(MLIRContext *context, double value) { return FloatAttr::get(Float64Type::get(context), value); } static Value getScalarIntValue(Value input, Location loc, PatternRewriter &rewriter) { auto inputType = input.getType(); if (inputType.isa()) { return input; } auto inputTensorType = inputType.dyn_cast(); if (!inputTensorType) return nullptr; Type inputDtype = inputTensorType.getOptionalDtype(); if (!inputDtype || !inputDtype.isInteger(64)) return nullptr; std::optional inputRank = getTensorRank(input); if (!inputRank || *inputRank != 0) return nullptr; if (auto valueTensorLiteralOp = input.getDefiningOp()) { auto val = valueTensorLiteralOp.getValue() .cast() .getSplatValue(); return rewriter.create( loc, rewriter.getI64IntegerAttr(val)); } else if (auto primNumToTensorScalarOp = input.getDefiningOp()) { return primNumToTensorScalarOp.getA(); } else if (auto tensorIntOp = input.getDefiningOp()) { return tensorIntOp.getT(); } return nullptr; } static Value getScalarFloatValue(Value input, Location loc, PatternRewriter &rewriter) { auto inputType = input.getType(); if (inputType.isa()) { return input; } auto inputTensorType = inputType.dyn_cast(); if (!inputTensorType) return nullptr; Type inputDtype = inputTensorType.getOptionalDtype(); if (!inputDtype || (!inputDtype.isF16() && !inputDtype.isF32() && !inputDtype.isF64())) return nullptr; std::optional inputRank = getTensorRank(input); if (!inputRank || *inputRank != 0) return nullptr; if (auto valueTensorLiteralOp = input.getDefiningOp()) { auto val = valueTensorLiteralOp.getValue() .cast() .getSplatValue() .getValueAsDouble(); return rewriter.create( loc, rewriter.getF64FloatAttr(val)); } else if (auto primNumToTensorScalarOp = input.getDefiningOp()) { return primNumToTensorScalarOp.getA(); } else if (auto tensorFloatOp = input.getDefiningOp()) { return tensorFloatOp.getT(); } return nullptr; } //===----------------------------------------------------------------------===// // MethodOp //===----------------------------------------------------------------------===// LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) { auto func = symbolTable.lookupNearestSymbolFrom( *this, getFunctionAttr()); if (!func) return emitError() << "'@" << getFunction() << "' does not reference a valid function"; if (func.getVisibility() != SymbolTable::Visibility::Private) return emitError() << "'@" << getFunction() << "' must reference a private function"; if (func.isDeclaration()) return emitError() << "'@" << getFunction() << "' must reference a function that is defined (not " "merely declared)"; auto expectedReceiverArgType = NnModuleType::get( getContext(), getOperation()->getParentOfType().getName()); if (func.getFunctionType().getNumInputs() == 0 || func.getFunctionType().getInput(0) != expectedReceiverArgType) { return emitError() << "the referenced function '" << getFunction() << "' must have a first argument of type " << expectedReceiverArgType; } return success(); } //===----------------------------------------------------------------------===// // NnModuleOp //===----------------------------------------------------------------------===// LogicalResult NnModuleOp::verify() { for (Operation &child : *getBody()) if (!isa(&child)) return child.emitOpError() << "is not allowed inside 'torch.nn_module'"; return success(); } LogicalResult NnModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) { auto classType = symbolTable.lookupNearestSymbolFrom( *this, SymbolRefAttr::get(getContext(), getClassName())); if (!classType) return emitError() << "'" << getClassName() << "' does not reference a valid class type"; auto attrs = llvm::to_vector<6>(getBody()->getOps()); auto attrDefs = llvm::to_vector<6>(classType.getBody()->getOps()); if (attrs.size() != attrDefs.size()) return emitError() << "number of 'torch.slot's in a 'torch.nn_module' must " "match number of 'torch.attr's in " "the corresponding 'torch.class_type'"; for (int i = 0, e = attrs.size(); i != e; i++) { SlotOp attr = attrs[i]; AttrOp attrDef = attrDefs[i]; if (!isValidSubtype(attr.getValue().getType(), attrDef.getType()) || attr.getName() != attrDef.getName()) { return attr.emitOpError() .append("is expected to match type and name of '", attrDef.getOperation(), "'") .attachNote(attrDef.getLoc()) .append("see torch.attr at corresponding index ", i, " here"); } } return success(); } //===----------------------------------------------------------------------===// // PrimListConstructOp //===----------------------------------------------------------------------===// LogicalResult PrimListConstructOp::verify() { auto resultType = getResult().getType(); auto resultElementType = resultType.dyn_cast().getContainedType(); auto matchResultElementType = [&](Type type) { return isValidSubtype(type, resultElementType); }; if (!llvm::all_of(getOperandTypes(), matchResultElementType)) { return emitError() << "operand types should have the same type as the " "list contained type"; } return success(); } //===----------------------------------------------------------------------===// // PrimDictConstructOp //===----------------------------------------------------------------------===// LogicalResult PrimDictConstructOp::verify() { auto isValidSubTypeOf = [](Type expectedType) { return [=](Type type) { return isValidSubtype(type, expectedType); }; }; if (!llvm::all_of(getKeys().getTypes(), isValidSubTypeOf(getKeyType()))) return emitError() << "keys should be of Dict key type"; if (!llvm::all_of(getValues().getTypes(), isValidSubTypeOf(getValueType()))) return emitError() << "values should be of Dict value type"; return success(); } //===----------------------------------------------------------------------===// // ClassTypeOp //===----------------------------------------------------------------------===// LogicalResult ClassTypeOp::verify() { llvm::StringMap namesToOps; for (Operation &child : getBody()->without_terminator()) { if (!isa(&child)) return child.emitOpError() << "is not allowed inside `torch.class_type`"; StringRef name; if (auto attr = dyn_cast(child)) name = attr.getName(); else name = cast(child).getName(); auto itAndWasInserted = namesToOps.insert({name, &child}); auto it = itAndWasInserted.first; bool wasInserted = itAndWasInserted.second; if (!wasInserted) { auto diag = emitOpError().append("has duplicate attr/method with name '", name, "'"); diag.attachNote(it->second->getLoc()) .append("see first conflicting attr/method here"); diag.attachNote(child.getLoc()) .append("see second conflicting attr/method here"); return failure(); } } return success(); } //===----------------------------------------------------------------------===// // PrimLoopOp //===----------------------------------------------------------------------===// OperandRange PrimLoopOp::getEntrySuccessorOperands(RegionBranchPoint point) { assert(point == getRegion()); return getIterArgsInit(); } void PrimLoopOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { Region ®ion = getRegion(); if (!point.getRegionOrNull()) { regions.emplace_back(®ion, region.getArguments().slice(1)); return; } assert(point == region); regions.emplace_back(®ion, region.getArguments().slice(1)); regions.emplace_back(getResults()); } bool PrimLoopOp::isForLike() { bool b; return matchPattern(getInitialCondition(), m_TorchConstantBool(&b)) && b; } //===----------------------------------------------------------------------===// // PrimLoopConditionOp //===----------------------------------------------------------------------===// MutableOperandRange PrimLoopConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) { // Pass all operands except the condition to the successor which is the // parent loop op. return getIterArgsMutable(); } //===----------------------------------------------------------------------===// // PrimIfOp //===----------------------------------------------------------------------===// ParseResult PrimIfOp::parse(OpAsmParser &parser, OperationState &result) { // Create the regions. result.regions.reserve(2); Region *thenRegion = result.addRegion(); Region *elseRegion = result.addRegion(); auto &builder = parser.getBuilder(); OpAsmParser::UnresolvedOperand cond; Type boolType = builder.getType(); if (parser.parseOperand(cond) || parser.resolveOperand(cond, boolType, result.operands)) return failure(); // Parse results type list. if (parser.parseArrowTypeList(result.types)) return failure(); // Parse the 'then' region. if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{})) return failure(); // Parse the 'else' region. if (parser.parseKeyword("else")) return failure(); if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{})) return failure(); // Parse the optional attribute list. if (parser.parseOptionalAttrDict(result.attributes)) return failure(); return success(); } void PrimIfOp::print(OpAsmPrinter &p) { p << " " << getCondition(); p << " -> (" << getResultTypes() << ") "; p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false); p << " else "; p.printRegion(getElseRegion(), /*printEntryBlockArgs=*/false); p.printOptionalAttrDict((*this)->getAttrs()); } void PrimIfOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { // The `then` and the `else` region branch back to the parent operation. if (point.getRegionOrNull()) { regions.push_back(RegionSuccessor(getResults())); return; } // If the condition is constant, we can give a more precise answer. bool condition; if (matchPattern(getCondition(), m_TorchConstantBool(&condition))) { Region *executedRegion = condition ? &getThenRegion() : &getElseRegion(); regions.push_back(RegionSuccessor(executedRegion)); return; } // If the condition isn't constant, both regions may be executed. regions.push_back(RegionSuccessor(&getThenRegion())); regions.push_back(RegionSuccessor(&getElseRegion())); return; } /// Replaces the given op with the contents of the given single-block region, /// using the operands of the block terminator to replace operation results. static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs = {}) { assert(llvm::hasSingleElement(region) && "expected single-region block"); Block *block = ®ion.front(); Operation *terminator = block->getTerminator(); ValueRange results = terminator->getOperands(); rewriter.inlineBlockBefore(block, op, blockArgs); rewriter.replaceOp(op, results); rewriter.eraseOp(terminator); } void PrimIfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { // If the condition is constant, delete the dead branch and inline the live // branch. patterns.add(+[](PrimIfOp op, PatternRewriter &rewriter) { auto constantBool = op.getCondition().getDefiningOp(); if (!constantBool) return rewriter.notifyMatchFailure(op, "non-constant condition"); replaceOpWithRegion(rewriter, op, constantBool.getValue() ? op.getThenRegion() : op.getElseRegion()); return success(); }); // If the thenRegion and elseRegion yield the same Value's, then use those // directly. patterns.add(+[](PrimIfOp op, PatternRewriter &rewriter) { auto trueTerminator = op.getThenRegion().front().getTerminator(); auto falseTerminator = op.getElseRegion().front().getTerminator(); bool madeChange = false; SmallVector resultsToErase; for (auto t : llvm::zip(trueTerminator->getOperands(), falseTerminator->getOperands(), op->getResults())) { auto trueVal = std::get<0>(t); auto falseVal = std::get<1>(t); auto resultToBeReplaced = std::get<2>(t); if (trueVal == falseVal) { madeChange |= !resultToBeReplaced.use_empty(); resultToBeReplaced.replaceAllUsesWith(trueVal); } } // We leave it up to a separate pattern (not yet implemented) to erase the // results that are now dead. That transformation is independently useful, // and also pretty tricky to implement because it changes the number of // results. return success(madeChange); }); // Erase any dead results. patterns.add(+[](PrimIfOp op, PatternRewriter &rewriter) { llvm::BitVector resultsToErase(op.getNumResults()); for (auto result : llvm::enumerate(op->getResults())) { if (result.value().use_empty()) resultsToErase.set(result.index()); } // If no results have uses and there are no side effects, just erase the op. // Approximate the body having no side effects by checking if it is just a // terminator. // Note: We don't want to make this logic too fancy, because in general, // checking for recursive side effects can result in a quadratic amount of // work (N nested If's each resulting in O(N) work). It should probably be // split into its own pattern if we want to make it fancier. if (resultsToErase.all() && llvm::hasSingleElement(op.getThenRegion().front()) && llvm::hasSingleElement(op.getElseRegion().front())) { rewriter.eraseOp(op); return success(); } // If there are no results to erase, we're done. if (!resultsToErase.any()) return failure(); SmallVector newResultTypes; for (int i = 0, e = op->getNumResults(); i < e; ++i) { if (resultsToErase[i]) continue; newResultTypes.push_back(op->getResult(i).getType()); } auto newIf = rewriter.create(op->getLoc(), newResultTypes, op.getCondition()); rewriter.inlineRegionBefore(op.getThenRegion(), newIf.getThenRegion(), newIf.getThenRegion().end()); rewriter.inlineRegionBefore(op.getElseRegion(), newIf.getElseRegion(), newIf.getElseRegion().end()); newIf.getThenRegion().front().getTerminator()->eraseOperands( resultsToErase); newIf.getElseRegion().front().getTerminator()->eraseOperands( resultsToErase); SmallVector replacementValues; for (int i = 0, e = op->getNumResults(), nextNewValue = 0; i < e; ++i) { if (resultsToErase[i]) replacementValues.push_back(nullptr); else replacementValues.push_back(newIf->getResult(nextNewValue++)); } rewriter.replaceOp(op, replacementValues); return success(); }); } //===----------------------------------------------------------------------===// // RuntimeAssertOp //===----------------------------------------------------------------------===// void RuntimeAssertOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](RuntimeAssertOp op, PatternRewriter &rewriter) { bool value; if (!matchPattern(op.getCondition(), m_TorchConstantBool(&value))) return failure(); if (value) { rewriter.eraseOp(op); return success(); } // Even if the condition is statically false, the assert might never be // executed. return failure(); }); } //===----------------------------------------------------------------------===// // DerefineOp //===----------------------------------------------------------------------===// bool DerefineOp::areCastCompatible(mlir::TypeRange inputs, mlir::TypeRange outputs) { return isValidSubtype(inputs[0], outputs[0]); } OpFoldResult DerefineOp::fold(FoldAdaptor adaptor) { auto uncheckedCast = getOperand().getDefiningOp(); if (!uncheckedCast) return nullptr; if (uncheckedCast.getOperand().getType() == getType()) return uncheckedCast.getOperand(); return nullptr; } void DerefineOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](DerefineOp op, PatternRewriter &rewriter) { bool madeChange = false; for (OpOperand &use : llvm::make_early_inc_range(op->getUses())) { if (use.getOwner()->hasTrait()) { use.set(op.getOperand()); madeChange = true; } } return success(madeChange); }); } static OpFoldResult atenIsOrIsNotFoldHelper(Operation *op, bool equalIsTrue) { Value lhs = op->getOperand(0); Value rhs = op->getOperand(1); // Look through DerefineOp's to get more refined static information. if (auto derefine = lhs.getDefiningOp()) lhs = derefine.getOperand(); if (auto derefine = rhs.getDefiningOp()) rhs = derefine.getOperand(); Type lhsType = lhs.getType(); Type rhsType = rhs.getType(); // If either type is a NoneType, make it be the lhsType. if (rhsType.isa()) { std::swap(lhsType, rhsType); std::swap(lhs, rhs); } // For now, check a few specific cases. // If both types are the singleton `!torch.none` type, then we don't even need // to look at the values. if (lhsType.isa() && rhsType.isa()) return IntegerAttr::get(IntegerType::get(op->getContext(), 1), equalIsTrue); // If neither type is a subtype of the other, then the result is false. // TODO: Implement and use subtype infra for this. // For now, check a specific case. // If the rhs is not OptionalType, then we know it cannot be None. if (lhsType.isa() && !rhsType.isa()) { return IntegerAttr::get(IntegerType::get(op->getContext(), 1), !equalIsTrue); } return nullptr; } //===----------------------------------------------------------------------===// // Aten__RangeLengthOp //===----------------------------------------------------------------------===// OpFoldResult Aten__RangeLengthOp::fold(FoldAdaptor adaptor) { auto lo = adaptor.getLo(); auto hi = adaptor.getHi(); auto step = adaptor.getStep(); if (!lo || !hi || !step) return nullptr; auto loInt = lo.dyn_cast_or_null().getValue(); auto hiInt = hi.dyn_cast_or_null().getValue(); auto stepInt = step.dyn_cast_or_null().getValue(); // TODO: Implement folding for negative steps. if (stepInt.isNegative()) return nullptr; // From Python language spec: // r[i] = lo + step*i such that i >= 0 and r[i] < hi // So maximize `i` such that lo + step * i < hi // ==> i == ceildiv(hi - lo, step) return IntegerAttr::get(lo.cast().getType(), llvm::APIntOps::RoundingSDiv(hiInt - loInt, stepInt, APInt::Rounding::UP)); } //===----------------------------------------------------------------------===// // Aten__DeriveIndexOp //===----------------------------------------------------------------------===// OpFoldResult Aten__DeriveIndexOp::fold(FoldAdaptor adaptor) { auto index = adaptor.getIndex(); auto start = adaptor.getStart(); auto step = adaptor.getStep(); if (!index || !start || !step) return nullptr; auto indexInt = index.dyn_cast_or_null().getValue(); auto startInt = start.dyn_cast_or_null().getValue(); auto stepInt = step.dyn_cast_or_null().getValue(); return IntegerAttr::get(index.cast().getType(), startInt + stepInt * indexInt); } //===----------------------------------------------------------------------===// // Aten__Is__Op //===----------------------------------------------------------------------===// OpFoldResult Aten__Is__Op::fold(FoldAdaptor adaptor) { return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/true); } //===----------------------------------------------------------------------===// // Aten__Isnot__Op //===----------------------------------------------------------------------===// OpFoldResult Aten__Isnot__Op::fold(FoldAdaptor adaptor) { return atenIsOrIsNotFoldHelper(*this, /*equalIsTrue=*/false); } //===----------------------------------------------------------------------===// // Aten__Not__Op //===----------------------------------------------------------------------===// OpFoldResult Aten__Not__Op::fold(FoldAdaptor adaptor) { bool value; if (!matchPattern(getOperand(), m_TorchConstantBool(&value))) return nullptr; return IntegerAttr::get(IntegerType::get(getContext(), 1), !value); } //===----------------------------------------------------------------------===// // AtenNeBoolOp //===----------------------------------------------------------------------===// OpFoldResult AtenNeBoolOp::fold(FoldAdaptor adaptor) { if (getOperand(0) == getOperand(1)) return IntegerAttr::get(IntegerType::get(getContext(), 1), false); bool a, b; if (!matchPattern(getOperand(0), m_TorchConstantBool(&a))) return nullptr; if (!matchPattern(getOperand(1), m_TorchConstantBool(&b))) return nullptr; return IntegerAttr::get(IntegerType::get(getContext(), 1), a != b); } //===----------------------------------------------------------------------===// // AtenUnsqueezeOp //===----------------------------------------------------------------------===// OpFoldResult AtenUnsqueezeOp::fold(FoldAdaptor adaptor) { auto selfTy = dyn_cast(getSelf().getType()); auto rty = dyn_cast(getType()); if (!rty.hasDtype()) return {}; if (auto attr = dyn_cast_or_null(adaptor.getSelf())) { auto aty = dyn_cast(attr.getType()); if (rty.hasSizes() && rty.areAllSizesKnown() && attr.isSplat()) { auto naty = RankedTensorType::get(rty.getSizes(), aty.getElementType()); return DenseElementsAttr::get(naty, attr.getSplatValue()); } } if (getSelf().getType() != getResult().getType()) return nullptr; if (selfTy && rty) { if (selfTy.hasSizes() && rty.hasSizes() && selfTy.getSizes().size() == rty.getSizes().size()) return getSelf(); } return nullptr; } //===----------------------------------------------------------------------===// // AtenSqueezeOp //===----------------------------------------------------------------------===// OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) { auto selfTy = dyn_cast(getSelf().getType()); auto rty = dyn_cast(getType()); if (!rty.hasDtype()) return {}; if (auto attr = dyn_cast_or_null(adaptor.getSelf())) { auto aty = dyn_cast(attr.getType()); if (rty.hasSizes() && rty.areAllSizesKnown() && attr.isSplat()) { auto naty = RankedTensorType::get(rty.getSizes(), aty.getElementType()); return DenseElementsAttr::get(naty, attr.getSplatValue()); } } if (getSelf().getType() != getResult().getType()) return nullptr; if (selfTy && rty) { if (selfTy.hasSizes() && rty.hasSizes() && selfTy.getSizes().size() == rty.getSizes().size()) return getSelf(); } return nullptr; } //===----------------------------------------------------------------------===// // AtenSqueezeDimOp //===----------------------------------------------------------------------===// OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) { if (getOperand(0).getType() != getResult().getType()) return nullptr; if (auto tensorType = getOperand(0).getType().dyn_cast()) { if (tensorType.hasSizes() && tensorType.getSizes().size() == 0) return getOperand(0); } return nullptr; } //===----------------------------------------------------------------------===// // AtenToDtypeOp //===----------------------------------------------------------------------===// OpFoldResult AtenToDtypeOp::fold(FoldAdaptor adaptor) { bool nonBlocking, copyArg; // The non_blocking arg must be `False`. if (!matchPattern(getNonBlocking(), m_TorchConstantBool(&nonBlocking)) || nonBlocking) return nullptr; // The copy arg must be `False`. if (!matchPattern(getCopy(), m_TorchConstantBool(©Arg)) || copyArg) return nullptr; // The memory_format arg must be `none`. if (!getMemoryFormat().getType().isa()) return nullptr; auto inputType = getSelf().getType().cast(); auto resType = getType().cast(); // If the types aren't equal, then we can't fold. if (inputType != resType) return nullptr; // If the type does not have a statically known dtype, then we cannot fold. // For example, folding `tensor<*,unk>` to `tensor<*,unk>` would be wrong, // since the `unk` could be dynamically different for the operand and result. if (!inputType.hasDtype()) return nullptr; // Fold when both the input tensor and result are of the same type. return getOperand(0); } //===----------------------------------------------------------------------===// // AtenToDtypeLayoutOp //===----------------------------------------------------------------------===// OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) { // The pin_memory arg should be either constant `False` or `none`. if (!getPinMemory().getType().isa()) { bool pinMemory; if (!matchPattern(getPinMemory(), m_TorchConstantBool(&pinMemory))) return nullptr; else if (pinMemory) return nullptr; } // The non_blocking arg should be constant `False`. bool nonBlocking; if (!matchPattern(getNonBlocking(), m_TorchConstantBool(&nonBlocking))) return nullptr; else if (nonBlocking) return nullptr; // The copy arg should be constant `False`. bool copyArg; if (!matchPattern(getCopy(), m_TorchConstantBool(©Arg))) return nullptr; else if (copyArg) return nullptr; // The device arg must be `none`. if (!getDevice().getType().isa()) return nullptr; // The memory_format arg must be `none`. if (!getMemoryFormat().getType().isa()) return nullptr; auto inputType = getSelf().getType().cast(); auto resType = getType().cast(); // If the types aren't equal, then we can't fold. if (inputType != resType) return nullptr; // If the type does not have a statically known dtype, then we cannot fold. // For example, folding `tensor<*,unk>` to `tensor<*,unk>` would be wrong, // since the `unk` could be dynamically different for the operand and result. if (!inputType.hasDtype()) return nullptr; // The layout arg should be either `none` or `0` i.e. strided. if (!getLayout().getType().isa()) { int64_t tensorLayout; if (!matchPattern(getLayout(), m_TorchConstantInt(&tensorLayout))) return nullptr; else if (tensorLayout != torch_upstream::Layout::Strided) return nullptr; } // Fold when both the input tensor and result are of the same type and the // layout arg is strided. return getOperand(0); } void AtenToDtypeLayoutOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { // `to.dtype_layout` -> `to.device/to.dtype` if layout is none and pin memory // is false patterns.add(+[](AtenToDtypeLayoutOp op, PatternRewriter &rewriter) { // The pin_memory arg should be either constant `False` or `none`. if (!op.getPinMemory().getType().isa()) { bool pinMemory; if (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory))) return failure(); else if (pinMemory) return failure(); } // The layout arg should be either `none` or `0` i.e. strided. if (!op.getLayout().getType().isa()) { int64_t tensorLayout; if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout))) return failure(); else if (tensorLayout != torch_upstream::Layout::Strided) return failure(); } if (op.getDevice().getType().isa()) { // The device arg is `none`. Rewrite to to.dtype. AtenToDtypeOp toDtype = rewriter.create( op.getLoc(), op.getType(), op.getSelf(), op.getDtype(), op.getNonBlocking(), op.getCopy(), op.getMemoryFormat()); rewriter.replaceOp(op, toDtype->getResults()); } else { // The device arg is not `none`. Rewrite to to.device. AtenToDeviceOp toDevice = rewriter.create( op.getLoc(), op.getType(), op.getSelf(), op.getDevice(), op.getDtype(), op.getNonBlocking(), op.getCopy(), op.getMemoryFormat()); rewriter.replaceOp(op, toDevice->getResults()); } return success(); }); } //===----------------------------------------------------------------------===// // AtenToOtherOp //===----------------------------------------------------------------------===// void AtenToOtherOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { // Canonicalize `aten.to.other` to `aten.to.device` patterns.add(+[](AtenToOtherOp op, PatternRewriter &rewriter) { auto lhs = op.getSelf(); auto rhs = op.getOther(); auto getRhsDevice = rewriter.create(op.getLoc(), rhs); auto getRhsDtype = rewriter.create(op.getLoc(), rhs); rewriter.replaceOpWithNewOp( op, op.getType(), lhs, getRhsDevice.getResult(), getRhsDtype.getResult(), op.getNonBlocking(), op.getCopy(), op.getMemoryFormat()); return success(); }); } //===----------------------------------------------------------------------===// // AtenViewOp //===----------------------------------------------------------------------===// OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) { auto inputType = getOperand(0).getType().dyn_cast(); if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 1) return nullptr; auto resType = getType().dyn_cast(); if (!resType || !resType.hasSizes() || resType.getSizes().size() != 1) return nullptr; if (inputType != resType) return nullptr; // Fold when both the input tensor and result are unity rank tensors. return getOperand(0); } //===----------------------------------------------------------------------===// // PrimsViewOfOp //===----------------------------------------------------------------------===// OpFoldResult PrimsViewOfOp::fold(FoldAdaptor adaptor) { // Always fold the op with its only input operand. return getOperand(); } //===----------------------------------------------------------------------===// // AtenDimOp //===----------------------------------------------------------------------===// OpFoldResult AtenDimOp::fold(FoldAdaptor adaptor) { if (auto tensorType = getOperand().getType().dyn_cast()) { if (tensorType.hasSizes()) return IntegerAttr::get(IntegerType::get(getContext(), 64), tensorType.getSizes().size()); } return nullptr; } //===----------------------------------------------------------------------===// // AtenLenTOp //===----------------------------------------------------------------------===// OpFoldResult AtenLenTOp::fold(FoldAdaptor adaptor) { // `len([1,1,1])` -> `3`, if it is not mutated. if (auto listConstruct = getOperand().getDefiningOp()) { if (!isListPotentiallyMutated(listConstruct)) { return IntegerAttr::get(IntegerType::get(getContext(), 64), listConstruct.getNumOperands()); } } return nullptr; } void AtenLenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { // `len(t.size())` -> `t.ndim` patterns.add(+[](AtenLenTOp op, PatternRewriter &rewriter) { auto size = op.getOperand().getDefiningOp(); if (!size) return rewriter.notifyMatchFailure(op, "operand not AtenSizeOp"); rewriter.replaceOpWithNewOp(op, size.getOperand()); return success(); }); } //===----------------------------------------------------------------------===// // AtenMinOtherOp //===----------------------------------------------------------------------===// void AtenMinOtherOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { // `aten.min.other` -> `aten.minimum` patterns.add(+[](AtenMinOtherOp op, PatternRewriter &rewriter) { rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), op.getOther()); return success(); }); } //===----------------------------------------------------------------------===// // AtenMaxOtherOp //===----------------------------------------------------------------------===// void AtenMaxOtherOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { // `aten.max.other` -> `aten.maximum` patterns.add(+[](AtenMaxOtherOp op, PatternRewriter &rewriter) { rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), op.getOther()); return success(); }); } //===----------------------------------------------------------------------===// // AtenLenStrOp //===----------------------------------------------------------------------===// OpFoldResult AtenLenStrOp::fold(FoldAdaptor adaptor) { if (auto stringConstruct = getS().getDefiningOp()) return getI64IntegerAttr(getContext(), stringConstruct.getValueAttr().getValue().size()); return nullptr; } LogicalResult rewrite0DBinaryTensorOp(Operation *op, PatternRewriter &rewriter) { Location loc = op->getLoc(); // This canonicalization pattern also includes aten div/mul/add/sub ops // between tensor and scalar, like aten.add.Scalar op if (op->getNumOperands() < 2) { return failure(); } auto lhs = getScalarIntValue(op->getOperand(0), loc, rewriter); auto rhs = getScalarIntValue(op->getOperand(1), loc, rewriter); auto outType = op->getResult(0).getType(); if (!lhs || !rhs) { return rewriter.notifyMatchFailure( op, "only int scalar lhs or rhs is supported"); } if (isa(op)) { Value alpha = getScalarIntValue(op->getOperand(2), loc, rewriter); if (!alpha) { return rewriter.notifyMatchFailure(op, "only int scalar alpha is supported"); } if (isa(op)) lhs = rewriter.create(loc, lhs, alpha); else rhs = rewriter.create(loc, rhs, alpha); } if (isa(op)) { // None rounding mode if (op->getOperand(2).getType().isa()) { Value quotient = rewriter.create(loc, lhs, rhs); rewriter.replaceOpWithNewOp(op, outType, quotient); return success(); } std::string roundingMode; if (!matchPattern(op->getOperand(2), m_TorchConstantStr(roundingMode))) { return rewriter.notifyMatchFailure( op, "only None, 'floor' or 'trunc' rounding mode is supported"); } if (roundingMode == "floor") { Value quotient = rewriter.create(loc, lhs, rhs); rewriter.replaceOpWithNewOp(op, outType, quotient); return success(); } // For "trunc" rounding mode, insted of canonicalizing it into // aten.abs, aten.floor, aten.sign and aten.mul.int ops, which adds // complexity but helps little in optimization (such as constant folding), // we are trying to fold it. if (roundingMode == "trunc") { int64_t lhsInt; int64_t rhsInt; if (!matchPattern(lhs, m_TorchConstantInt(&lhsInt))) { return failure(); } if (!matchPattern(rhs, m_TorchConstantInt(&rhsInt))) { return failure(); } int64_t result = (int64_t)std::trunc((double)lhsInt / rhsInt); Value resultScalar = rewriter.create( loc, rewriter.getI64IntegerAttr(result)); rewriter.replaceOpWithNewOp(op, outType, resultScalar); return success(); } return failure(); } Value result; // Other Add/Sub/Mul ops if (isa(op)) { result = rewriter.create(loc, lhs, rhs); } else if (isa(op)) { result = rewriter.create(loc, lhs, rhs); } else if (isa(op)) { result = rewriter.create(loc, rhs, lhs); } else if (isa(op)) { result = rewriter.create(loc, lhs, rhs); } rewriter.replaceOpWithNewOp(op, outType, result); return success(); } //===----------------------------------------------------------------------===// // NAry folder helpers //===----------------------------------------------------------------------===// static bool checkSameDTypes(llvm::ArrayRef attrs) { bool allFp = true; bool allInt = true; for (auto attr : attrs) { if (!attr) return false; Type attrty; if (auto dense = dyn_cast_or_null(attr)) attrty = dense.getType(); if (auto fp = dyn_cast_or_null(attr)) attrty = fp.getType(); if (auto integer = dyn_cast_or_null(attr)) attrty = integer.getType(); if (auto shaped = dyn_cast_or_null(attrty)) attrty = shaped.getElementType(); allFp &= isa(attrty); allInt &= isa(attrty); } return allFp || allInt; } static bool checkAllSplats(llvm::ArrayRef attrs) { for (auto attr : attrs) { if (auto dense = dyn_cast_or_null(attr)) { if (!dense.isSplat()) return false; } } return true; } llvm::SmallVector getFoldValueAtIndexFp(llvm::ArrayRef attrs, int64_t idx = 0) { llvm::SmallVector splattrs; for (auto attr : attrs) { if (auto dense = dyn_cast(attr)) { if (dense.isSplat()) { splattrs.push_back(dense.getSplatValue().convertToDouble()); } else { splattrs.push_back(dense.getValues()[idx].convertToDouble()); } } else if (auto intattr = dyn_cast(attr)) { splattrs.push_back(intattr.getValueAsDouble()); } else { return {}; } } return splattrs; } llvm::SmallVector getFoldValueAtIndexInt(llvm::ArrayRef attrs, int64_t bitwidth, int64_t idx = 0) { llvm::SmallVector splattrs; for (auto attr : attrs) { bool isunsigned = false; if (auto dense = dyn_cast(attr)) { isunsigned = dyn_cast(dense.getElementType()).isUnsigned(); if (dense.isSplat()) { splattrs.push_back(dense.getSplatValue()); } else { splattrs.push_back(dense.getValues()[idx]); } } else if (auto intattr = dyn_cast(attr)) { isunsigned = cast(intattr.getType()).isUnsigned(); splattrs.push_back(intattr.getValue()); } else { return {}; } auto &apint = splattrs.back(); if (apint.getBitWidth() < bitwidth) { if (isunsigned) { apint = apint.zextOrTrunc(bitwidth); } else { apint = apint.sextOrTrunc(bitwidth); } } } return splattrs; } using NAryFoldFpOperator = std::function)>; using NAryFoldIntOperator = std::function)>; static OpFoldResult naryFolderHelper(ArrayRef operands, Type ty, NAryFoldFpOperator fpFolder, NAryFoldIntOperator intFolder) { constexpr int64_t maxFold = 16; if (!checkSameDTypes(operands)) return nullptr; auto resultTy = dyn_cast(ty); if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes()) return nullptr; auto dty = resultTy.getDtype(); auto resultBTy = resultTy.toBuiltinTensor().clone(dty); auto fpTy = dyn_cast(dty); auto intTy = dyn_cast(dty); if (!fpTy && !intTy) return nullptr; bool allSplats = checkAllSplats(operands); bool withinMaxFold = resultBTy.hasStaticShape() && resultBTy.getNumElements() <= maxFold; if (!allSplats && !withinMaxFold) return nullptr; // We do not support broadcasting in the non-splat case so validate same // shaped inputs / outputs: if (!allSplats) { auto resultShape = resultBTy.getShape(); for (int i = 0, s = operands.size(); i < s; ++i) { if (auto dense = dyn_cast(operands[i])) { if (dense.isSplat()) continue; auto operandShape = cast(dense.getType()).getShape(); if (operandShape.size() != resultShape.size()) return nullptr; for (int i = 0, s = operandShape.size(); i < s; ++i) if (operandShape[i] != resultShape[i]) return nullptr; } } } const int64_t numValues = allSplats ? 1 : resultBTy.getNumElements(); if (fpTy) { llvm::SmallVector folded; for (int i = 0, s = numValues; i < s; ++i) { auto inputs = getFoldValueAtIndexFp(operands, i); double fold = fpFolder(inputs); APFloat val(fold); bool unused; val.convert(fpTy.getFloatSemantics(), APFloat::rmNearestTiesToEven, &unused); folded.push_back(val); } return DenseElementsAttr::get(resultBTy, folded); } if (intTy) { llvm::SmallVector folded; for (int i = 0, s = numValues; i < s; ++i) { auto inputs = getFoldValueAtIndexInt(operands, dty.getIntOrFloatBitWidth(), i); folded.push_back(intFolder(inputs)); } return DenseElementsAttr::get(resultBTy, folded); } return nullptr; } //===----------------------------------------------------------------------===// // AtenAddTensorOp //===----------------------------------------------------------------------===// void AtenAddTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenAddTensorOp op, PatternRewriter &rewriter) { return rewrite0DBinaryTensorOp(op, rewriter); }); } OpFoldResult AtenAddTensorOp::fold(FoldAdaptor adaptor) { auto fpFold = [](llvm::ArrayRef inputs) { assert(inputs.size() == 3); return inputs[0] + (inputs[1] * inputs[2]); }; auto intFold = [](llvm::ArrayRef inputs) { assert(inputs.size() == 3); return inputs[0] + (inputs[1] * inputs[2]); }; return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold); } //===----------------------------------------------------------------------===// // AtenAddScalarOp //===----------------------------------------------------------------------===// void AtenAddScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenAddScalarOp op, PatternRewriter &rewriter) { return rewrite0DBinaryTensorOp(op, rewriter); }); } //===----------------------------------------------------------------------===// // AtenSubTensorOp //===----------------------------------------------------------------------===// void AtenSubTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenSubTensorOp op, PatternRewriter &rewriter) { return rewrite0DBinaryTensorOp(op, rewriter); }); } OpFoldResult AtenSubTensorOp::fold(FoldAdaptor adaptor) { auto fpFold = [](llvm::ArrayRef inputs) { assert(inputs.size() == 3); return inputs[0] - (inputs[1] * inputs[2]); }; auto intFold = [](llvm::ArrayRef inputs) { assert(inputs.size() == 3); return inputs[0] - (inputs[1] * inputs[2]); }; return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold); } //===----------------------------------------------------------------------===// // AtenSubScalarOp //===----------------------------------------------------------------------===// void AtenSubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenSubScalarOp op, PatternRewriter &rewriter) { return rewrite0DBinaryTensorOp(op, rewriter); }); } //===----------------------------------------------------------------------===// // AtenRSubScalarOp //===----------------------------------------------------------------------===// void AtenRsubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenRsubScalarOp op, PatternRewriter &rewriter) { return rewrite0DBinaryTensorOp(op, rewriter); }); } //===----------------------------------------------------------------------===// // AtenMulTensorOp //===----------------------------------------------------------------------===// void AtenMulTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenMulTensorOp op, PatternRewriter &rewriter) { return rewrite0DBinaryTensorOp(op, rewriter); }); } OpFoldResult AtenMulTensorOp::fold(FoldAdaptor adaptor) { auto fpFold = [](llvm::ArrayRef inputs) { assert(inputs.size() == 2); return inputs[0] * inputs[1]; }; auto intFold = [](llvm::ArrayRef inputs) { assert(inputs.size() == 2); return inputs[0] * inputs[1]; }; return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold); } //===----------------------------------------------------------------------===// // AtenEqTensorOp //===----------------------------------------------------------------------===// OpFoldResult AtenEqTensorOp::fold(FoldAdaptor adaptor) { constexpr int64_t kMaxFold = 16; auto ty = dyn_cast(getType()); if (!ty || !ty.hasDtype() || !ty.hasSizes()) return nullptr; auto bty = ty.toBuiltinTensor().clone(ty.getDtype()); if (!bty.hasStaticShape()) return nullptr; if (getSelf() == getOther()) return DenseElementsAttr::get(bty, IntegerAttr::get(bty.getElementType(), 1)); auto self = dyn_cast_or_null(adaptor.getSelf()); auto other = dyn_cast_or_null(adaptor.getOther()); if (!self || !other) return nullptr; auto selfTy = dyn_cast(self.getType()); auto otherTy = dyn_cast(other.getType()); if (!selfTy || !otherTy || selfTy.getElementType() != otherTy.getElementType()) return nullptr; // If both values are splats we can just compute the output value as a splat. if (self.isSplat() && other.isSplat()) { if (isa(selfTy.getElementType())) { APFloat lhsFp = self.getSplatValue(); APFloat rhsFp = other.getSplatValue(); bool eq = lhsFp.compare(rhsFp) == APFloat::cmpEqual; return DenseElementsAttr::get(bty, eq); } if (isa(selfTy.getElementType())) { APInt lhsInt = self.getSplatValue(); APInt rhsInt = other.getSplatValue(); bool eq = lhsInt == rhsInt; return DenseElementsAttr::get(bty, eq); } return nullptr; } if (selfTy != otherTy || bty.getNumElements() > kMaxFold) return nullptr; if (isa(selfTy.getElementType())) { auto extract = [bty](DenseElementsAttr attr) { llvm::SmallVector vals; if (attr.isSplat()) { vals.resize(bty.getNumElements(), attr.getSplatValue()); return vals; } for (auto fp : attr.getValues()) { vals.push_back(fp); } return vals; }; llvm::SmallVector lhsFp = extract(self); llvm::SmallVector rhsFp = extract(other); llvm::SmallVector vals(bty.getNumElements()); for (int i = 0, s = bty.getNumElements(); i < s; ++i) { vals[i] = lhsFp[i].compare(rhsFp[i]) == APFloat::cmpEqual; } return DenseElementsAttr::get(bty, vals); } if (isa(selfTy.getElementType())) { auto extract = [bty](DenseElementsAttr attr) { llvm::SmallVector vals; if (attr.isSplat()) { vals.resize(bty.getNumElements(), attr.getSplatValue()); return vals; } for (auto fp : attr.getValues()) { vals.push_back(fp); } return vals; }; llvm::SmallVector lhsInt = extract(self); llvm::SmallVector rhsInt = extract(other); llvm::SmallVector vals(bty.getNumElements()); for (int i = 0, s = bty.getNumElements(); i < s; ++i) { vals[i] = lhsInt[i] == rhsInt[i]; } return DenseElementsAttr::get(bty, vals); } return nullptr; } //===----------------------------------------------------------------------===// // AtenLeScalarOp //===----------------------------------------------------------------------===// using ComparisonFoldFpOperator = std::function; using ComparisonFoldIntOperator = std::function; static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs, ValueTensorType resultTy, ComparisonFoldFpOperator fpFolder, ComparisonFoldIntOperator intFolder) { constexpr int64_t kMaxFold = 16; if (!lhs || !rhs || !resultTy) return nullptr; if (!resultTy.hasSizes() || !resultTy.hasDtype()) return nullptr; for (auto size : resultTy.getSizes()) if (size == Torch::kUnknownSize) return nullptr; auto ctx = lhs.getContext(); auto resultETy = resultTy.getDtype(); auto tensorETy = cast(lhs.getType()).getElementType(); if (lhs.isSplat()) { if (auto intAttr = dyn_cast(rhs)) { auto unsign = cast(tensorETy).isUnsigned(); auto scalarAP = intAttr.getValue(); auto tensorAP = lhs.getSplatValue().getValue(); tensorAP = APInt( scalarAP.getBitWidth(), unsign ? tensorAP.getZExtValue() : tensorAP.getSExtValue(), !unsign); auto resultBool = intFolder(tensorAP, scalarAP, unsign); auto resultAP = IntegerAttr::get(IntegerType::get(ctx, 1), resultBool); return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), resultAP); } if (auto floatAttr = dyn_cast(rhs)) { APFloat scalarAP = floatAttr.getValue(); APFloat tensorAP = lhs.getSplatValue().getValue(); auto resultBool = fpFolder(tensorAP.convertToDouble(), scalarAP.convertToDouble()); auto resultAP = IntegerAttr::get(IntegerType::get(ctx, 1), resultBool); return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), resultAP); } return nullptr; } int64_t count = 1; for (auto size : resultTy.getSizes()) count *= size; if (count > kMaxFold) return nullptr; if (auto intAttr = dyn_cast(rhs)) { auto unsign = cast(tensorETy).isUnsigned(); llvm::SmallVector values; for (auto tensorAP : lhs.getValues()) { auto scalarAP = intAttr.getValue(); tensorAP = APInt( scalarAP.getBitWidth(), unsign ? tensorAP.getZExtValue() : tensorAP.getSExtValue(), !unsign); auto resultBool = intFolder(tensorAP, scalarAP, unsign); values.push_back(resultBool); } return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), values); } if (auto floatAttr = dyn_cast(rhs)) { llvm::SmallVector values; for (auto tensorAP : lhs.getValues()) { APFloat scalarAP = floatAttr.getValue(); auto resultBool = fpFolder(tensorAP.convertToDouble(), scalarAP.convertToDouble()); values.push_back(resultBool); } return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), values); } return nullptr; } OpFoldResult AtenLeScalarOp::fold(FoldAdaptor adaptor) { auto self = dyn_cast_or_null(adaptor.getSelf()); auto other = adaptor.getOther(); auto resultTy = dyn_cast(getType()); auto fpFold = [](double lhs, double rhs) -> bool { return lhs <= rhs; }; auto intFold = [](APInt lhs, APInt rhs, bool unsign) -> bool { return unsign ? lhs.ule(rhs) : lhs.sle(rhs); }; return comparisonScaleFolder(self, other, resultTy, fpFold, intFold); } //===----------------------------------------------------------------------===// // AtenLtScalarOp //===----------------------------------------------------------------------===// OpFoldResult AtenLtScalarOp::fold(FoldAdaptor adaptor) { auto self = dyn_cast_or_null(adaptor.getSelf()); auto other = adaptor.getOther(); auto resultTy = dyn_cast(getType()); auto fpFold = [](double lhs, double rhs) -> bool { return lhs < rhs; }; auto intFold = [](APInt lhs, APInt rhs, bool unsign) -> bool { return unsign ? lhs.ult(rhs) : lhs.slt(rhs); }; return comparisonScaleFolder(self, other, resultTy, fpFold, intFold); } //===----------------------------------------------------------------------===// // AtenGtScalarOp //===----------------------------------------------------------------------===// OpFoldResult AtenGtScalarOp::fold(FoldAdaptor adaptor) { auto self = dyn_cast_or_null(adaptor.getSelf()); auto other = adaptor.getOther(); auto resultTy = dyn_cast(getType()); auto fpFold = [](double lhs, double rhs) -> bool { return lhs > rhs; }; auto intFold = [](APInt lhs, APInt rhs, bool unsign) -> bool { return unsign ? lhs.ugt(rhs) : lhs.sgt(rhs); }; return comparisonScaleFolder(self, other, resultTy, fpFold, intFold); } //===----------------------------------------------------------------------===// // AtenGeScalarOp //===----------------------------------------------------------------------===// OpFoldResult AtenGeScalarOp::fold(FoldAdaptor adaptor) { auto self = dyn_cast_or_null(adaptor.getSelf()); auto other = adaptor.getOther(); auto resultTy = dyn_cast(getType()); auto fpFold = [](double lhs, double rhs) -> bool { return lhs >= rhs; }; auto intFold = [](APInt lhs, APInt rhs, bool unsign) -> bool { return unsign ? lhs.uge(rhs) : lhs.sge(rhs); }; return comparisonScaleFolder(self, other, resultTy, fpFold, intFold); } //===----------------------------------------------------------------------===// // AtenEqScalarOp //===----------------------------------------------------------------------===// OpFoldResult AtenEqScalarOp::fold(FoldAdaptor adaptor) { auto self = dyn_cast_or_null(adaptor.getSelf()); auto other = adaptor.getOther(); auto resultTy = dyn_cast(getType()); auto fpFold = [](double lhs, double rhs) -> bool { return lhs == rhs; }; auto intFold = [](APInt lhs, APInt rhs, bool unsign) -> bool { return lhs.eq(rhs); }; return comparisonScaleFolder(self, other, resultTy, fpFold, intFold); } //===----------------------------------------------------------------------===// // AtenNeScalarOp //===----------------------------------------------------------------------===// OpFoldResult AtenNeScalarOp::fold(FoldAdaptor adaptor) { auto self = dyn_cast_or_null(adaptor.getSelf()); auto other = adaptor.getOther(); auto resultTy = dyn_cast(getType()); auto fpFold = [](double lhs, double rhs) -> bool { return lhs != rhs; }; auto intFold = [](APInt lhs, APInt rhs, bool unsign) -> bool { return lhs.ne(rhs); }; return comparisonScaleFolder(self, other, resultTy, fpFold, intFold); } //===----------------------------------------------------------------------===// // AtenFloorOp //===----------------------------------------------------------------------===// OpFoldResult AtenFloorOp::fold(FoldAdaptor adaptor) { auto resultType = getType().dyn_cast(); if (resultType && resultType.hasDtype() && resultType.getDtype().isa()) { return getSelf(); } return {}; } //===----------------------------------------------------------------------===// // AtenCeilOp //===----------------------------------------------------------------------===// OpFoldResult AtenCeilOp::fold(FoldAdaptor adaptor) { auto resultType = getType().dyn_cast(); if (resultType && resultType.hasDtype() && resultType.getDtype().isa()) { return getSelf(); } return {}; } //===----------------------------------------------------------------------===// // AtenRoundOp //===----------------------------------------------------------------------===// OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) { auto resultType = getType().dyn_cast(); if (resultType && resultType.hasDtype() && resultType.getDtype().isa()) { return getSelf(); } return {}; } //===----------------------------------------------------------------------===// // AtenMulScalarOp //===----------------------------------------------------------------------===// void AtenMulScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenMulScalarOp op, PatternRewriter &rewriter) { return rewrite0DBinaryTensorOp(op, rewriter); }); } //===----------------------------------------------------------------------===// // AtenDivTensorModeOp //===----------------------------------------------------------------------===// void AtenDivTensorModeOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenDivTensorModeOp op, PatternRewriter &rewriter) { return rewrite0DBinaryTensorOp(op, rewriter); }); } //===----------------------------------------------------------------------===// // AtenNumelOp //===----------------------------------------------------------------------===// void AtenNumelOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenNumelOp op, PatternRewriter &rewriter) { auto inputType = op.getSelf().getType().dyn_cast(); if (!inputType || !inputType.areAllSizesKnown()) { return failure(); } auto sizes = inputType.getSizes(); int64_t numel = 1; for (int64_t d : sizes) { numel *= d; } rewriter.replaceOpWithNewOp( op, rewriter.getI64IntegerAttr(numel)); return success(); }); } //===----------------------------------------------------------------------===// // Aten__Or__TensorOp //===----------------------------------------------------------------------===// void Aten__Or__TensorOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](Aten__Or__TensorOp op, PatternRewriter &rewriter) { rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), op.getOther()); return success(); }); } //===----------------------------------------------------------------------===// // AtenScalarImplicitOp //===----------------------------------------------------------------------===// void AtenScalarImplicitOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenScalarImplicitOp op, PatternRewriter &rewriter) { Location loc = op.getLoc(); Value a = op.getA(); auto outType = op.getResult().getType(); Value scalarValue = getScalarIntValue(a, loc, rewriter); if (!scalarValue) return failure(); rewriter.replaceOpWithNewOp(op, outType, scalarValue); return success(); }); } //===----------------------------------------------------------------------===// // AtenFloatImplicitOp //===----------------------------------------------------------------------===// void AtenFloatImplicitOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenFloatImplicitOp op, PatternRewriter &rewriter) { Location loc = op.getLoc(); Value a = op.getA(); Value scalarValue = getScalarFloatValue(a, loc, rewriter); if (!scalarValue) return failure(); rewriter.replaceOp(op, scalarValue); return success(); }); } //===----------------------------------------------------------------------===// // AtenIntImplicitOp //===----------------------------------------------------------------------===// void AtenIntImplicitOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenIntImplicitOp op, PatternRewriter &rewriter) { Location loc = op.getLoc(); Value a = op.getA(); Value scalarValue = getScalarIntValue(a, loc, rewriter); if (!scalarValue) return failure(); rewriter.replaceOp(op, scalarValue); return success(); }); } //===----------------------------------------------------------------------===// // AtenSizeOp //===----------------------------------------------------------------------===// // Traces at most 6 parents of `value` to determine the tensor type with known // dimension size or returns failure if such a type was not found. If `dim` is // `None`, then all dimension's sizes must be known. static FailureOr traceKnownSizeTensorType(Value value, std::optional dim) { // Function to check if we found a type that contains the queried information. auto foundType = [](BaseTensorType tensorType, std::optional(dim)) { if (!tensorType.hasSizes()) return false; if (dim == std::nullopt) return tensorType.areAllSizesKnown(); // If the dimension value is negative, then convert it to a positive value. ArrayRef sizes = tensorType.getSizes(); *dim = toPositiveDim(*dim, sizes.size()); return isValidDim(*dim, sizes.size()) && sizes[*dim] != kUnknownSize; }; // Limit the loop count to 6 to avoid indefinite compilation times from // unbounded IR traversals. for (auto idx = 0; idx < 6; ++idx) { if (!value || !value.getType().isa()) return failure(); auto tensorType = value.getType().cast(); if (foundType(tensorType, dim)) return tensorType; auto op = value.getDefiningOp(); if (!op || !isa(op)) return failure(); // In all ops of interest to us, the source tensor is operand #0. value = op->getOperand(0); } return failure(); } void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenSizeOp op, PatternRewriter &rewriter) { auto type = traceKnownSizeTensorType(op.getOperand(), std::nullopt); if (failed(type)) return rewriter.notifyMatchFailure(op, "all sizes not known"); SmallVector listElements; for (int64_t size : type->getSizes()) { listElements.push_back(rewriter.create( op->getLoc(), rewriter.getI64IntegerAttr(size))); } rewriter.replaceOpWithNewOp( op, Torch::ListType::get(rewriter.getType()), listElements); return success(); }); // One-off pattern to erase if dead. // TODO: Use the effects infra to express the semantics of this op and enable // a centralized "erase if dead" canonicalization. // Specifically, we need to mark the op as only MemoryEffects::Allocate // so that `mlir::wouldOpBeTriviallyDead` does the right thing. patterns.add(+[](AtenSizeOp op, PatternRewriter &rewriter) { if (!op.use_empty()) return failure(); rewriter.eraseOp(op); return failure(); }); } //===----------------------------------------------------------------------===// // AtenSelectIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenSelectIntOp::fold(FoldAdaptor adaptor) { auto self = dyn_cast_or_null(adaptor.getSelf()); auto ty = dyn_cast(getType()); if (!self || !ty || !ty.hasDtype() || !ty.hasSizes()) return nullptr; auto selfTy = cast(self.getType()); auto bty = ty.toBuiltinTensor().clone(ty.getDtype()); if (!bty.hasStaticShape()) return nullptr; if (self.isSplat()) return DenseElementsAttr::get(bty, self.getSplatValue()); auto dimAttr = dyn_cast_or_null(adaptor.getDim()); auto indexAttr = dyn_cast_or_null(adaptor.getIndex()); if (!dimAttr || !indexAttr || bty.getNumElements() != 1) return nullptr; auto dim = dimAttr.getInt(); auto index = indexAttr.getInt(); for (int i = 0, s = selfTy.getRank(); i < s; ++i) { if (i != dim && selfTy.getDimSize(i) != 1) return nullptr; } auto splattr = self.getValues()[index]; return DenseElementsAttr::get(bty, splattr); } //===----------------------------------------------------------------------===// // AtenSizeIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenSizeIntOp::fold(FoldAdaptor adaptor) { int64_t dim; if (!matchPattern(this->getDim(), m_TorchConstantInt(&dim))) return nullptr; auto type = traceKnownSizeTensorType(this->getSelf(), dim); if (failed(type)) return nullptr; ArrayRef sizes = type->getSizes(); dim = toPositiveDim(dim, sizes.size()); if (!isValidDim(dim, sizes.size())) return nullptr; return IntegerAttr::get(IntegerType::get(getContext(), 64), sizes[dim]); } //===----------------------------------------------------------------------===// // AtenGtIntOp //===----------------------------------------------------------------------===// static IntegerAttr getI1IntegerAttr(MLIRContext *context, bool value) { return IntegerAttr::get(IntegerType::get(context, 1), static_cast(value)); } using ConstantFloatComparator = std::function; template static OpFoldResult floatComparatorFoldHelper(OpTy op, ConstantFloatComparator comparator) { if (op.getOperand(0) == op.getOperand(1)) return getI1IntegerAttr(op.getContext(), comparator(0, 0)); double lhs, rhs; if (!matchPattern(op.getOperand(0), m_TorchConstantFloat(&lhs)) || !matchPattern(op.getOperand(1), m_TorchConstantFloat(&rhs))) return nullptr; return getI1IntegerAttr(op.getContext(), comparator(lhs, rhs)); } //===----------------------------------------------------------------------===// // AtenLtFloatOp //===----------------------------------------------------------------------===// OpFoldResult AtenLtFloatOp::fold(FoldAdaptor adaptor) { return floatComparatorFoldHelper(*this, [](double a, double b) { return a < b; }); } //===----------------------------------------------------------------------===// // AtenGtFloatOp //===----------------------------------------------------------------------===// OpFoldResult AtenGtFloatOp::fold(FoldAdaptor adaptor) { return floatComparatorFoldHelper(*this, [](double a, double b) { return a > b; }); } //===----------------------------------------------------------------------===// // AtenGeFloatOp //===----------------------------------------------------------------------===// OpFoldResult AtenGeFloatOp::fold(FoldAdaptor adaptor) { return floatComparatorFoldHelper(*this, [](double a, double b) { return a >= b; }); } //===----------------------------------------------------------------------===// // AtenEqFloatOp //===----------------------------------------------------------------------===// OpFoldResult AtenEqFloatOp::fold(FoldAdaptor adaptor) { return floatComparatorFoldHelper(*this, [](double a, double b) { return a == b; }); } using ConstantIntComparator = std::function; template static OpFoldResult intComparatorFoldHelper(OpTy op, ConstantIntComparator comparator) { Value lhsValue = op->getOperand(0); Value rhsValue = op->getOperand(1); if (lhsValue == rhsValue) return getI1IntegerAttr(op.getContext(), comparator(0, 0)); int64_t lhs, rhs; bool lhsIsConstant = matchPattern(lhsValue, m_TorchConstantInt(&lhs)); bool rhsIsConstant = matchPattern(rhsValue, m_TorchConstantInt(&rhs)); if (lhsIsConstant && rhsIsConstant) return getI1IntegerAttr(op.getContext(), comparator(lhs, rhs)); // Ensure that if there is a constant, it is on the right. if (lhsIsConstant && !rhsIsConstant) { std::swap(lhs, rhs); std::swap(lhsValue, rhsValue); std::swap(lhsIsConstant, rhsIsConstant); auto newComparator = [comparator](int64_t lhs, int64_t rhs) { return comparator(rhs, lhs); }; comparator = newComparator; } // Fold comparisons of AtenSizeIntOp against negative values. // AtenSizeIntOp is known to always be non-negative. if (rhsIsConstant && rhs < 0) { // We can return `comparator(0, -1)` here because of the property: // If x >= 0 && y < 0, then: // - cmp(x, y) == cmp(x + 1, y) // - cmp(x, y) == cmp(x, y - 1) // By induction all cases here are covered. if (auto size = lhsValue.getDefiningOp()) return getI1IntegerAttr(op->getContext(), comparator(0, -1)); } // Fold comparisons of AtenSizeIntOp against 0: // - torch.aten.size.int >= 0 ==> True. // - torch.aten.size.int < 0 ==> False. // (and the operand-swapped versions of the above) if (rhsIsConstant && rhs == 0) { if (auto size = lhsValue.getDefiningOp()) { // >= 0 comparison. if (comparator(0, 0) && comparator(1, 0)) return getI1IntegerAttr(op->getContext(), true); // < 0 comparison. if (!comparator(0, 0) && comparator(-1, 0) && !comparator(1, 0)) return getI1IntegerAttr(op->getContext(), false); } } return nullptr; } //===----------------------------------------------------------------------===// // AtenDetachOp //===----------------------------------------------------------------------===// OpFoldResult AtenDetachOp::fold(FoldAdaptor adaptor) { if (getSelf().getType() != getResult().getType()) return {}; return getSelf(); } //===----------------------------------------------------------------------===// // AtenNeIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenNeIntOp::fold(FoldAdaptor adaptor) { return intComparatorFoldHelper(*this, [](int64_t a, int64_t b) { return a != b; }); } //===----------------------------------------------------------------------===// // AtenEqIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenEqIntOp::fold(FoldAdaptor adaptor) { return intComparatorFoldHelper(*this, [](int64_t a, int64_t b) { return a == b; }); } //===----------------------------------------------------------------------===// // AtenEqStrOp //===----------------------------------------------------------------------===// OpFoldResult AtenEqStrOp::fold(FoldAdaptor adaptor) { if (getOperand(0) == getOperand(1)) return getI1IntegerAttr(getContext(), true); auto aStr = getA().getDefiningOp(); auto bStr = getB().getDefiningOp(); if (aStr && bStr) return getI1IntegerAttr(getContext(), aStr == bStr); return nullptr; } //===----------------------------------------------------------------------===// // AtenLtIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenLtIntOp::fold(FoldAdaptor adaptor) { return intComparatorFoldHelper(*this, [](int64_t a, int64_t b) { return a < b; }); } //===----------------------------------------------------------------------===// // AtenLeIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenLeIntOp::fold(FoldAdaptor adaptor) { return intComparatorFoldHelper(*this, [](int64_t a, int64_t b) { return a <= b; }); } //===----------------------------------------------------------------------===// // AtenGtIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenGtIntOp::fold(FoldAdaptor adaptor) { return intComparatorFoldHelper(*this, [](int64_t a, int64_t b) { return a > b; }); } //===----------------------------------------------------------------------===// // AtenGeIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenGeIntOp::fold(FoldAdaptor adaptor) { return intComparatorFoldHelper(*this, [](int64_t a, int64_t b) { return a >= b; }); } //===----------------------------------------------------------------------===// // AtenBoolFloatOp //===----------------------------------------------------------------------===// OpFoldResult AtenBoolFloatOp::fold(FoldAdaptor adaptor) { double c; if (matchPattern(getOperand(), m_TorchConstantFloat(&c))) return getI1IntegerAttr(getContext(), c != 0.0); return nullptr; } //===----------------------------------------------------------------------===// // AtenBoolIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenBoolIntOp::fold(FoldAdaptor adaptor) { int64_t c; if (matchPattern(getOperand(), m_TorchConstantInt(&c))) return getI1IntegerAttr(getContext(), c != 0); return nullptr; } //===----------------------------------------------------------------------===// // AtenAnyBoolOp //===----------------------------------------------------------------------===// OpFoldResult AtenAnyBoolOp::fold(FoldAdaptor adaptor) { auto inputConstruct = getSelf().getDefiningOp(); if (!inputConstruct || isListPotentiallyMutated(inputConstruct)) return nullptr; // If any operand is a constant true, return true. for (auto operand : inputConstruct.getOperands()) { bool b = false; if (matchPattern(operand, m_TorchConstantBool(&b)) && b) { return getI1IntegerAttr(getContext(), true); } } return nullptr; } //===----------------------------------------------------------------------===// // AtenFloatScalarOp //===----------------------------------------------------------------------===// OpFoldResult AtenFloatScalarOp::fold(FoldAdaptor adaptor) { // Constant fold int -> float conversion. if (auto integerAttr = adaptor.getA().dyn_cast_or_null()) { return FloatAttr::get( mlir::Float64Type::get(getContext()), static_cast(integerAttr.getValue().getSExtValue())); } // If the input is float type already, the op is an identity. if (getType() == getOperand().getType()) return getOperand(); return nullptr; } //===----------------------------------------------------------------------===// // AtenIntFloatOp //===----------------------------------------------------------------------===// OpFoldResult AtenIntFloatOp::fold(FoldAdaptor adaptor) { // Constant fold float -> int conversion. if (auto floatAttr = adaptor.getA().dyn_cast_or_null()) { return IntegerAttr::get( mlir::IntegerType::get(getContext(), 64), static_cast(floatAttr.getValue().convertToDouble())); } return nullptr; } //===----------------------------------------------------------------------===// // AtenIntScalarOp //===----------------------------------------------------------------------===// OpFoldResult AtenIntScalarOp::fold(FoldAdaptor adaptor) { // Constant fold float -> int conversion. if (auto floatAttr = adaptor.getA().dyn_cast_or_null()) { return IntegerAttr::get( mlir::IntegerType::get(getContext(), 64), static_cast(floatAttr.getValue().convertToDouble())); } // If the input is int type already, the op is an identity. if (getType() == getOperand().getType()) return getOperand(); return nullptr; } //===----------------------------------------------------------------------===// // AtenIntBoolOp //===----------------------------------------------------------------------===// OpFoldResult AtenIntBoolOp::fold(FoldAdaptor adaptor) { bool b; if (matchPattern(getOperand(), m_TorchConstantBool(&b))) { return getI64IntegerAttr(getContext(), static_cast(b)); } return nullptr; } //===----------------------------------------------------------------------===// // AtenMaskedFillTensorOp //===----------------------------------------------------------------------===// // Fold 0d fill tensor to scalar void AtenMaskedFillTensorOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenMaskedFillTensorOp op, PatternRewriter &rewriter) { auto scalarIntVal = getScalarIntValue(op.getValue(), op->getLoc(), rewriter); auto scalarFloatVal = getScalarFloatValue(op.getValue(), op->getLoc(), rewriter); if (!scalarIntVal && !scalarFloatVal) return failure(); Value scalarVal = scalarIntVal ? scalarIntVal : scalarFloatVal; rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), op.getMask(), scalarVal); return failure(); }); } //===----------------------------------------------------------------------===// // AtenCloneOp //===----------------------------------------------------------------------===// OpFoldResult AtenCloneOp::fold(FoldAdaptor adaptor) { // note: memory_format would be ignored if (llvm::dyn_cast(getSelf().getType())) { // self should have value semantics return getSelf(); } return {}; } //===----------------------------------------------------------------------===// // AtenSortIntOp //===----------------------------------------------------------------------===// void AtenSortIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenSortIntOp op, PatternRewriter &rewriter) { SmallVector listElements; if (!matchPattern(op.getSelf(), m_TorchListOfConstantInts(listElements))) return rewriter.notifyMatchFailure( op, "all input list elements must be constant ints"); bool reverse; if (!matchPattern(op.getReverse(), m_TorchConstantBool(&reverse))) return rewriter.notifyMatchFailure( op, "Expected reverse arg to be constant bool."); std::sort(listElements.begin(), listElements.end()); if (reverse) std::reverse(listElements.begin(), listElements.end()); SmallVector sortedListElements; for (int64_t elem : listElements) sortedListElements.push_back(rewriter.create( op->getLoc(), rewriter.getI64IntegerAttr(elem))); Value result = rewriter.create( op->getLoc(), Torch::ListType::get(rewriter.getType()), sortedListElements); op.getSelf().replaceAllUsesWith(result); rewriter.eraseOp(op); return success(); }); } //===----------------------------------------------------------------------===// // AtenSortOp //===----------------------------------------------------------------------===// LogicalResult AtenSortOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { auto operand = getSelf(); auto operandType = dyn_cast(operand.getType()); if (!operandType || !operandType.hasSizes()) return failure(); // only ValueTensorType has toBuiltinTensor auto indicesTensorType = dyn_cast(getResult(1).getType()); if (!indicesTensorType) return failure(); if (!indicesTensorType.hasDtype()) return failure(); auto indicesType = indicesTensorType.toBuiltinTensor().clone(indicesTensorType.getDtype()); if (!indicesType || !indicesType.hasStaticShape()) return failure(); bool unaryDim = false; IntegerAttr dimAttribute = dyn_cast_if_present(adaptor.getDim()); if (!dimAttribute) return failure(); int64_t dimInt = dimAttribute.getValue().getSExtValue(); if (dimInt < 0) dimInt += operandType.getSizes().size(); if (dimAttribute) { unaryDim = operandType.getSizes()[dimInt] == 1; } OpBuilder builder(getContext()); if (unaryDim || llvm::all_of(operandType.getSizes(), [](int64_t dim) { return dim == 1; })) { results.push_back(operand); results.push_back(DenseElementsAttr::get( indicesType, builder.getZeroAttr(indicesType.getElementType()))); return success(); } return failure(); } //===----------------------------------------------------------------------===// // NonValueTensorLiteralOp //===----------------------------------------------------------------------===// LogicalResult NonValueTensorLiteralOp::inferReturnTypes( MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { auto attr = properties.as() ->getValue() .dyn_cast_or_null(); if (!attr) return failure(); RankedTensorType tensorType = attr.getType().cast(); NonValueTensorType returnType = NonValueTensorType::get(tensorType.getContext(), tensorType.getShape(), tensorType.getElementType()); inferredReturnTypes.push_back(returnType); return success(); } static bool areSizesAndDtypesCompatible(BaseTensorType a, BaseTensorType b) { if (a.hasSizes() && b.hasSizes()) { if (failed(verifyCompatibleShape(makeShapeLLVMCompatible(a.getSizes()), makeShapeLLVMCompatible(b.getSizes())))) return false; } if (a.hasDtype() && b.hasDtype()) { if (a.getDtype() != b.getDtype()) return false; } return true; } bool NonValueTensorLiteralOp::isCompatibleReturnTypes(TypeRange inferred, TypeRange actual) { if (!actual[0].isa()) return false; return areSizesAndDtypesCompatible(inferred[0].cast(), actual[0].cast()); } //===----------------------------------------------------------------------===// // ValueTensorLiteralOp //===----------------------------------------------------------------------===// LogicalResult ValueTensorLiteralOp::inferReturnTypes( MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { auto attr = properties.as() ->getValue() .dyn_cast_or_null(); if (!attr) return failure(); RankedTensorType tensorType = attr.getType().cast(); ValueTensorType returnType = ValueTensorType::get(tensorType.getContext(), tensorType.getShape(), tensorType.getElementType()); inferredReturnTypes.push_back(returnType); return success(); } OpFoldResult ValueTensorLiteralOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } //----------------------------------------------------------------------------// // TensorStaticInfoCast //----------------------------------------------------------------------------// bool TensorStaticInfoCastOp::areCastCompatible(mlir::TypeRange inputs, mlir::TypeRange outputs) { return areSizesAndDtypesCompatible(inputs[0].cast(), outputs[0].cast()); } void TensorStaticInfoCastOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](TensorStaticInfoCastOp op, PatternRewriter &rewriter) { auto reverseCast = op.getOperand().getDefiningOp(); if (!reverseCast || reverseCast.getOperand().getType() != op.getType()) return failure(); rewriter.replaceOp(op, reverseCast.getOperand()); return success(); }); patterns.add(+[](TensorStaticInfoCastOp op, PatternRewriter &rewriter) { if (isValidSubtype(op.getOperand().getType(), op.getType())) { SmallVector> usesToChange( llvm::make_filter_range(op->getUses(), [](OpOperand &operand) { return operand.getOwner() ->hasTrait(); })); if (usesToChange.empty()) return failure(); for (OpOperand &use : usesToChange) { Operation *user = use.getOwner(); user->setOperand(use.getOperandNumber(), op.getOperand()); } return success(); } return failure(); }); } //===----------------------------------------------------------------------===// // CopyToNonValueTensorOp //===----------------------------------------------------------------------===// LogicalResult CopyToNonValueTensorOp::verify() { auto resultType = getResult().getType().cast(); auto operandType = getOperand().getType().cast(); if (!resultType.hasSameSizesAndDtype(operandType)) return emitError() << "operand and result must have same sizes and dtype"; return success(); } LogicalResult CopyToNonValueTensorOp::inferReturnTypes( MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { auto resultType = operands[0].getType().cast(); inferredReturnTypes.push_back(resultType.getWithoutValueSemantics()); return success(); } void CopyToNonValueTensorOp::getEffects( SmallVectorImpl> &effects) { effects.emplace_back(MemoryEffects::Allocate::get(), getResult()); } //===----------------------------------------------------------------------===// // CopyToValueTensorOp //===----------------------------------------------------------------------===// LogicalResult CopyToValueTensorOp::verify() { auto resultType = getResult().getType().cast(); auto operandType = getOperand().getType().cast(); if (!resultType.hasSameSizesAndDtype(operandType)) return emitError() << "operand and result must have same sizes and dtype"; return success(); } LogicalResult CopyToValueTensorOp::inferReturnTypes( MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { auto resultType = operands[0].getType().cast(); inferredReturnTypes.push_back(resultType.getWithValueSemantics()); return success(); } void CopyToValueTensorOp::getEffects( SmallVectorImpl> &effects) { effects.emplace_back(MemoryEffects::Read::get(), getOperand()); } //===----------------------------------------------------------------------===// // ConstantNoneOp //===----------------------------------------------------------------------===// OpFoldResult ConstantNoneOp::fold(FoldAdaptor adaptor) { return TypeAttr::get(Torch::NoneType::get(getContext())); } void ConstantNoneOp::getAsmResultNames( function_ref setNameFn) { setNameFn(getResult(), "none"); } //===----------------------------------------------------------------------===// // ConstantStrOp //===----------------------------------------------------------------------===// OpFoldResult ConstantStrOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } void ConstantStrOp::getAsmResultNames( function_ref setNameFn) { setNameFn(getResult(), "str"); } //===----------------------------------------------------------------------===// // ConstantDeviceOp //===----------------------------------------------------------------------===// void ConstantDeviceOp::getAsmResultNames( function_ref setNameFn) { setNameFn(getResult(), getValue()); } //===----------------------------------------------------------------------===// // ConstantIntOp //===----------------------------------------------------------------------===// ParseResult ConstantIntOp::parse(OpAsmParser &parser, OperationState &result) { Builder builder(result.getContext()); result.addTypes(builder.getType()); if (parser.parseOptionalAttrDict(result.attributes)) return failure(); int64_t value; if (parser.parseInteger(value)) return failure(); result.addAttribute("value", builder.getI64IntegerAttr(value)); return success(); } void ConstantIntOp::print(OpAsmPrinter &p) { p << " "; p << getValueAttr().getInt(); p.printOptionalAttrDict((*this)->getAttrs(), {"value"}); } OpFoldResult Torch::ConstantIntOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } void Torch::ConstantIntOp::getAsmResultNames( function_ref setNameFn) { SmallVector buf; llvm::raw_svector_ostream os(buf); os << "int" << getValueAttr().getInt(); setNameFn(getResult(), os.str()); } //===----------------------------------------------------------------------===// // ConstantFloatOp //===----------------------------------------------------------------------===// OpFoldResult Torch::ConstantFloatOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } void Torch::ConstantFloatOp::getAsmResultNames( function_ref setNameFn) { // Calculate a stringified version of the number, compatible with MLIR // identifier syntax. (in practice, this just removes the '+' from 'e+' in // float string representation). SmallVector buf; getValue().toString(buf, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0, /*TruncateZero=*/false); auto isValidMLIRIdentifierChar = [](char c) { return isalpha(c) || isdigit(c) || c == '_' || c == '$' || c == '.' || c == '-'; }; auto numberStr = llvm::to_vector<16>( llvm::make_filter_range(buf, isValidMLIRIdentifierChar)); // Construct the identifier string. buf.clear(); llvm::append_range(buf, StringRef("float")); llvm::append_range(buf, numberStr); setNameFn(getResult(), StringRef(buf.data(), buf.size())); } //===----------------------------------------------------------------------===// // ConstantNumberOp //===----------------------------------------------------------------------===// OpFoldResult Torch::ConstantNumberOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } void Torch::ConstantNumberOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](Torch::ConstantNumberOp op, PatternRewriter &rewriter) { Location loc = op->getLoc(); Value constValue; Attribute value = op.getValueAttr(); if (auto floatValue = value.dyn_cast()) { constValue = rewriter.create(loc, floatValue); } else if (auto intValue = value.dyn_cast()) { constValue = rewriter.create(loc, intValue); } else { return failure(); } rewriter.replaceOpWithNewOp(op, op.getType(), constValue); return success(); }); } //===----------------------------------------------------------------------===// // ConstantBoolOp //===----------------------------------------------------------------------===// OpFoldResult Torch::ConstantBoolOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } void Torch::ConstantBoolOp::getAsmResultNames( function_ref setNameFn) { setNameFn(getResult(), getValue() ? "true" : "false"); } //===----------------------------------------------------------------------===// // PrimUncheckedCastOp //===----------------------------------------------------------------------===// bool PrimUncheckedCastOp::areCastCompatible(mlir::TypeRange inputs, mlir::TypeRange outputs) { return isValidSubtype(outputs[0], inputs[0]); } OpFoldResult PrimUncheckedCastOp::fold(FoldAdaptor adaptor) { if (auto derefineOp = getX().getDefiningOp()) { if (derefineOp.getOperand().getType() == getType()) return derefineOp.getOperand(); } return nullptr; } //===----------------------------------------------------------------------===// // Aten__Getitem__TOp //===----------------------------------------------------------------------===// void Aten__Getitem__TOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](Aten__Getitem__TOp op, PatternRewriter &rewriter) { auto torchList = op.getOperand(0); if (isListPotentiallyMutated(torchList)) return failure(); auto listConstruct = torchList.getDefiningOp(); if (!listConstruct) return failure(); // Get the index, but be careful because it might be statically invalid. std::optional indexOpt = matchLegalConstantIndexIntoListOfSize( op.getOperand(1), listConstruct.getNumOperands()); if (!indexOpt) return rewriter.notifyMatchFailure(op, "statically invalid index"); rewriter.replaceOp(op, {listConstruct.getOperand(*indexOpt)}); return success(); }); patterns.add(+[](Aten__Getitem__TOp op, PatternRewriter &rewriter) { auto sizeOp = op.getList().getDefiningOp(); if (!sizeOp) return failure(); // This assumes tht the size doesn't change between the // AtenSizeOp and the Aten__Getitem__TOp. // `t_` is the only op I can find that changes the shape in-place. It seems // like otherwise we can treat the size of a tensor as having value // semantics. The other view-like ops don't have in-place variants -- // they always return a new SSA value that is aliased to the input. // Can we have a pass to normalize the `t_` case and then elsewhere in the // compiler treat the size as having value semantics? // There's a small number of such ops, and they are marked as `inplace_view` // in PyTorch's `native_functions.yaml` file. rewriter.replaceOpWithNewOp(op, sizeOp.getSelf(), op.getIdx()); return success(); }); } //===----------------------------------------------------------------------===// // AtenIsFloatingPointOp //===----------------------------------------------------------------------===// OpFoldResult AtenIsFloatingPointOp::fold(FoldAdaptor adaptor) { auto operandType = getSelf().getType().dyn_cast(); if (!operandType) return nullptr; if (operandType.hasDtype()) { bool isFloatType = operandType.getDtype().isa(); return IntegerAttr::get(IntegerType::get(getContext(), 1), isFloatType); } // doesn't has dtype return nullptr; } //===----------------------------------------------------------------------===// // AtenAddTOp //===----------------------------------------------------------------------===// void AtenAddTOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenAddTOp op, PatternRewriter &rewriter) { auto lhsListConstruct = op.getA().getDefiningOp(); if (!lhsListConstruct || isListPotentiallyMutated(lhsListConstruct)) return failure(); auto rhsListConstruct = op.getB().getDefiningOp(); if (!rhsListConstruct || isListPotentiallyMutated(rhsListConstruct)) return failure(); SmallVector concatenatedList; for (auto a : lhsListConstruct.getOperands()) { concatenatedList.push_back(a); } for (auto b : rhsListConstruct.getOperands()) { concatenatedList.push_back(b); } rewriter.replaceOpWithNewOp(op, op.getType(), concatenatedList); return success(); }); } //===----------------------------------------------------------------------===// // AtenSliceTOp //===----------------------------------------------------------------------===// void AtenSliceTOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenSliceTOp op, PatternRewriter &rewriter) { auto valueList = op.getL(); auto listConstructOp = valueList.getDefiningOp(); if (!listConstructOp || isListPotentiallyMutated(listConstructOp)) { return failure(); } SmallVector listElements = llvm::to_vector<4>(listConstructOp.getElements()); int64_t size = static_cast(listElements.size()); int64_t start; int64_t end; int64_t step; if (op.getStart().getType().isa()) { start = 0; } else if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) { return failure(); } if (op.getEnd().getType().isa()) { end = listElements.size(); } else if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) { return failure(); } if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) { return failure(); } start = start >= 0 ? start : start + size; start = start >= 0 ? start : 0; end = end >= 0 ? end : end + size; end = end < size ? end : size; SmallVector newListElements; for (int64_t i = start; i < end; i += step) { newListElements.push_back(listElements[i]); } rewriter.replaceOpWithNewOp( op, Torch::ListType::get(listElements[0].getType()), newListElements); return success(); }); } //===----------------------------------------------------------------------===// // AtenEqIntListOp //===----------------------------------------------------------------------===// OpFoldResult AtenEqIntListOp::fold(FoldAdaptor adaptor) { auto lhsLiteral = getA().getDefiningOp(); if (!lhsLiteral) return nullptr; auto rhsLiteral = getB().getDefiningOp(); if (!rhsLiteral) return nullptr; // If the sizes don't match, then we know the lists aren't equal. if (lhsLiteral.getNumOperands() != rhsLiteral.getNumOperands()) return getI1IntegerAttr(getContext(), false); // If the sizes match and all corresponding list elements are the same Value, // then we know the lists are equal. // Note that we can't prove that the lists are not-equal with this method, // since two different Value's might dynamically be equal. if (llvm::all_of( llvm::zip(lhsLiteral.getOperands(), rhsLiteral.getOperands()), [](const auto &pair) { return std::get<0>(pair) == std::get<1>(pair); })) return getI1IntegerAttr(getContext(), true); return nullptr; } //===----------------------------------------------------------------------===// // PrimTupleConstructOp //===----------------------------------------------------------------------===// LogicalResult PrimTupleConstructOp::verify() { if (!(isValidSubtype( Torch::TupleType::get(getContext(), llvm::to_vector<6>(getElements().getType())), getResult().getType()))) return emitOpError( "failed to verify that contained types correspond to operand types"); return success(); } //===----------------------------------------------------------------------===// // PrimTupleIndexOp //===----------------------------------------------------------------------===// void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](PrimTupleIndexOp op, PatternRewriter &rewriter) { auto tupleConstruct = op.getTup().getDefiningOp(); if (!tupleConstruct) return failure(); int64_t i; if (!matchPattern(op.getI(), m_TorchConstantInt(&i))) return failure(); if (i >= (int64_t)tupleConstruct.getElements().size()) return failure(); // TODO: We should have a clear picture of whether we want to consistently // allow refinement, and where. It seems desirable to require precise // type equality for TupleConstruct / TupleIndex, but that might break // things. Value replacement = tupleConstruct.getElements()[i]; if (replacement.getType() != op.getType()) { if (op.getType().isa()) { replacement = rewriter.create( op.getLoc(), op.getType(), replacement); } else { return failure(); } } rewriter.replaceOp(op, replacement); return success(); }); } //===----------------------------------------------------------------------===// // PrimUninitializedOp //===----------------------------------------------------------------------===// void PrimUninitializedOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](PrimUninitializedOp op, PatternRewriter &rewriter) { if (!op.use_empty()) return failure(); rewriter.eraseOp(op); return success(); }); } //===----------------------------------------------------------------------===// // PrimTupleUnpackOp //===----------------------------------------------------------------------===// void PrimTupleUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](PrimTupleUnpackOp op, PatternRewriter &rewriter) { auto tupleConstruct = op.getTup().getDefiningOp(); if (!tupleConstruct) return failure(); llvm::SmallVector derefinedElements; // The result types may be supertypes of the tuple element types. // Ensure we maintain the exact type, with identity `derefine`s being // folded. for (auto [type, element] : llvm::zip(op.getResultTypes(), tupleConstruct.getElements())) { derefinedElements.push_back( rewriter.createOrFold(op.getLoc(), type, element)); } rewriter.replaceOp(op, derefinedElements); return success(); }); } //===----------------------------------------------------------------------===// // PrimListUnpackOp //===----------------------------------------------------------------------===// void PrimListUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](PrimListUnpackOp op, PatternRewriter &rewriter) { auto torchList = op.getOperand(); if (isListPotentiallyMutated(torchList)) { return failure(); } auto listConstruct = torchList.getDefiningOp(); if (!listConstruct) return failure(); rewriter.replaceOp(op, listConstruct.getElements()); return success(); }); } static PrimDictConstructOp getDictConstructIfNotModified(Value torchDict) { if (!llvm::all_of(torchDict.getUsers(), [](Operation *op) { return isa(op); })) return nullptr; return torchDict.getDefiningOp(); } //===----------------------------------------------------------------------===// // Aten__Getitem__DictStrOp //===----------------------------------------------------------------------===// OpFoldResult Aten__Getitem__DictStrOp::fold(FoldAdaptor adaptor) { auto dictConstruct = getDictConstructIfNotModified(getSelf()); if (!dictConstruct) return nullptr; auto targetKey = getKey(); for (auto i : llvm::zip(dictConstruct.getKeys(), dictConstruct.getValues())) { auto k = std::get<0>(i); if (k == targetKey) return std::get<1>(i); } return nullptr; } //===----------------------------------------------------------------------===// // Aten__Contains__StrOp //===----------------------------------------------------------------------===// OpFoldResult Aten__Contains__StrOp::fold(FoldAdaptor adaptor) { auto dictConstruct = getDictConstructIfNotModified(getDict()); if (!dictConstruct) return nullptr; auto targetKey = getKey(); for (auto key : dictConstruct.getKeys()) { if (key == targetKey) return getI1IntegerAttr(getContext(), true); } return nullptr; } //===----------------------------------------------------------------------===// // Aten__Contains__IntListOp //===----------------------------------------------------------------------===// static bool isListConstructNotModified(Value torchList) { return llvm::all_of(torchList.getUsers(), [](Operation *op) { return isa(op); }); } OpFoldResult Aten__Contains__IntListOp::fold(FoldAdaptor adaptor) { auto itemConstruct = getItem(); if (!isListConstructNotModified(getL())) return nullptr; int64_t item; SmallVector list; if (!matchPattern(itemConstruct, m_TorchConstantInt(&item))) return nullptr; if (!matchPattern(getL(), m_TorchListOfConstantInts(list))) return nullptr; for (auto elem : list) { if (elem == item) return getI1IntegerAttr(getContext(), true); } return getI1IntegerAttr(getContext(), false); } using BinaryIntOperatorFn = std::function; static OpFoldResult atenBinaryIntOperatorFoldHelper(ArrayRef operands, BinaryIntOperatorFn f) { auto intLhs = operands[0].dyn_cast_or_null(); auto intRhs = operands[1].dyn_cast_or_null(); if (!intLhs || !intRhs) { return nullptr; } return IntegerAttr::get( intLhs.getType(), f(intLhs.getValue().getSExtValue(), intRhs.getValue().getSExtValue())); } using BinaryFloatOperatorFn = std::function; static OpFoldResult atenBinaryFloatOperatorFoldHelper(ArrayRef operands, BinaryFloatOperatorFn f) { double lhs, rhs; auto parseDoubleAttribute = [](Attribute attr, double &value) -> bool { if (auto intLhs = attr.dyn_cast_or_null()) { value = static_cast(intLhs.getValue().getSExtValue()); } else if (auto floatLhs = attr.dyn_cast_or_null()) { value = floatLhs.getValue().convertToDouble(); } else { return false; } return true; }; if (!parseDoubleAttribute(operands[0], lhs) || !parseDoubleAttribute(operands[1], rhs)) { return nullptr; } return getF64FloatAttr(operands[0].getContext(), f(lhs, rhs)); } //===----------------------------------------------------------------------===// // AtenAliasOp //===----------------------------------------------------------------------===// OpFoldResult AtenAliasOp::fold(FoldAdaptor adaptor) { return getOperand(); } //===----------------------------------------------------------------------===// // AtenFloordivIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenFloordivIntOp::fold(FoldAdaptor adaptor) { return atenBinaryIntOperatorFoldHelper( adaptor.getOperands(), [](int64_t a, int64_t b) { return std::floor(a / (double)b); }); } //===----------------------------------------------------------------------===// // AtenRemainderIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenRemainderIntOp::fold(FoldAdaptor adaptor) { return atenBinaryIntOperatorFoldHelper( adaptor.getOperands(), [](int64_t a, int64_t b) { return a % b; }); } //===----------------------------------------------------------------------===// // AtenAddIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenAddIntOp::fold(FoldAdaptor adaptor) { return atenBinaryIntOperatorFoldHelper( adaptor.getOperands(), [](int64_t a, int64_t b) { return a + b; }); } //===----------------------------------------------------------------------===// // AtenSubIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) { return atenBinaryIntOperatorFoldHelper( adaptor.getOperands(), [](int64_t a, int64_t b) { return a - b; }); } //===----------------------------------------------------------------------===// // AtenCatOp //===----------------------------------------------------------------------===// OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) { // We set a maximum folding size of 16. This is a reasonable upper limit // for shape computations. constexpr int64_t kMaxFoldSize = 16; auto list = getOperand(0).getDefiningOp(); if (!list) return nullptr; auto elements = list.getElements(); if (elements.size() == 1 && elements[0].getType() == getResult().getType()) return list.getElements()[0]; auto resultTy = dyn_cast(getType()); if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype()) return nullptr; auto bResultTy = resultTy.toBuiltinTensor(); if (!bResultTy.hasStaticShape() || bResultTy.getNumElements() > kMaxFoldSize) return nullptr; auto dimAttr = dyn_cast_or_null(adaptor.getDim()); if (!dimAttr) return nullptr; auto dim = dimAttr.getValue().getSExtValue(); dim += dim < 0 ? bResultTy.getRank() : 0; for (int i = 0, s = bResultTy.getRank(); i < s; ++i) { if (i == dim) continue; if (bResultTy.getDimSize(i) != 1) return nullptr; } llvm::SmallVector values; for (auto operand : list.getOperands()) { DenseElementsAttr dattr; if (!matchPattern(operand, m_Constant(&dattr))) return nullptr; auto oty = dyn_cast(dattr.getType()); if (!oty) return nullptr; if (dattr.isSplat()) { for (int i = 0, s = oty.getDimSize(dim); i < s; ++i) values.push_back(dattr.getSplatValue()); } else { auto evals = dattr.getValues(); for (int i = 0, s = oty.getDimSize(dim); i < s; ++i) values.push_back(evals[i]); } } return DenseElementsAttr::get(bResultTy.clone(resultTy.getDtype()), values); } void AtenCatOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenCatOp op, PatternRewriter &rewriter) { auto list = op.getTensors().getDefiningOp(); auto resultTy = dyn_cast(op.getType()); if (!list || !resultTy) return failure(); int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return failure(); llvm::SmallVector filtered; for (auto operand : list.getOperands()) { auto operandTy = dyn_cast(operand.getType()); if (!operandTy || !operandTy.hasSizes()) return failure(); int64_t adim = dim < 0 ? dim + operandTy.getSizes().size() : dim; if (operandTy.getSizes()[adim] != 0) filtered.push_back(operand); } if (filtered.size() == list.getNumOperands()) return failure(); auto newlist = rewriter.create( op.getLoc(), list.getType(), filtered); rewriter.replaceOpWithNewOp(op, op.getType(), newlist, op.getDim()); return success(); }); } //===----------------------------------------------------------------------===// // AtenBroadcastToOp //===----------------------------------------------------------------------===// OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) { auto inType = getOperand(0).getType().dyn_cast(); auto outType = getResult().getType().dyn_cast(); if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() || !outType.hasDtype()) return nullptr; if (!inType.areAllSizesKnown() || !outType.areAllSizesKnown()) return nullptr; auto inSizes = inType.getSizes(); auto outSizes = outType.getSizes(); if (inSizes.size() == outSizes.size()) { bool sameSizes = true; for (int i = 0, s = inSizes.size(); i < s; ++i) sameSizes &= inSizes[i] == outSizes[i]; if (sameSizes) return getOperand(0); } auto selfAttr = dyn_cast_or_null(adaptor.getSelf()); if (!selfAttr) return nullptr; if (!selfAttr.isSplat()) return nullptr; auto attrty = RankedTensorType::get(outType.getSizes(), outType.getDtype()); return DenseElementsAttr::get(attrty, selfAttr.getSplatValue()); } //===----------------------------------------------------------------------===// // AtenSliceTensorOp //===----------------------------------------------------------------------===// OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { DenseElementsAttr input = dyn_cast_or_null(adaptor.getSelf()); IntegerAttr start = dyn_cast_or_null(adaptor.getStart()); IntegerAttr end = dyn_cast_or_null(adaptor.getEnd()); IntegerAttr step = dyn_cast_or_null(adaptor.getStep()); IntegerAttr dim = dyn_cast_or_null(adaptor.getDim()); if (start && end && step && step.getValue().getSExtValue() == 1 && start.getValue().getSExtValue() == 0 && end.getValue().getSExtValue() == std::numeric_limits::max()) return getOperand(0); auto inType = getOperand(0).getType().dyn_cast(); auto outType = getResult().getType().dyn_cast(); if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() || !inType.hasDtype() || !outType.hasDtype() || inType.getDtype() != outType.getDtype()) return nullptr; if (inType.getSizes().size() != outType.getSizes().size() || !inType.areAllSizesKnown() || !outType.areAllSizesKnown()) return nullptr; if (input && input.isSplat()) return DenseElementsAttr::get( outType.toBuiltinTensor().clone(inType.getDtype()), input.getSplatValue()); int count = 1; for (auto dim : outType.getSizes()) count = count * dim; if (count == 0) return {}; if (!dim) return nullptr; int64_t dimInt = dim.getValue().getSExtValue(); if (dimInt < 0) dimInt += inType.getSizes().size(); bool unaryNonDim = true; for (int i = 0, s = outType.getSizes().size(); i < s; ++i) unaryNonDim &= outType.getSizes()[i] == 1 || i == dimInt; // Fold the slice if the output tensor is relatively small, currently // coded to 16: if (input && start && step && dim && count < 16 && unaryNonDim && count < 16) { int64_t inCount = input.getNumElements(); int64_t begin = start.getValue().getSExtValue(); int64_t stride = step.getValue().getSExtValue(); if (stride < 1) return {}; int64_t limit = end.getValue().getSExtValue(); begin = begin < 0 ? begin + inCount : begin; limit = limit < 0 ? limit + inCount : limit; limit = limit < 0 ? inType.getSizes()[dimInt] : limit; limit = std::min(limit, inType.getSizes()[dimInt]); llvm::SmallVector values; for (int i = begin; i < limit; i += stride) values.push_back(input.getValues()[i]); return DenseElementsAttr::get( outType.toBuiltinTensor().clone(inType.getDtype()), values); } // If the input and output shapes are the same we can just fold: for (size_t i = 0; i < inType.getSizes().size(); ++i) { if (inType.getSizes()[i] != outType.getSizes()[i]) return nullptr; } return getOperand(0); } //===----------------------------------------------------------------------===// // AtenMulIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) { int64_t lhs, rhs; bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs)); bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs)); if ((lConstant && lhs == 0) || (rConstant && rhs == 0)) return getI64IntegerAttr(getContext(), 0); if (lConstant && rConstant) return getI64IntegerAttr(getContext(), lhs * rhs); return nullptr; } //===----------------------------------------------------------------------===// // AtenMulFloatOp //===----------------------------------------------------------------------===// OpFoldResult AtenMulFloatOp::fold(FoldAdaptor adaptor) { return atenBinaryFloatOperatorFoldHelper( adaptor.getOperands(), [](double a, double b) { return a * b; }); } //===----------------------------------------------------------------------===// // AtenSubFloatOp //===----------------------------------------------------------------------===// OpFoldResult AtenSubFloatOp::fold(FoldAdaptor adaptor) { return atenBinaryFloatOperatorFoldHelper( adaptor.getOperands(), [](double a, double b) { return a - b; }); } //===----------------------------------------------------------------------===// // AtenAddOp //===----------------------------------------------------------------------===// OpFoldResult AtenAddOp::fold(FoldAdaptor adaptor) { if (!adaptor.getA() || !adaptor.getB()) { return nullptr; } if (adaptor.getA().isa() && adaptor.getB().isa()) { return atenBinaryIntOperatorFoldHelper( adaptor.getOperands(), [](int64_t a, int64_t b) -> int64_t { return a + b; }); } return atenBinaryFloatOperatorFoldHelper( adaptor.getOperands(), [](double a, double b) -> double { return a + b; }); } //===----------------------------------------------------------------------===// // AtenMulOp //===----------------------------------------------------------------------===// OpFoldResult AtenMulOp::fold(FoldAdaptor adaptor) { if (!adaptor.getA() || !adaptor.getB()) { return nullptr; } if (adaptor.getA().isa() && adaptor.getB().isa()) { return atenBinaryIntOperatorFoldHelper( adaptor.getOperands(), [](int64_t a, int64_t b) -> int64_t { return a * b; }); } return atenBinaryFloatOperatorFoldHelper( adaptor.getOperands(), [](double a, double b) -> double { return a * b; }); } //===----------------------------------------------------------------------===// // AtenSubOp //===----------------------------------------------------------------------===// OpFoldResult AtenSubOp::fold(FoldAdaptor adaptor) { if (!adaptor.getA() || !adaptor.getB()) { return nullptr; } if (adaptor.getA().isa() && adaptor.getB().isa()) { return atenBinaryIntOperatorFoldHelper( adaptor.getOperands(), [](int64_t a, int64_t b) -> int64_t { return a - b; }); } return atenBinaryFloatOperatorFoldHelper( adaptor.getOperands(), [](double a, double b) -> double { return a - b; }); } //===----------------------------------------------------------------------===// // AtenDivOp //===----------------------------------------------------------------------===// OpFoldResult AtenDivOp::fold(FoldAdaptor adaptor) { if (!adaptor.getA() || !adaptor.getB()) { return nullptr; } // Since AtenDivOp always returns float value, we don't need to deal with the // case where the operands are both integers separately. return atenBinaryFloatOperatorFoldHelper( adaptor.getOperands(), [](double a, double b) -> double { return a / b; }); } //===----------------------------------------------------------------------===// // AtenAddFloatIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenAddFloatIntOp::fold(FoldAdaptor adaptor) { if (!adaptor.getA() || !adaptor.getB()) { return nullptr; } return atenBinaryFloatOperatorFoldHelper( adaptor.getOperands(), [](double a, double b) { return a + b; }); } //===----------------------------------------------------------------------===// // AtenPowIntFloatOp //===----------------------------------------------------------------------===// OpFoldResult AtenPowIntFloatOp::fold(FoldAdaptor adaptor) { if (!adaptor.getA() || !adaptor.getB()) { return nullptr; } return atenBinaryFloatOperatorFoldHelper( adaptor.getOperands(), [](double a, double b) { return std::pow(a, b); }); } //===----------------------------------------------------------------------===// // AtenCeilScalarOp //===----------------------------------------------------------------------===// OpFoldResult AtenCeilScalarOp::fold(FoldAdaptor adaptor) { if (!adaptor.getA()) { return nullptr; } auto floatValue = adaptor.getA().dyn_cast_or_null(); if (!floatValue) { return nullptr; } return getI64IntegerAttr( getContext(), static_cast(std::ceil(floatValue.getValue().convertToDouble()))); } //===----------------------------------------------------------------------===// // AtenNegIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenNegIntOp::fold(FoldAdaptor adaptor) { int64_t c; if (matchPattern(getOperand(), m_TorchConstantInt(&c))) return getI64IntegerAttr(getContext(), -c); return nullptr; } //===----------------------------------------------------------------------===// // AtenNegFloatOp //===----------------------------------------------------------------------===// OpFoldResult AtenNegFloatOp::fold(FoldAdaptor adaptor) { if (!adaptor.getA()) { return nullptr; } auto value = adaptor.getA().dyn_cast_or_null(); if (!value) { return nullptr; } return getF64FloatAttr(getContext(), -value.getValue().convertToDouble()); } //===----------------------------------------------------------------------===// // AtenSqrtIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenSqrtIntOp::fold(FoldAdaptor adaptor) { int64_t c; if (matchPattern(getOperand(), m_TorchConstantInt(&c))) return getF64FloatAttr(getContext(), std::sqrt(c)); return nullptr; } //===----------------------------------------------------------------------===// // PrimDtypeOp //===----------------------------------------------------------------------===// OpFoldResult PrimDtypeOp::fold(FoldAdaptor adaptor) { BaseTensorType tensorType = getA().getType().cast(); if (tensorType.hasDtype()) { torch_upstream::ScalarType scalarType = Torch::getScalarTypeForType(tensorType.getDtype()); return getI64IntegerAttr(getContext(), static_cast(scalarType)); } return nullptr; } //===----------------------------------------------------------------------===// // PrimDeviceOp //===----------------------------------------------------------------------===// void PrimDeviceOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](PrimDeviceOp op, PatternRewriter &rewriter) { // Device information isn't relevant to torch-mlir, just replace it with // "cpu". rewriter.replaceOpWithNewOp(op, "cpu"); return success(); }); } //===----------------------------------------------------------------------===// // AtenCudaOp //===----------------------------------------------------------------------===// void AtenCudaOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenCudaOp op, PatternRewriter &rewriter) { // Device information isn't relevant to torch-mlir auto inputTensor = op.getSelf(); rewriter.replaceOp(op, inputTensor); return success(); }); } //===----------------------------------------------------------------------===// // AtenDeviceWithIndexOp //===----------------------------------------------------------------------===// void AtenDeviceWithIndexOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenDeviceWithIndexOp op, PatternRewriter &rewriter) { std::string type; int64_t index; if (!matchPattern(op.getType(), m_TorchConstantStr(type))) { return rewriter.notifyMatchFailure( op, "unimplemented: type must be a constant string"); } if (!matchPattern(op.getIndex(), m_TorchConstantInt(&index))) { return rewriter.notifyMatchFailure( op, "unimplemented: index must be a constant integer"); } rewriter.replaceOpWithNewOp( op, type + ":" + std::to_string(index)); return success(); }); } //===----------------------------------------------------------------------===// // AtenTensorOp //===----------------------------------------------------------------------===// OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) { // If a torch.aten.tensor op is initialized by a list with a constant, single // element, fold it into a torch.vtensor.literal auto resultTy = dyn_cast(getType()); Type eTy = resultTy.getDtype(); ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy); SmallVector data; if (matchPattern(getData(), m_TorchListOfConstantInts(data)) && data.size() == 1) { Attribute attribute = IntegerAttr::get(eTy, data[0]); return DenseElementsAttr::get(shapedTy, attribute); } return nullptr; } //===----------------------------------------------------------------------===// // AtenTensorOp //===----------------------------------------------------------------------===// OpFoldResult Aten_ShapeAsTensorOp::fold(FoldAdaptor adaptor) { auto selfTy = dyn_cast(getSelf().getType()); auto resultTy = dyn_cast(getType()); if (!selfTy || !resultTy || !selfTy.hasSizes() || !resultTy.hasDtype() || !resultTy.hasSizes()) return {}; llvm::SmallVector values(selfTy.getSizes()); if (llvm::any_of(values, [](int64_t d) { return d == Torch::kUnknownSize; })) return {}; auto dty = dyn_cast(resultTy.getDtype()); if (!dty) return {}; llvm::SmallVector attrs; for (auto val : values) { attrs.push_back(IntegerAttr::get(dty, val)); } auto attrty = RankedTensorType::get(resultTy.getSizes(), dty); return DenseElementsAttr::get(attrty, attrs); } //===----------------------------------------------------------------------===// // AtenIntTensorOp //===----------------------------------------------------------------------===// OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) { // If a scalar number is converted to a 0-d tensor and passed on to // aten.Int.Tensor, fold to the scalar number. if (auto numToTensorScalar = getA().getDefiningOp()) return numToTensorScalar.getA(); if (auto tensorIntOp = getA().getDefiningOp()) return tensorIntOp.getT(); return nullptr; } //===----------------------------------------------------------------------===// // AtenFloatTensorOp //===----------------------------------------------------------------------===// OpFoldResult AtenFloatTensorOp::fold(FoldAdaptor adaptor) { // If a scalar number is converted to a 0-d tensor and passed on to // aten.Float.Tensor, fold to the scalar number. if (auto numToTensorScalar = getA().getDefiningOp()) return numToTensorScalar.getA(); return nullptr; } //===----------------------------------------------------------------------===// // AtenDivFloatOp //===----------------------------------------------------------------------===// OpFoldResult AtenDivFloatOp::fold(FoldAdaptor adaptor) { double lhs, rhs; bool lConstant = matchPattern(getOperand(0), m_TorchConstantFloat(&lhs)); bool rConstant = matchPattern(getOperand(1), m_TorchConstantFloat(&rhs)); if (lConstant && lhs == 0.0) return getF64FloatAttr(getContext(), 0.0); if (lConstant && rConstant && rhs == 1.0) return getF64FloatAttr(getContext(), lhs); if (lConstant && rConstant) return getF64FloatAttr(getContext(), lhs / rhs); return nullptr; } //===----------------------------------------------------------------------===// // AtenDivIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenDivIntOp::fold(FoldAdaptor adaptor) { int64_t lhs, rhs; bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs)); bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs)); if (lConstant && rConstant) return getF64FloatAttr(getContext(), double(lhs) / rhs); return nullptr; } //===----------------------------------------------------------------------===// // AtenIndexSelectOp //===----------------------------------------------------------------------===// OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) { auto self = getSelf(); auto index = getIndex(); auto selfTy = dyn_cast(self.getType()); auto indexTy = dyn_cast(index.getType()); auto resultTy = dyn_cast(getType()); if (!selfTy || !indexTy || !resultTy || !selfTy.hasSizes() || !indexTy.hasSizes() || !resultTy.hasSizes() || !selfTy.hasDtype() || !indexTy.hasDtype() || !resultTy.hasDtype()) return nullptr; auto selfSizes = selfTy.getSizes(); auto indexSizes = indexTy.getSizes(); auto resultSizes = resultTy.getSizes(); if (selfTy.getDtype() != resultTy.getDtype() || selfSizes.size() != resultSizes.size() || indexSizes.size() != 1) return nullptr; // If the selection results in a tensor of the same dimensions as the // input, the selection must have specified every index of the input, // so the result is exactly the same as the input. bool fullTensor = true; for (int i = 0, s = selfSizes.size(); i < s; ++i) { fullTensor &= selfSizes[i] == resultSizes[i]; fullTensor &= selfSizes[i] != Torch::kUnknownSize; fullTensor &= resultSizes[i] != Torch::kUnknownSize; } if (fullTensor && indexSizes[0] == 1) return self; // If the input tensor, index dimension, or indexes are non-constant, // can't fold. auto selfAttr = dyn_cast_or_null(adaptor.getSelf()); auto dimAttr = dyn_cast_or_null(adaptor.getDim()); auto indexAttr = dyn_cast_or_null(adaptor.getIndex()); if (!selfAttr || !dimAttr || !indexAttr) return {}; // If the input's dimensions are all 1 except for one dimension, and if // there is a single index in the index list (as detected by the result // dimension being 1), then fold to a <1x1x...x1> tensor literal containing // a single element. Handles float and int types. int64_t dimInt = dimAttr.getInt(); // If the selected dim is negative, count backwards from the last dim if (dimInt < 0) dimInt = selfSizes.size() + dimInt; assert(uint64_t(dimInt) < selfSizes.size() && "Selected dim > number of dims"); for (int i = 0, s = selfSizes.size(); i < s; ++i) { if ((selfSizes[i] != 1 && i != dimInt) || resultSizes[i] != 1) return nullptr; } // Get the single index value for the selected dimension auto splatValue = indexAttr.getSplatValue(); int64_t indexInt = getIntAttrAsSigned(splatValue); indexInt = indexInt < 0 && selfSizes[dimInt] ? indexInt + selfSizes[dimInt] : indexInt; // Extract the single constant value from the input tensor and turn the // extracted value into a single-element tensor of the output shape and dtype Attribute splattr = selfAttr.isSplat() ? selfAttr.getSplatValue() : selfAttr.getValues()[indexInt]; auto dty = resultTy.getDtype(); auto attrTy = resultTy.toBuiltinTensor().clone(dty); if (auto floatAttr = dyn_cast(splattr)) return DenseElementsAttr::get( attrTy, FloatAttr::get(dty, floatAttr.getValueAsDouble())); if (auto intAttr = dyn_cast(splattr)) { return DenseElementsAttr::get(attrTy, IntegerAttr::get(dty, intAttr.getValue())); } return nullptr; } //===----------------------------------------------------------------------===// // AtenItemOp //===----------------------------------------------------------------------===// OpFoldResult AtenItemOp::fold(FoldAdaptor adaptor) { // see if we have a constant tensor DenseElementsAttr attr; if (matchPattern(getOperand(), m_Constant(&attr))) { auto splat = attr.getSplatValue(); if (auto intAttr = dyn_cast(splat)) { return getI64IntegerAttr(getContext(), intAttr.getSInt()); } if (auto floatAttr = dyn_cast(splat)) { return getF64FloatAttr(getContext(), floatAttr.getValueAsDouble()); } return nullptr; } return nullptr; } //===----------------------------------------------------------------------===// // AtenOnesOp, AtenZerosOp, AtenFullOp //===----------------------------------------------------------------------===// OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) { SmallVector sizes; if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) { return nullptr; } Type resultType = getResult().getType(); BaseTensorType resultTensorType = resultType.dyn_cast(); if (!resultTensorType || !resultTensorType.hasDtype() || !resultTensorType.hasSizes()) { return nullptr; } for (auto sz : sizes) if (sz == Torch::kUnknownSize || sz < 0) return nullptr; for (auto sz : resultTensorType.getSizes()) if (sz == Torch::kUnknownSize || sz < 0) return nullptr; ShapedType shapedty = mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType sizes, resultTensorType.getDtype()); if (!shapedty) { return nullptr; } auto elementType = shapedty.getElementType(); if (elementType.isa()) { Attribute attribute = IntegerAttr::get(elementType, 1); return DenseElementsAttr::get(shapedty, attribute); } if (elementType.isa()) { Attribute attribute = FloatAttr::get(elementType, 1.0); return DenseElementsAttr::get(shapedty, attribute); } return nullptr; } OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) { SmallVector sizes; if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) { return nullptr; } Type resultType = getResult().getType(); BaseTensorType resultTensorType = resultType.dyn_cast(); if (!resultTensorType || !resultTensorType.hasDtype() || !resultTensorType.hasSizes()) { return nullptr; } for (auto sz : sizes) if (sz == Torch::kUnknownSize || sz < 0) return nullptr; for (auto sz : resultTensorType.getSizes()) if (sz == Torch::kUnknownSize || sz < 0) return nullptr; ShapedType shapedty = mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType sizes, resultTensorType.getDtype()); if (!shapedty) { return nullptr; } auto elementType = shapedty.getElementType(); if (elementType.isa()) { Attribute attribute = IntegerAttr::get(elementType, 0); return DenseElementsAttr::get(shapedty, attribute); } if (elementType.isa()) { Attribute attribute = FloatAttr::get(elementType, 0.0); return DenseElementsAttr::get(shapedty, attribute); } return nullptr; } OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) { SmallVector sizes; if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) { return nullptr; } Type resultType = getResult().getType(); BaseTensorType resultTensorType = resultType.dyn_cast(); if (!resultTensorType || !resultTensorType.hasDtype() || !resultTensorType.hasSizes()) { return nullptr; } for (auto sz : sizes) if (sz == Torch::kUnknownSize || sz < 0) return nullptr; for (auto sz : resultTensorType.getSizes()) if (sz == Torch::kUnknownSize || sz < 0) return nullptr; ShapedType shapedty = mlir::RankedTensorType::get(sizes, resultTensorType.getDtype()); auto elementType = shapedty.getElementType(); if (elementType.isa()) { int64_t value = 0; if (matchPattern(getFillValue(), m_TorchConstantInt(&value))) { Attribute attribute = IntegerAttr::get(elementType, value); return DenseElementsAttr::get(shapedty, attribute); } } if (elementType.isa()) { double value = 0.0; if (matchPattern(getFillValue(), m_TorchConstantFloat(&value))) { Attribute attribute = FloatAttr::get(elementType, value); return DenseElementsAttr::get(shapedty, attribute); } } return nullptr; } //===----------------------------------------------------------------------===// // AtenCeilFloatOp //===----------------------------------------------------------------------===// OpFoldResult AtenCeilFloatOp::fold(FoldAdaptor adaptor) { double c; if (matchPattern(getOperand(), m_TorchConstantFloat(&c))) return getI64IntegerAttr(getContext(), std::ceil(c)); return nullptr; } //===----------------------------------------------------------------------===// // AtenWhereSelfOp //===----------------------------------------------------------------------===// static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) { if (!attr || !ty.hasDtype() || !ty.hasSizes()) return nullptr; auto dty = ty.getDtype(); if (auto valueDense = dyn_cast(attr)) { if (!valueDense.isSplat()) return nullptr; auto splattr = valueDense.getSplatValue(); auto attrty = ty.toBuiltinTensor().clone(dty); return DenseElementsAttr::get(attrty, splattr); } if (auto intAttr = dyn_cast_or_null(attr)) { if (!isa(dty)) return nullptr; int64_t intval = intAttr.getInt(); auto attrty = ty.toBuiltinTensor().clone(dty); return DenseElementsAttr::get(attrty, IntegerAttr::get(dty, intval)); } if (auto fpAttr = dyn_cast_or_null(attr)) { if (!isa(dty)) return nullptr; double dblval = fpAttr.getValueAsDouble(); auto attrty = ty.toBuiltinTensor().clone(dty); return DenseElementsAttr::get(attrty, FloatAttr::get(dty, dblval)); } return nullptr; } OpFoldResult AtenWhereSelfOp::fold(FoldAdaptor adaptor) { auto dense = dyn_cast_or_null(adaptor.getCondition()); auto resultTy = dyn_cast(getType()); if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes() || !dense || !dense.isSplat()) return nullptr; auto condattr = dense.getSplatValue(); auto value = getSelf(); auto valueAttr = adaptor.getSelf(); if (condattr.isZero()) { value = getOther(); valueAttr = adaptor.getOther(); } auto valueTy = dyn_cast(value.getType()); if (valueTy && valueTy.hasSizes() && valueTy.hasDtype() && valueTy == resultTy) return value; return getBroadcastedAttr(valueAttr, resultTy); } //===----------------------------------------------------------------------===// // AtenWhereScalarOp //===----------------------------------------------------------------------===// OpFoldResult AtenWhereScalarOp::fold(FoldAdaptor adaptor) { auto dense = dyn_cast_or_null(adaptor.getCondition()); auto resultTy = dyn_cast(getType()); if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes() || !dense || !dense.isSplat()) return nullptr; auto condattr = dense.getSplatValue(); auto valueAttr = adaptor.getSelf(); if (condattr.isZero()) { valueAttr = adaptor.getOther(); } return getBroadcastedAttr(valueAttr, resultTy); } //===----------------------------------------------------------------------===// // AtenWhereScalarOtherOp //===----------------------------------------------------------------------===// OpFoldResult AtenWhereScalarOtherOp::fold(FoldAdaptor adaptor) { auto dense = dyn_cast_or_null(adaptor.getCondition()); auto resultTy = dyn_cast(getType()); if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes() || !dense || !dense.isSplat()) return nullptr; auto condattr = dense.getSplatValue(); auto valueAttr = adaptor.getSelf(); if (condattr.isZero()) { valueAttr = adaptor.getOther(); } return getBroadcastedAttr(valueAttr, resultTy); } //===----------------------------------------------------------------------===// // AtenWhereScalarSelfOp //===----------------------------------------------------------------------===// OpFoldResult AtenWhereScalarSelfOp::fold(FoldAdaptor adaptor) { auto dense = dyn_cast_or_null(adaptor.getCondition()); auto resultTy = dyn_cast(getType()); if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes() || !dense || !dense.isSplat()) return nullptr; auto condattr = dense.getSplatValue(); auto valueAttr = adaptor.getSelf(); if (condattr.isZero()) { valueAttr = adaptor.getOther(); } return getBroadcastedAttr(valueAttr, resultTy); } //===----------------------------------------------------------------------===// // PrimMaxIntOp //===----------------------------------------------------------------------===// OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) { // If both operands are the same, then the operation is an identity. if (getA() == getB()) return getA(); auto lhs = adaptor.getA().dyn_cast_or_null(); auto rhs = adaptor.getB().dyn_cast_or_null(); if (!lhs || !rhs) return nullptr; // Torch semantics are that !torch.int is 64-bit signed. return IntegerAttr::get( lhs.getType(), std::max(lhs.getValue().getSExtValue(), rhs.getValue().getSExtValue())); } //===----------------------------------------------------------------------===// // PrimNumToTensorScalarOp //===----------------------------------------------------------------------===// OpFoldResult PrimNumToTensorScalarOp::fold(FoldAdaptor adaptor) { Attribute a = adaptor.getA(); auto resultTy = cast(getType()); if (!a) return {}; if (!resultTy.hasDtype() || !resultTy.hasSizes()) return {}; auto dty = resultTy.getDtype(); if (auto iattr = dyn_cast(a)) { a = IntegerAttr::get(dty, iattr.getInt()); } else if (auto fattr = dyn_cast(a)) { a = FloatAttr::get(dty, fattr.getValueAsDouble()); } auto mlirTensorType = RankedTensorType::get(resultTy.getSizes(), resultTy.getDtype()); return SplatElementsAttr::get(mlirTensorType, a); } //===----------------------------------------------------------------------===// // PrimMinSelfIntOp //===----------------------------------------------------------------------===// OpFoldResult PrimMinSelfIntOp::fold(FoldAdaptor adaptor) { auto list = getOperand().getDefiningOp(); if (!list) return nullptr; // TODO: What does it return for an empty list? if (list->getNumOperands() == 0) return nullptr; SmallVector values; for (auto operand : list->getOperands()) { int64_t value; if (!matchPattern(operand, m_TorchConstantInt(&value))) return nullptr; values.push_back(value); } return getI64IntegerAttr(getContext(), *std::min_element(values.begin(), values.end())); } //===----------------------------------------------------------------------===// // PrimMinIntOp //===----------------------------------------------------------------------===// OpFoldResult PrimMinIntOp::fold(FoldAdaptor adaptor) { // If both operands are the same, then the operation is an identity. if (getA() == getB()) return getA(); auto lhs = adaptor.getA().dyn_cast_or_null(); auto rhs = adaptor.getB().dyn_cast_or_null(); if (!lhs || !rhs) return nullptr; // Torch semantics are that !torch.int is 64-bit signed. return IntegerAttr::get( lhs.getType(), std::min(lhs.getValue().getSExtValue(), rhs.getValue().getSExtValue())); } //===----------------------------------------------------------------------===// // ShapeCalculateOp //===----------------------------------------------------------------------===// template static void getSuccessorRegionsForCalculateOp(CalculateOp op, RegionBranchPoint point, SmallVectorImpl ®ions) { if (!point.getRegionOrNull()) { // First thing the op does is branch into the calculation. regions.emplace_back(&op.getCalculation()); return; } if (point == op.getBody()) { // Body returns control to the outer op, passing through results. regions.emplace_back(op.getResults()); return; } assert(point == op.getCalculation()); // Calculation branches to the body. regions.emplace_back(&op.getBody()); } void ShapeCalculateOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { getSuccessorRegionsForCalculateOp(*this, point, regions); } //===----------------------------------------------------------------------===// // DtypeCalculateOp //===----------------------------------------------------------------------===// void DtypeCalculateOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { getSuccessorRegionsForCalculateOp(*this, point, regions); } //===----------------------------------------------------------------------===// // ShapeCalculateYieldShapesOp //===----------------------------------------------------------------------===// MutableOperandRange ShapeCalculateYieldShapesOp::getMutableSuccessorOperands( RegionBranchPoint point) { // The shape operands don't get forwarded to the body. // MutableOperandRange always has an owning operation, even if empty, so // create a 0-length range. return MutableOperandRange(*this, /*start=*/0, /*length=*/0); } LogicalResult ShapeCalculateYieldShapesOp::verify() { auto parent = cast(getOperation()->getParentOp()); if (parent.getNumResults() != getNumOperands()) return emitOpError("expected number of shapes to match number of results"); return success(); } //===----------------------------------------------------------------------===// // AtenNormScalarOp //===----------------------------------------------------------------------===// LogicalResult AtenNormScalarOp::verify() { // Verificaion of input type for torch.aten.norm.Scalar. // Per PyTorch docs, only float and complex types are valid for norm // operation. auto inTensor = getSelf().getType().cast(); // If no dtype is specified, it will default to a float one. if (!inTensor.hasDtype()) { return success(); } auto inTensorDtype = inTensor.getDtype(); // Check if dtype is one of those supported by norm operation. // ComplexType will match any torch complex types, but each float must be // checked individually. if (!inTensorDtype.isa()) { return emitOpError( "expected a float or complex type for input tensor, but got ") << inTensorDtype; } return success(); } //===----------------------------------------------------------------------===// // AtenPermuteOp //===----------------------------------------------------------------------===// LogicalResult AtenPermuteOp::verify() { // Verification of the permute op for input & output dimensions with // statically known sizes. SmallVector permutation; auto permutationObtained = getListConstructElements(getDims(), permutation); if (!permutationObtained) { return success(); } auto outType = getResult().getType().cast(); auto inType = getSelf().getType().cast(); if (!outType.hasSizes() || !inType.hasSizes()) { return success(); } auto outShape = outType.getSizes(); auto inShape = inType.getSizes(); auto outRank = outShape.size(); if (outRank != inShape.size()) { return emitOpError( "expected input and output tensors to have same rank, but ") << inShape.size() << " != " << outRank << '.'; } if (outRank != permutation.size()) { return emitOpError() << "expected permutation to have size equal result " "tensor rank. The permutation has " << permutation.size() << " elements, the output has rank " << outRank << '.'; } // Initialization of the reverse permutation. -1 denotes an unknown // permutation index. SmallVector reversePermutation(outRank, -1); // In this loop: // (1) check that the permutation indices are in bounds, and not duplicated. // (2) populate reversePermutation (to check for duplicates). // (3) check that the input and output shapes agree with the permutation. For // example, if the permutation is (1,2,0) and the input shape is (2,3,5), // then the output shape must be (3,5,2). for (uint64_t to = 0; to < outRank; ++to) { int64_t from; auto fromIsSet = matchPattern(permutation[to], m_TorchConstantInt(&from)); if (!fromIsSet) { continue; } // if 'from' is the unkwown index, continue. if (from == -1) { continue; } if (!isValidDim(from, outRank)) { return emitError("observed invalid index in permutation (") << from << ") for input tensor of rank " << outRank << '.'; } if (reversePermutation[from] != -1) { return emitOpError("has a duplicate dimension (") << from << ") in its permutation " << getDims() << '.'; } reversePermutation[from] = to; auto dimSizesDefined = inShape[from] != kUnknownSize && outShape[to] != kUnknownSize; auto dimSizesDifferent = inShape[from] != outShape[to]; if (dimSizesDefined && dimSizesDifferent) { return emitOpError("has a permutation which is not compatible with the " "input and output shapes. ") << "The input shape in dimension " << from << " is " << inShape[from] << ", and the output shape in dimension " << to << " is " << outShape[to] << " : they should be the same with this permutation. "; } } return success(); } //===----------------------------------------------------------------------===// // AtenLinalgCrossOp //===----------------------------------------------------------------------===// LogicalResult AtenLinalgCrossOp::verify() { auto selfType = getSelf().getType().cast(); auto otherType = getOther().getType().cast(); if (!selfType.hasDtype() || !otherType.hasDtype() || !selfType.hasSizes() || !otherType.hasSizes()) { return success(); } Type selfDtype = selfType.getDtype(); Type otherDtype = otherType.getDtype(); // the operation succeeds only if both inputs have the same dtype if (selfDtype != otherDtype) { return emitOpError("input tensors must have the same dtype, but got ") << selfDtype << " and " << otherDtype; } // Check if any of the input tensors has torch.bool dtype. // The operation does not support this type. // The docs state that only float, double, cfloat and cdouble dtypes are // supported, but, when testing, it fails only for boolean dtype. Update to // fit the docs if necessary. // https://pytorch.org/docs/stable/generated/torch.linalg.cross.html if (selfDtype.isSignlessInteger(1) || otherDtype.isSignlessInteger(1)) { return emitOpError("input tensors must not have bool dtype"); } ArrayRef selfShape = selfType.getSizes(); ArrayRef otherShape = otherType.getSizes(); int64_t selfRank = selfShape.size(); int64_t otherRank = otherShape.size(); // check if both input tensors have the same number of dims if (selfRank != otherRank) { return emitOpError("input tensors must have the same number of dimensions, " "but got ") << selfRank << " and " << otherRank; } // convert dim to an integer type int64_t dim; if (!matchPattern(getDim(), m_TorchConstantInt(&dim))) { return success(); } // check if dim is in the correct range if (dim >= selfRank || dim < -selfRank) { return emitOpError("dim expected to be in rank of [") << -selfRank << ", " << selfRank - 1 << "], but got " << dim; } // compensate for possible negative dim value if (dim < 0) { dim += selfRank; } // check if the size of the dimensions specified by 'dim' is equal to 3 // (required by the operation) if ((selfShape[dim] != 3 && selfShape[dim] != kUnknownSize) || (otherShape[dim] != 3 && otherShape[dim] != kUnknownSize)) { return emitOpError("inputs dimension ") << dim << " must have length 3, but got " << selfShape[dim] << " and " << otherShape[dim]; } // Check if there is a disparity between dimension sizes. // Dimensions at the same index must either have the same size, // or one of them must be equal to 1. int32_t i = 0; for (auto [selfCurrent, otherCurrent] : llvm::zip_equal(selfShape, otherShape)) { if (selfCurrent != otherCurrent && selfCurrent != 1 && otherCurrent != 1) { return emitOpError("the size of first tensor (") << selfCurrent << ") must match the size of second tensor (" << otherCurrent << ") at dimension " << i << " or one of them must be 1"; } ++i; } return success(); } //===----------------------------------------------------------------------===// // DtypeCalculateYieldDtypesOp //===----------------------------------------------------------------------===// MutableOperandRange DtypeCalculateYieldDtypesOp::getMutableSuccessorOperands( RegionBranchPoint point) { // The dtype operands don't get forwarded to the body. // MutableOperandRange always has an owning operation, even if empty, so // create a 0-length range. return MutableOperandRange(*this, /*start=*/0, /*length=*/0); } LogicalResult DtypeCalculateYieldDtypesOp::verify() { auto parent = cast(getOperation()->getParentOp()); if (parent.getNumResults() != getNumOperands()) return emitOpError("expected number of dtypes to match number of results"); return success(); } //===----------------------------------------------------------------------===// // GlobalSlotModuleInitializerOp //===----------------------------------------------------------------------===// LogicalResult GlobalSlotModuleInitializerOp::verify() { // We centralize all verification of the global slots and the // InitializeGlobalSlotsOp into here, since it requires processing the whole // module. // TODO: We should really have a `torch.module` and have this initializer be // a region attached to it. ModuleOp module = cast(getOperation()->getParentOp()); for (auto op : module.getOps()) { if (op.getOperation() != getOperation()) return op.emitError("there must be only one global slot initializer"); } // Collect the relevant symbol names we will verify. DenseSet knownGlobalSlots; for (auto op : module.getOps()) knownGlobalSlots.insert(op.getSymNameAttr()); DenseSet initializedGlobalSlots; auto initialize = cast(getBody()->getTerminator()); for (Attribute symName : initialize.getSlotSymNames()) { auto wasInserted = initializedGlobalSlots .insert(symName.cast().getAttr()) .second; if (!wasInserted) return initialize.emitError("duplicate initialization of global slot: ") << symName; } auto lessThanByStringValue = [](Attribute lhs, Attribute rhs) { return lhs.cast().getValue() < rhs.cast().getValue(); }; auto known = llvm::to_vector(knownGlobalSlots); llvm::sort(known, lessThanByStringValue); auto initialized = llvm::to_vector(initializedGlobalSlots); llvm::sort(initialized, lessThanByStringValue); // Check that the global slots in the module are all initialized. SymbolTable symbolTable(module); if (initializedGlobalSlots != knownGlobalSlots) { InFlightDiagnostic diag = initialize.emitOpError( "must have one initializer for each global slot in the module"); for (auto knownGlobalSlot : known) { auto symName = FlatSymbolRefAttr::get(knownGlobalSlot.cast()); if (!initializedGlobalSlots.count(knownGlobalSlot)) { diag.attachNote( symbolTable.lookup(symName.getAttr()).getLoc()) .append("missing global slot initializer for ", symName); } } for (auto initializedGlobalSlot : initialized) { if (!knownGlobalSlots.count(initializedGlobalSlot)) { diag.attachNote().append( "unexpected global slot initializer for non-existent global slot ", FlatSymbolRefAttr::get(initializedGlobalSlot.cast())); } } return diag; } // Check that initial values satisfy type bounds. for (int i = 0, e = initialize.getNumOperands(); i < e; ++i) { auto symName = initialize.getSlotSymNames()[i].cast(); auto initialValue = initialize.getOperand(i); auto globalSlotOp = symbolTable.lookup(symName.getValue()); if (!isValidSubtype(initialValue.getType(), globalSlotOp.getTypeBound())) { return initialize.emitOpError().append( "initial value for global slot ", symName, " has type ", initialValue.getType(), " which is not within the bound ", globalSlotOp.getTypeBound()); } } auto walkResult = getOperation()->walk([](Operation *op) { // We only permit a small set of ops in the module initializer. // These ops are essentially those which can be produced by the IValue // importer. if (op->hasTrait()) return WalkResult::advance(); op->emitOpError() << "is not allowed in a module initializer"; return WalkResult::interrupt(); }); if (walkResult.wasInterrupted()) return failure(); return success(); } //===----------------------------------------------------------------------===// // InitializeGlobalSlotsOp //===----------------------------------------------------------------------===// ParseResult InitializeGlobalSlotsOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); if (parser.parseLSquare()) return failure(); SmallVector slotSymNames; while (!succeeded(parser.parseOptionalRSquare())) { NamedAttrList dummy; StringAttr slotSymName; if (parser.parseSymbolName(slotSymName, "dummy", dummy)) return failure(); slotSymNames.push_back(FlatSymbolRefAttr::get(slotSymName)); if (parser.parseLParen()) return failure(); OpAsmParser::UnresolvedOperand initialValue; if (parser.parseOperand(initialValue)) return failure(); Type initialValueType; if (parser.parseColonType(initialValueType)) return failure(); if (parser.parseRParen()) return failure(); if (parser.resolveOperand(initialValue, initialValueType, result.operands)) return failure(); } result.addAttribute("slotSymNames", ArrayAttr::get(parser.getContext(), slotSymNames)); return success(); } void InitializeGlobalSlotsOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict(getOperation()->getAttrs(), /*elidedAttrs=*/{"slotSymNames"}); p << " ["; p.printNewline(); for (int i = 0, e = getNumOperands(); i < e; ++i) { p << " " << getSlotSymNames()[i] << "(" << getInitialValues()[i] << " : " << getInitialValues()[i].getType() << ")"; p.printNewline(); } p << "]"; } LogicalResult InitializeGlobalSlotsOp::verify() { if (getInitialValues().size() != getSlotSymNames().size()) return emitOpError("expected number of operands to match number of slots"); return success(); }