Make TorchOps.cpp faster to iterate on.

The ODS-generated code included via the `TorchOps.cpp.inc` file takes a
very long time to compile. This PR isolates it into its own file so that
the build system can cache it.

This PR creates a new file `TorchOpsODSGenerated.cpp` just to include
the `TorchOps.cpp.inc` file. Doing so required moving to the "new" way
to define verifiers, since the static `verify` free functions in
TorchOps.cpp weren't accessible from the .inc file after it was moved to
`TorchOpsODSGenerated.cpp`.

On my machine, this drops the build time of TorchOps.cpp (such as when
iterating on a canonicalizer) from >40 seconds to <10 seconds.
10 seconds still isn't great though, but at least it isn't "go get a
coffee" type of waiting.
pull/679/head snapshot-20220316.328
Sean Silva 2022-03-16 00:54:57 +00:00
parent 8da7d90611
commit 3b66b4925a
6 changed files with 143 additions and 98 deletions

View File

@ -62,7 +62,7 @@ def Torch_NnModuleOp : Torch_Op<"nn_module", [
let arguments = (ins);
let results = (outs Torch_NnModuleType:$result);
let regions = (region SizedRegion<1>:$region);
let verifier = "return ::verify(*this);";
let hasVerifier = 1;
let assemblyFormat = "$region attr-dict `:` qualified(type($result))";
@ -146,7 +146,7 @@ def Torch_ClassTypeOp : Torch_Op<"class_type", [
let arguments = (ins SymbolNameAttr:$sym_name);
let results = (outs);
let regions = (region SizedRegion<1>:$region);
let verifier = "return ::verify(*this);";
let hasVerifier = 1;
let assemblyFormat = "$sym_name $region attr-dict";
}
@ -360,7 +360,7 @@ def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [
AnyTorchListType:$result
);
let verifier = "return ::verify(*this);";
let hasVerifier = 1;
let assemblyFormat = [{
$elements attr-dict `:` functional-type(operands, results)
@ -382,7 +382,7 @@ def Torch_PrimDictConstructOp: Torch_Op<"prim.DictConstruct", [
Torch_DictType:$result
);
let verifier = "return ::verify(*this);";
let hasVerifier = 1;
let assemblyFormat = [{
`keys` `(` ($keys^ `:` qualified(type($keys)))? `)` `values` `(` ($values^ `:` qualified(type($values)))? `)` attr-dict `->` qualified(type($result))
@ -474,7 +474,7 @@ def Torch_PrimLoopOp : Torch_Op<"prim.Loop", [
$maxTripCount `,` $initialCondition `,` `init` `(` $iterArgsInit `)` $region
attr-dict `:` functional-type(operands, results)
}];
let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }];
let hasVerifier = 1;
let extraClassDeclaration = [{
/// Returns true if this loop is "for-like". Otherwise it is "while-like"
/// and this function returns false.
@ -528,7 +528,7 @@ def Torch_PrimIfOp : Torch_Op<"prim.If", [
let regions = (region SizedRegion<1>:$thenRegion, SizedRegion<1>:$elseRegion);
// Indicate that the operation has a custom parser and printer method.
let hasCustomAssemblyFormat = 1;
let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }];
let hasVerifier = 1;
let hasCanonicalizer = 1;
}
@ -877,7 +877,7 @@ def Torch_CopyToNonValueTensorOp : Torch_Op<"copy.to_tensor", [
let assemblyFormat = [{
$operand attr-dict `:` qualified(type($result))
}];
let verifier = "return ::verify(*this);";
let hasVerifier = 1;
}
def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [
@ -907,7 +907,7 @@ def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [
let assemblyFormat = [{
$operand attr-dict `:` qualified(type($result))
}];
let verifier = "return ::verify(*this);";
let hasVerifier = 1;
}
def Torch_OverwriteTensorContentsOp : Torch_Op<"overwrite.tensor.contents", [
@ -1104,7 +1104,7 @@ def Torch_ShapeCalculateOp : Torch_Op<"shape.calculate", [
let assemblyFormat = [{
$body `shapes` $shapeCalculation attr-dict `:` type($results)
}];
let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }];
let hasVerifier = 1;
}
def Torch_ShapeCalculateYieldOp : Torch_Op<"shape.calculate.yield", [
@ -1147,7 +1147,7 @@ def Torch_ShapeCalculateYieldShapesOp : Torch_Op<"shape.calculate.yield.shapes",
attr-dict ($results^ `:` type($results))?
}];
let verifier = "return ::verify(*this);";
let hasVerifier = 1;
}
#endif // TORCH_OPS

