From 3b66b4925a7287f624a7c441ac4870873613557b Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Wed, 16 Mar 2022 00:54:57 +0000 Subject: [PATCH] Make TorchOps.cpp faster to iterate on. The ODS-generated code included via the `TorchOps.cpp.inc` file takes a very long time to compile. This PR isolates it into its own file so that the build system can cache it. This PR creates a new file `TorchOpsODSGenerated.cpp` just to include the `TorchOps.cpp.inc` file. Doing so required moving to the "new" way to define verifiers, since the static `verify` free functions in TorchOps.cpp weren't accessible from the .inc file after it was moved to `TorchOpsODSGenerated.cpp`. On my machine, this drops the build time of TorchOps.cpp (such as when iterating on a canonicalizer) from >40 seconds to <10 seconds. 10 seconds still isn't great though, but at least it isn't "go get a coffee" type of waiting. --- .../torch-mlir/Dialect/Torch/IR/TorchOps.td | 20 +-- .../torch-mlir/Dialect/Torch/IR/TorchTypes.h | 10 ++ lib/Dialect/Torch/IR/CMakeLists.txt | 1 + lib/Dialect/Torch/IR/TorchOps.cpp | 130 ++++++------------ lib/Dialect/Torch/IR/TorchOpsODSGenerated.cpp | 35 +++++ lib/Dialect/Torch/IR/TorchTypes.cpp | 45 ++++++ 6 files changed, 143 insertions(+), 98 deletions(-) create mode 100644 lib/Dialect/Torch/IR/TorchOpsODSGenerated.cpp diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index 2ec682434..1882b6a04 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -62,7 +62,7 @@ def Torch_NnModuleOp : Torch_Op<"nn_module", [ let arguments = (ins); let results = (outs Torch_NnModuleType:$result); let regions = (region SizedRegion<1>:$region); - let verifier = "return ::verify(*this);"; + let hasVerifier = 1; let assemblyFormat = "$region attr-dict `:` qualified(type($result))"; @@ -146,7 +146,7 @@ def Torch_ClassTypeOp : Torch_Op<"class_type", [ let arguments = (ins SymbolNameAttr:$sym_name); let results = (outs); let regions = (region SizedRegion<1>:$region); - let verifier = "return ::verify(*this);"; + let hasVerifier = 1; let assemblyFormat = "$sym_name $region attr-dict"; } @@ -360,7 +360,7 @@ def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [ AnyTorchListType:$result ); - let verifier = "return ::verify(*this);"; + let hasVerifier = 1; let assemblyFormat = [{ $elements attr-dict `:` functional-type(operands, results) @@ -382,7 +382,7 @@ def Torch_PrimDictConstructOp: Torch_Op<"prim.DictConstruct", [ Torch_DictType:$result ); - let verifier = "return ::verify(*this);"; + let hasVerifier = 1; let assemblyFormat = [{ `keys` `(` ($keys^ `:` qualified(type($keys)))? `)` `values` `(` ($values^ `:` qualified(type($values)))? `)` attr-dict `->` qualified(type($result)) @@ -474,7 +474,7 @@ def Torch_PrimLoopOp : Torch_Op<"prim.Loop", [ $maxTripCount `,` $initialCondition `,` `init` `(` $iterArgsInit `)` $region attr-dict `:` functional-type(operands, results) }]; - let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }]; + let hasVerifier = 1; let extraClassDeclaration = [{ /// Returns true if this loop is "for-like". Otherwise it is "while-like" /// and this function returns false. @@ -528,7 +528,7 @@ def Torch_PrimIfOp : Torch_Op<"prim.If", [ let regions = (region SizedRegion<1>:$thenRegion, SizedRegion<1>:$elseRegion); // Indicate that the operation has a custom parser and printer method. let hasCustomAssemblyFormat = 1; - let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }]; + let hasVerifier = 1; let hasCanonicalizer = 1; } @@ -877,7 +877,7 @@ def Torch_CopyToNonValueTensorOp : Torch_Op<"copy.to_tensor", [ let assemblyFormat = [{ $operand attr-dict `:` qualified(type($result)) }]; - let verifier = "return ::verify(*this);"; + let hasVerifier = 1; } def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [ @@ -907,7 +907,7 @@ def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [ let assemblyFormat = [{ $operand attr-dict `:` qualified(type($result)) }]; - let verifier = "return ::verify(*this);"; + let hasVerifier = 1; } def Torch_OverwriteTensorContentsOp : Torch_Op<"overwrite.tensor.contents", [ @@ -1104,7 +1104,7 @@ def Torch_ShapeCalculateOp : Torch_Op<"shape.calculate", [ let assemblyFormat = [{ $body `shapes` $shapeCalculation attr-dict `:` type($results) }]; - let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }]; + let hasVerifier = 1; } def Torch_ShapeCalculateYieldOp : Torch_Op<"shape.calculate.yield", [ @@ -1147,7 +1147,7 @@ def Torch_ShapeCalculateYieldShapesOp : Torch_Op<"shape.calculate.yield.shapes", attr-dict ($results^ `:` type($results))? }]; - let verifier = "return ::verify(*this);"; + let hasVerifier = 1; } #endif // TORCH_OPS diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h index dd8806840..46771dc72 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h @@ -16,6 +16,16 @@ namespace mlir { namespace torch { namespace Torch { +/// PyTorch has a well-developed notion of subtyping. +/// +/// This is a restricted subset of it that only handles a few special cases +/// that we need to model. +/// +/// TODO: Flesh this out. +/// TODO: Decide / properly model the distinction between PEP 483 / Python +/// subtyping vs "more static information". +bool isValidSubtype(Type subtype, Type type); + class NonValueTensorType; class ValueTensorType; diff --git a/lib/Dialect/Torch/IR/CMakeLists.txt b/lib/Dialect/Torch/IR/CMakeLists.txt index 0b565bb45..8ddaad76a 100644 --- a/lib/Dialect/Torch/IR/CMakeLists.txt +++ b/lib/Dialect/Torch/IR/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_library(TorchMLIRTorchDialect TorchDialect.cpp TorchOps.cpp + TorchOpsODSGenerated.cpp TorchTypes.cpp ADDITIONAL_HEADER_DIRS diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index d81770942..984941cfa 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -125,61 +125,13 @@ LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) { // NnModuleOp //===----------------------------------------------------------------------===// -static LogicalResult verify(NnModuleOp op) { - for (Operation &child : *op.getBody()) +LogicalResult NnModuleOp::verify() { + for (Operation &child : *getBody()) if (!isa(&child)) return child.emitOpError() << "is not allowed inside 'torch.nn_module'"; return success(); } -// PyTorch has a well-developed notion of subtyping. -// -// This is a restricted subset of it. -// -// TODO: Flesh this out. -// TODO: Decide / properly model the distinction between PEP 483 / Python -// subtyping vs "more static information". -bool isValidSubtype(Type subtype, Type type) { - if (subtype == type) - return true; - - if (auto any = type.dyn_cast()) - return true; - - if (auto number = type.dyn_cast()) - return subtype.isa() || subtype.isa(); - - if (auto optional = type.dyn_cast()) - return isValidSubtype(subtype, optional.getContainedType()) || - subtype.isa(); - - if (auto tuple = type.dyn_cast()) { - if (!subtype.isa()) - return false; - auto subtypes = subtype.cast().getContainedTypes(); - auto types = tuple.getContainedTypes(); - if (subtypes.size() != types.size()) - return false; - for (auto t : llvm::zip(subtypes, types)) { - if (!isValidSubtype(std::get<0>(t), std::get<1>(t))) - return false; - } - return true; - } - - // TODO: This is not subtyping according to PEP 483. See description - // of NonValueTensorType. - if (subtype.isa() && type.isa() && - type == - NonValueTensorType::getWithLeastStaticInformation(type.getContext())) - return true; - - if (subtype.isa() && type.isa() && - type == ValueTensorType::getWithLeastStaticInformation(type.getContext())) - return true; - return false; -} - LogicalResult NnModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) { auto classType = symbolTable.lookupNearestSymbolFrom( *this, SymbolRefAttr::get(getContext(), getClassName())); @@ -213,15 +165,15 @@ LogicalResult NnModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) { // PrimListConstructOp //===----------------------------------------------------------------------===// -static LogicalResult verify(PrimListConstructOp op) { - auto resultType = op.getResult().getType(); +LogicalResult PrimListConstructOp::verify() { + auto resultType = getResult().getType(); auto resultElementType = resultType.dyn_cast().getContainedType(); auto matchResultElementType = [&](Type type) { return isValidSubtype(type, resultElementType); }; - if (!llvm::all_of(op->getOperandTypes(), matchResultElementType)) { - return op.emitError() << "operand types should have the same type as the " - "list contained type"; + if (!llvm::all_of(getOperandTypes(), matchResultElementType)) { + return emitError() << "operand types should have the same type as the " + "list contained type"; } return success(); @@ -231,18 +183,16 @@ static LogicalResult verify(PrimListConstructOp op) { // PrimDictConstructOp //===----------------------------------------------------------------------===// -static LogicalResult verify(PrimDictConstructOp op) { +LogicalResult PrimDictConstructOp::verify() { auto isValidSubTypeOf = [](Type expectedType) { return [=](Type type) { return isValidSubtype(type, expectedType); }; }; - Type keyType = op.getKeyType(); - if (!llvm::all_of(op.keys().getTypes(), isValidSubTypeOf(keyType))) - return op.emitError() << "keys should be of Dict key type"; + if (!llvm::all_of(keys().getTypes(), isValidSubTypeOf(getKeyType()))) + return emitError() << "keys should be of Dict key type"; - Type valueType = op.getValueType(); - if (!llvm::all_of(op.values().getTypes(), isValidSubTypeOf(valueType))) - return op.emitError() << "values should be of Dict value type"; + if (!llvm::all_of(values().getTypes(), isValidSubTypeOf(getValueType()))) + return emitError() << "values should be of Dict value type"; return success(); } @@ -251,9 +201,9 @@ static LogicalResult verify(PrimDictConstructOp op) { // ClassTypeOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ClassTypeOp op) { +LogicalResult ClassTypeOp::verify() { llvm::StringMap namesToOps; - for (Operation &child : op.getBody()->without_terminator()) { + for (Operation &child : getBody()->without_terminator()) { if (!isa(&child)) return child.emitOpError() << "is not allowed inside `torch.class_type`"; StringRef name; @@ -265,8 +215,8 @@ static LogicalResult verify(ClassTypeOp op) { auto it = itAndWasInserted.first; bool wasInserted = itAndWasInserted.second; if (!wasInserted) { - auto diag = op.emitOpError().append( - "has duplicate attr/method with name '", name, "'"); + auto diag = emitOpError().append("has duplicate attr/method with name '", + name, "'"); diag.attachNote(it->second->getLoc()) .append("see first conflicting attr/method here"); diag.attachNote(child.getLoc()) @@ -282,6 +232,10 @@ static LogicalResult verify(ClassTypeOp op) { // PrimLoopOp //===----------------------------------------------------------------------===// +LogicalResult PrimLoopOp::verify() { + return RegionBranchOpInterface::verifyTypes(*this); +} + OperandRange PrimLoopOp::getSuccessorEntryOperands(unsigned index) { assert(index == 0); return iterArgsInit(); @@ -321,6 +275,10 @@ PrimLoopConditionOp::getMutableSuccessorOperands(Optional index) { // PrimIfOp //===----------------------------------------------------------------------===// +LogicalResult PrimIfOp::verify() { + return RegionBranchOpInterface::verifyTypes(*this); +} + ParseResult PrimIfOp::parse(OpAsmParser &parser, OperationState &result) { // Create the regions. result.regions.reserve(2); @@ -1073,13 +1031,11 @@ void TensorStaticInfoCastOp::getCanonicalizationPatterns( // CopyToNonValueTensorOp //===----------------------------------------------------------------------===// -static LogicalResult verify(CopyToNonValueTensorOp op) { - auto resultType = op.getResult().getType().cast(); - auto operandType = op.getOperand().getType().cast(); - if (!resultType.hasSameSizesAndDtype(operandType)) { - return op.emitError() - << "operand and result must have same sizes and dtype"; - } +LogicalResult CopyToNonValueTensorOp::verify() { + auto resultType = getResult().getType().cast(); + auto operandType = getOperand().getType().cast(); + if (!resultType.hasSameSizesAndDtype(operandType)) + return emitError() << "operand and result must have same sizes and dtype"; return success(); } @@ -1102,13 +1058,11 @@ void CopyToNonValueTensorOp::getEffects( // CopyToValueTensorOp //===----------------------------------------------------------------------===// -static LogicalResult verify(CopyToValueTensorOp op) { - auto resultType = op.getResult().getType().cast(); - auto operandType = op.getOperand().getType().cast(); - if (!resultType.hasSameSizesAndDtype(operandType)) { - return op.emitError() - << "operand and result must have same sizes and dtype"; - } +LogicalResult CopyToValueTensorOp::verify() { + auto resultType = getResult().getType().cast(); + auto operandType = getOperand().getType().cast(); + if (!resultType.hasSameSizesAndDtype(operandType)) + return emitError() << "operand and result must have same sizes and dtype"; return success(); } @@ -1588,6 +1542,10 @@ OpFoldResult PrimMinSelfIntOp::fold(ArrayRef operands) { // ShapeCalculateOp //===----------------------------------------------------------------------===// +LogicalResult ShapeCalculateOp::verify() { + return RegionBranchOpInterface::verifyTypes(*this); +} + void ShapeCalculateOp::getSuccessorRegions( Optional index, ArrayRef operands, SmallVectorImpl ®ions) { @@ -1620,13 +1578,9 @@ MutableOperandRange ShapeCalculateYieldShapesOp::getMutableSuccessorOperands( return MutableOperandRange(*this, /*start=*/0, /*length=*/0); } -static LogicalResult verify(ShapeCalculateYieldShapesOp op) { - auto parent = op->getParentOfType(); - if (parent.getNumResults() != op.getNumOperands()) - return op.emitOpError( - "expected number of shapes to match number of results"); +LogicalResult ShapeCalculateYieldShapesOp::verify() { + auto parent = cast(getOperation()->getParentOp()); + if (parent.getNumResults() != getNumOperands()) + return emitOpError("expected number of shapes to match number of results"); return success(); } - -#define GET_OP_CLASSES -#include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc" diff --git a/lib/Dialect/Torch/IR/TorchOpsODSGenerated.cpp b/lib/Dialect/Torch/IR/TorchOpsODSGenerated.cpp new file mode 100644 index 000000000..133a63c73 --- /dev/null +++ b/lib/Dialect/Torch/IR/TorchOpsODSGenerated.cpp @@ -0,0 +1,35 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// +// +// This file is meant to include the `TorchOps.cpp.inc` file and compile it +// separately from the main TorchOps.cpp file. The .inc file takes a very long +// time to compile, and slows down the iteration time on folders, +// canonicalizations, parser/printers, etc. in the actual TorchOps.cpp file, so +// it makes sense to isolate it and let the build system cache it. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LLVM.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "llvm/ADT/BitVector.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/Support/Casting.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +#define GET_OP_CLASSES +#include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc" diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index 34c0e2789..a5fdc824b 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -17,6 +17,51 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +//===----------------------------------------------------------------------===// +// isValidSubtype +//===----------------------------------------------------------------------===// + +bool Torch::isValidSubtype(Type subtype, Type type) { + if (subtype == type) + return true; + + if (auto any = type.dyn_cast()) + return true; + + if (auto number = type.dyn_cast()) + return subtype.isa() || subtype.isa(); + + if (auto optional = type.dyn_cast()) + return isValidSubtype(subtype, optional.getContainedType()) || + subtype.isa(); + + if (auto tuple = type.dyn_cast()) { + if (!subtype.isa()) + return false; + auto subtypes = subtype.cast().getContainedTypes(); + auto types = tuple.getContainedTypes(); + if (subtypes.size() != types.size()) + return false; + for (auto t : llvm::zip(subtypes, types)) { + if (!isValidSubtype(std::get<0>(t), std::get<1>(t))) + return false; + } + return true; + } + + // TODO: This is not subtyping according to PEP 483. See description + // of NonValueTensorType. + if (subtype.isa() && type.isa() && + type == + NonValueTensorType::getWithLeastStaticInformation(type.getContext())) + return true; + + if (subtype.isa() && type.isa() && + type == ValueTensorType::getWithLeastStaticInformation(type.getContext())) + return true; + return false; +} + //===----------------------------------------------------------------------===// // TupleType //===----------------------------------------------------------------------===//