View File

@ -16,6 +16,16 @@ namespace mlir {
namespace torch {
namespace Torch {
/// PyTorch has a well-developed notion of subtyping.
///
/// This is a restricted subset of it that only handles a few special cases
/// that we need to model.
///
/// TODO: Flesh this out.
/// TODO: Decide / properly model the distinction between PEP 483 / Python
/// subtyping vs "more static information".
bool isValidSubtype(Type subtype, Type type);
class NonValueTensorType;
class ValueTensorType;

View File

@ -1,6 +1,7 @@
add_mlir_library(TorchMLIRTorchDialect
TorchDialect.cpp
TorchOps.cpp
TorchOpsODSGenerated.cpp
TorchTypes.cpp
ADDITIONAL_HEADER_DIRS

View File

@ -125,61 +125,13 @@ LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// NnModuleOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(NnModuleOp op) {
for (Operation &child : *op.getBody())
LogicalResult NnModuleOp::verify() {
for (Operation &child : *getBody())
if (!isa<SlotOp, NnModuleTerminatorOp>(&child))
return child.emitOpError() << "is not allowed inside 'torch.nn_module'";
return success();
}
// PyTorch has a well-developed notion of subtyping.
//
// This is a restricted subset of it.
//
// TODO: Flesh this out.
// TODO: Decide / properly model the distinction between PEP 483 / Python
// subtyping vs "more static information".
bool isValidSubtype(Type subtype, Type type) {
if (subtype == type)
return true;
if (auto any = type.dyn_cast<AnyType>())
return true;
if (auto number = type.dyn_cast<NumberType>())
return subtype.isa<IntType>() || subtype.isa<Torch::FloatType>();
if (auto optional = type.dyn_cast<OptionalType>())
return isValidSubtype(subtype, optional.getContainedType()) ||
subtype.isa<Torch::NoneType>();
if (auto tuple = type.dyn_cast<Torch::TupleType>()) {
if (!subtype.isa<Torch::TupleType>())
return false;
auto subtypes = subtype.cast<Torch::TupleType>().getContainedTypes();
auto types = tuple.getContainedTypes();
if (subtypes.size() != types.size())
return false;
for (auto t : llvm::zip(subtypes, types)) {
if (!isValidSubtype(std::get<0>(t), std::get<1>(t)))
return false;
}
return true;
}
// TODO: This is not subtyping according to PEP 483. See description
// of NonValueTensorType.
if (subtype.isa<NonValueTensorType>() && type.isa<NonValueTensorType>() &&
type ==
NonValueTensorType::getWithLeastStaticInformation(type.getContext()))
return true;
if (subtype.isa<ValueTensorType>() && type.isa<ValueTensorType>() &&
type == ValueTensorType::getWithLeastStaticInformation(type.getContext()))
return true;
return false;
}
LogicalResult NnModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto classType = symbolTable.lookupNearestSymbolFrom<ClassTypeOp>(
*this, SymbolRefAttr::get(getContext(), getClassName()));
@ -213,15 +165,15 @@ LogicalResult NnModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// PrimListConstructOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(PrimListConstructOp op) {
auto resultType = op.getResult().getType();
LogicalResult PrimListConstructOp::verify() {
auto resultType = getResult().getType();
auto resultElementType = resultType.dyn_cast<ListType>().getContainedType();
auto matchResultElementType = [&](Type type) {
return isValidSubtype(type, resultElementType);
};
if (!llvm::all_of(op->getOperandTypes(), matchResultElementType)) {
return op.emitError() << "operand types should have the same type as the "
"list contained type";
if (!llvm::all_of(getOperandTypes(), matchResultElementType)) {
return emitError() << "operand types should have the same type as the "
"list contained type";
}
return success();
@ -231,18 +183,16 @@ static LogicalResult verify(PrimListConstructOp op) {
// PrimDictConstructOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(PrimDictConstructOp op) {
LogicalResult PrimDictConstructOp::verify() {
auto isValidSubTypeOf = [](Type expectedType) {
return [=](Type type) { return isValidSubtype(type, expectedType); };
};
Type keyType = op.getKeyType();
if (!llvm::all_of(op.keys().getTypes(), isValidSubTypeOf(keyType)))
return op.emitError() << "keys should be of Dict key type";
if (!llvm::all_of(keys().getTypes(), isValidSubTypeOf(getKeyType())))
return emitError() << "keys should be of Dict key type";
Type valueType = op.getValueType();
if (!llvm::all_of(op.values().getTypes(), isValidSubTypeOf(valueType)))
return op.emitError() << "values should be of Dict value type";
if (!llvm::all_of(values().getTypes(), isValidSubTypeOf(getValueType())))
return emitError() << "values should be of Dict value type";
return success();
}
@ -251,9 +201,9 @@ static LogicalResult verify(PrimDictConstructOp op) {
// ClassTypeOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(ClassTypeOp op) {
LogicalResult ClassTypeOp::verify() {
llvm::StringMap<Operation *> namesToOps;
for (Operation &child : op.getBody()->without_terminator()) {
for (Operation &child : getBody()->without_terminator()) {
if (!isa<AttrOp, MethodOp>(&child))
return child.emitOpError() << "is not allowed inside `torch.class_type`";
StringRef name;
@ -265,8 +215,8 @@ static LogicalResult verify(ClassTypeOp op) {
auto it = itAndWasInserted.first;
bool wasInserted = itAndWasInserted.second;
if (!wasInserted) {
auto diag = op.emitOpError().append(
"has duplicate attr/method with name '", name, "'");
auto diag = emitOpError().append("has duplicate attr/method with name '",
name, "'");
diag.attachNote(it->second->getLoc())
.append("see first conflicting attr/method here");
diag.attachNote(child.getLoc())
@ -282,6 +232,10 @@ static LogicalResult verify(ClassTypeOp op) {
// PrimLoopOp
//===----------------------------------------------------------------------===//
LogicalResult PrimLoopOp::verify() {
return RegionBranchOpInterface::verifyTypes(*this);
}
OperandRange PrimLoopOp::getSuccessorEntryOperands(unsigned index) {
assert(index == 0);
return iterArgsInit();
@ -321,6 +275,10 @@ PrimLoopConditionOp::getMutableSuccessorOperands(Optional<unsigned> index) {
// PrimIfOp
//===----------------------------------------------------------------------===//
LogicalResult PrimIfOp::verify() {
return RegionBranchOpInterface::verifyTypes(*this);
}
ParseResult PrimIfOp::parse(OpAsmParser &parser, OperationState &result) {
// Create the regions.
result.regions.reserve(2);
@ -1073,13 +1031,11 @@ void TensorStaticInfoCastOp::getCanonicalizationPatterns(
// CopyToNonValueTensorOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(CopyToNonValueTensorOp op) {
auto resultType = op.getResult().getType().cast<BaseTensorType>();
auto operandType = op.getOperand().getType().cast<BaseTensorType>();
if (!resultType.hasSameSizesAndDtype(operandType)) {
return op.emitError()
<< "operand and result must have same sizes and dtype";
}
LogicalResult CopyToNonValueTensorOp::verify() {
auto resultType = getResult().getType().cast<BaseTensorType>();
auto operandType = getOperand().getType().cast<BaseTensorType>();
if (!resultType.hasSameSizesAndDtype(operandType))
return emitError() << "operand and result must have same sizes and dtype";
return success();
}
@ -1102,13 +1058,11 @@ void CopyToNonValueTensorOp::getEffects(
// CopyToValueTensorOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(CopyToValueTensorOp op) {
auto resultType = op.getResult().getType().cast<BaseTensorType>();
auto operandType = op.getOperand().getType().cast<BaseTensorType>();
if (!resultType.hasSameSizesAndDtype(operandType)) {
return op.emitError()
<< "operand and result must have same sizes and dtype";
}
LogicalResult CopyToValueTensorOp::verify() {
auto resultType = getResult().getType().cast<BaseTensorType>();
auto operandType = getOperand().getType().cast<BaseTensorType>();
if (!resultType.hasSameSizesAndDtype(operandType))
return emitError() << "operand and result must have same sizes and dtype";
return success();
}
@ -1588,6 +1542,10 @@ OpFoldResult PrimMinSelfIntOp::fold(ArrayRef<Attribute> operands) {
// ShapeCalculateOp
//===----------------------------------------------------------------------===//
LogicalResult ShapeCalculateOp::verify() {
return RegionBranchOpInterface::verifyTypes(*this);
}
void ShapeCalculateOp::getSuccessorRegions(
Optional<unsigned> index, ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &regions) {
@ -1620,13 +1578,9 @@ MutableOperandRange ShapeCalculateYieldShapesOp::getMutableSuccessorOperands(
return MutableOperandRange(*this, /*start=*/0, /*length=*/0);
}
static LogicalResult verify(ShapeCalculateYieldShapesOp op) {
auto parent = op->getParentOfType<ShapeCalculateOp>();
if (parent.getNumResults() != op.getNumOperands())
return op.emitOpError(
"expected number of shapes to match number of results");
LogicalResult ShapeCalculateYieldShapesOp::verify() {
auto parent = cast<ShapeCalculateOp>(getOperation()->getParentOp());
if (parent.getNumResults() != getNumOperands())
return emitOpError("expected number of shapes to match number of results");
return success();
}
#define GET_OP_CLASSES
#include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc"

View File

@ -0,0 +1,35 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
//
// This file is meant to include the `TorchOps.cpp.inc` file and compile it
// separately from the main TorchOps.cpp file. The .inc file takes a very long
// time to compile, and slows down the iteration time on folders,
// canonicalizations, parser/printers, etc. in the actual TorchOps.cpp file, so
// it makes sense to isolate it and let the build system cache it.
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/Casting.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
#define GET_OP_CLASSES
#include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc"

View File

@ -17,6 +17,51 @@ using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
//===----------------------------------------------------------------------===//
// isValidSubtype
//===----------------------------------------------------------------------===//
bool Torch::isValidSubtype(Type subtype, Type type) {
if (subtype == type)
return true;
if (auto any = type.dyn_cast<AnyType>())
return true;
if (auto number = type.dyn_cast<NumberType>())
return subtype.isa<IntType>() || subtype.isa<Torch::FloatType>();
if (auto optional = type.dyn_cast<OptionalType>())
return isValidSubtype(subtype, optional.getContainedType()) ||
subtype.isa<Torch::NoneType>();
if (auto tuple = type.dyn_cast<Torch::TupleType>()) {
if (!subtype.isa<Torch::TupleType>())
return false;
auto subtypes = subtype.cast<Torch::TupleType>().getContainedTypes();
auto types = tuple.getContainedTypes();
if (subtypes.size() != types.size())
return false;
for (auto t : llvm::zip(subtypes, types)) {
if (!isValidSubtype(std::get<0>(t), std::get<1>(t)))
return false;
}
return true;
}
// TODO: This is not subtyping according to PEP 483. See description
// of NonValueTensorType.
if (subtype.isa<NonValueTensorType>() && type.isa<NonValueTensorType>() &&
type ==
NonValueTensorType::getWithLeastStaticInformation(type.getContext()))
return true;
if (subtype.isa<ValueTensorType>() && type.isa<ValueTensorType>() &&
type == ValueTensorType::getWithLeastStaticInformation(type.getContext()))
return true;
return false;
}
//===----------------------------------------------------------------------===//
// TupleType
//===----------------------------------------------------------------------===//