mirror of https://github.com/llvm/torch-mlir
Make TorchOps.cpp faster to iterate on.
The ODS-generated code included via the `TorchOps.cpp.inc` file takes a very long time to compile. This PR isolates it into its own file so that the build system can cache it. This PR creates a new file `TorchOpsODSGenerated.cpp` just to include the `TorchOps.cpp.inc` file. Doing so required moving to the "new" way to define verifiers, since the static `verify` free functions in TorchOps.cpp weren't accessible from the .inc file after it was moved to `TorchOpsODSGenerated.cpp`. On my machine, this drops the build time of TorchOps.cpp (such as when iterating on a canonicalizer) from >40 seconds to <10 seconds. 10 seconds still isn't great though, but at least it isn't "go get a coffee" type of waiting.pull/679/head snapshot-20220316.328
parent
8da7d90611
commit
3b66b4925a
|
@ -62,7 +62,7 @@ def Torch_NnModuleOp : Torch_Op<"nn_module", [
|
|||
let arguments = (ins);
|
||||
let results = (outs Torch_NnModuleType:$result);
|
||||
let regions = (region SizedRegion<1>:$region);
|
||||
let verifier = "return ::verify(*this);";
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = "$region attr-dict `:` qualified(type($result))";
|
||||
|
||||
|
@ -146,7 +146,7 @@ def Torch_ClassTypeOp : Torch_Op<"class_type", [
|
|||
let arguments = (ins SymbolNameAttr:$sym_name);
|
||||
let results = (outs);
|
||||
let regions = (region SizedRegion<1>:$region);
|
||||
let verifier = "return ::verify(*this);";
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = "$sym_name $region attr-dict";
|
||||
}
|
||||
|
@ -360,7 +360,7 @@ def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [
|
|||
AnyTorchListType:$result
|
||||
);
|
||||
|
||||
let verifier = "return ::verify(*this);";
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$elements attr-dict `:` functional-type(operands, results)
|
||||
|
@ -382,7 +382,7 @@ def Torch_PrimDictConstructOp: Torch_Op<"prim.DictConstruct", [
|
|||
Torch_DictType:$result
|
||||
);
|
||||
|
||||
let verifier = "return ::verify(*this);";
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
`keys` `(` ($keys^ `:` qualified(type($keys)))? `)` `values` `(` ($values^ `:` qualified(type($values)))? `)` attr-dict `->` qualified(type($result))
|
||||
|
@ -474,7 +474,7 @@ def Torch_PrimLoopOp : Torch_Op<"prim.Loop", [
|
|||
$maxTripCount `,` $initialCondition `,` `init` `(` $iterArgsInit `)` $region
|
||||
attr-dict `:` functional-type(operands, results)
|
||||
}];
|
||||
let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }];
|
||||
let hasVerifier = 1;
|
||||
let extraClassDeclaration = [{
|
||||
/// Returns true if this loop is "for-like". Otherwise it is "while-like"
|
||||
/// and this function returns false.
|
||||
|
@ -528,7 +528,7 @@ def Torch_PrimIfOp : Torch_Op<"prim.If", [
|
|||
let regions = (region SizedRegion<1>:$thenRegion, SizedRegion<1>:$elseRegion);
|
||||
// Indicate that the operation has a custom parser and printer method.
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }];
|
||||
let hasVerifier = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
|
@ -877,7 +877,7 @@ def Torch_CopyToNonValueTensorOp : Torch_Op<"copy.to_tensor", [
|
|||
let assemblyFormat = [{
|
||||
$operand attr-dict `:` qualified(type($result))
|
||||
}];
|
||||
let verifier = "return ::verify(*this);";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [
|
||||
|
@ -907,7 +907,7 @@ def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [
|
|||
let assemblyFormat = [{
|
||||
$operand attr-dict `:` qualified(type($result))
|
||||
}];
|
||||
let verifier = "return ::verify(*this);";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def Torch_OverwriteTensorContentsOp : Torch_Op<"overwrite.tensor.contents", [
|
||||
|
@ -1104,7 +1104,7 @@ def Torch_ShapeCalculateOp : Torch_Op<"shape.calculate", [
|
|||
let assemblyFormat = [{
|
||||
$body `shapes` $shapeCalculation attr-dict `:` type($results)
|
||||
}];
|
||||
let verifier = [{ return RegionBranchOpInterface::verifyTypes(*this); }];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def Torch_ShapeCalculateYieldOp : Torch_Op<"shape.calculate.yield", [
|
||||
|
@ -1147,7 +1147,7 @@ def Torch_ShapeCalculateYieldShapesOp : Torch_Op<"shape.calculate.yield.shapes",
|
|||
attr-dict ($results^ `:` type($results))?
|
||||
}];
|
||||
|
||||
let verifier = "return ::verify(*this);";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
#endif // TORCH_OPS
|
||||
|
|
|
@ -16,6 +16,16 @@ namespace mlir {
|
|||
namespace torch {
|
||||
namespace Torch {
|
||||
|
||||
/// PyTorch has a well-developed notion of subtyping.
|
||||
///
|
||||
/// This is a restricted subset of it that only handles a few special cases
|
||||
/// that we need to model.
|
||||
///
|
||||
/// TODO: Flesh this out.
|
||||
/// TODO: Decide / properly model the distinction between PEP 483 / Python
|
||||
/// subtyping vs "more static information".
|
||||
bool isValidSubtype(Type subtype, Type type);
|
||||
|
||||
class NonValueTensorType;
|
||||
class ValueTensorType;
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
add_mlir_library(TorchMLIRTorchDialect
|
||||
TorchDialect.cpp
|
||||
TorchOps.cpp
|
||||
TorchOpsODSGenerated.cpp
|
||||
TorchTypes.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
|
|
|
@ -125,61 +125,13 @@ LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|||
// NnModuleOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(NnModuleOp op) {
|
||||
for (Operation &child : *op.getBody())
|
||||
LogicalResult NnModuleOp::verify() {
|
||||
for (Operation &child : *getBody())
|
||||
if (!isa<SlotOp, NnModuleTerminatorOp>(&child))
|
||||
return child.emitOpError() << "is not allowed inside 'torch.nn_module'";
|
||||
return success();
|
||||
}
|
||||
|
||||
// PyTorch has a well-developed notion of subtyping.
|
||||
//
|
||||
// This is a restricted subset of it.
|
||||
//
|
||||
// TODO: Flesh this out.
|
||||
// TODO: Decide / properly model the distinction between PEP 483 / Python
|
||||
// subtyping vs "more static information".
|
||||
bool isValidSubtype(Type subtype, Type type) {
|
||||
if (subtype == type)
|
||||
return true;
|
||||
|
||||
if (auto any = type.dyn_cast<AnyType>())
|
||||
return true;
|
||||
|
||||
if (auto number = type.dyn_cast<NumberType>())
|
||||
return subtype.isa<IntType>() || subtype.isa<Torch::FloatType>();
|
||||
|
||||
if (auto optional = type.dyn_cast<OptionalType>())
|
||||
return isValidSubtype(subtype, optional.getContainedType()) ||
|
||||
subtype.isa<Torch::NoneType>();
|
||||
|
||||
if (auto tuple = type.dyn_cast<Torch::TupleType>()) {
|
||||
if (!subtype.isa<Torch::TupleType>())
|
||||
return false;
|
||||
auto subtypes = subtype.cast<Torch::TupleType>().getContainedTypes();
|
||||
auto types = tuple.getContainedTypes();
|
||||
if (subtypes.size() != types.size())
|
||||
return false;
|
||||
for (auto t : llvm::zip(subtypes, types)) {
|
||||
if (!isValidSubtype(std::get<0>(t), std::get<1>(t)))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// TODO: This is not subtyping according to PEP 483. See description
|
||||
// of NonValueTensorType.
|
||||
if (subtype.isa<NonValueTensorType>() && type.isa<NonValueTensorType>() &&
|
||||
type ==
|
||||
NonValueTensorType::getWithLeastStaticInformation(type.getContext()))
|
||||
return true;
|
||||
|
||||
if (subtype.isa<ValueTensorType>() && type.isa<ValueTensorType>() &&
|
||||
type == ValueTensorType::getWithLeastStaticInformation(type.getContext()))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
LogicalResult NnModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
auto classType = symbolTable.lookupNearestSymbolFrom<ClassTypeOp>(
|
||||
*this, SymbolRefAttr::get(getContext(), getClassName()));
|
||||
|
@ -213,15 +165,15 @@ LogicalResult NnModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|||
// PrimListConstructOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(PrimListConstructOp op) {
|
||||
auto resultType = op.getResult().getType();
|
||||
LogicalResult PrimListConstructOp::verify() {
|
||||
auto resultType = getResult().getType();
|
||||
auto resultElementType = resultType.dyn_cast<ListType>().getContainedType();
|
||||
auto matchResultElementType = [&](Type type) {
|
||||
return isValidSubtype(type, resultElementType);
|
||||
};
|
||||
if (!llvm::all_of(op->getOperandTypes(), matchResultElementType)) {
|
||||
return op.emitError() << "operand types should have the same type as the "
|
||||
"list contained type";
|
||||
if (!llvm::all_of(getOperandTypes(), matchResultElementType)) {
|
||||
return emitError() << "operand types should have the same type as the "
|
||||
"list contained type";
|
||||
}
|
||||
|
||||
return success();
|
||||
|
@ -231,18 +183,16 @@ static LogicalResult verify(PrimListConstructOp op) {
|
|||
// PrimDictConstructOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(PrimDictConstructOp op) {
|
||||
LogicalResult PrimDictConstructOp::verify() {
|
||||
auto isValidSubTypeOf = [](Type expectedType) {
|
||||
return [=](Type type) { return isValidSubtype(type, expectedType); };
|
||||
};
|
||||
|
||||
Type keyType = op.getKeyType();
|
||||
if (!llvm::all_of(op.keys().getTypes(), isValidSubTypeOf(keyType)))
|
||||
return op.emitError() << "keys should be of Dict key type";
|
||||
if (!llvm::all_of(keys().getTypes(), isValidSubTypeOf(getKeyType())))
|
||||
return emitError() << "keys should be of Dict key type";
|
||||
|
||||
Type valueType = op.getValueType();
|
||||
if (!llvm::all_of(op.values().getTypes(), isValidSubTypeOf(valueType)))
|
||||
return op.emitError() << "values should be of Dict value type";
|
||||
if (!llvm::all_of(values().getTypes(), isValidSubTypeOf(getValueType())))
|
||||
return emitError() << "values should be of Dict value type";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -251,9 +201,9 @@ static LogicalResult verify(PrimDictConstructOp op) {
|
|||
// ClassTypeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(ClassTypeOp op) {
|
||||
LogicalResult ClassTypeOp::verify() {
|
||||
llvm::StringMap<Operation *> namesToOps;
|
||||
for (Operation &child : op.getBody()->without_terminator()) {
|
||||
for (Operation &child : getBody()->without_terminator()) {
|
||||
if (!isa<AttrOp, MethodOp>(&child))
|
||||
return child.emitOpError() << "is not allowed inside `torch.class_type`";
|
||||
StringRef name;
|
||||
|
@ -265,8 +215,8 @@ static LogicalResult verify(ClassTypeOp op) {
|
|||
auto it = itAndWasInserted.first;
|
||||
bool wasInserted = itAndWasInserted.second;
|
||||
if (!wasInserted) {
|
||||
auto diag = op.emitOpError().append(
|
||||
"has duplicate attr/method with name '", name, "'");
|
||||
auto diag = emitOpError().append("has duplicate attr/method with name '",
|
||||
name, "'");
|
||||
diag.attachNote(it->second->getLoc())
|
||||
.append("see first conflicting attr/method here");
|
||||
diag.attachNote(child.getLoc())
|
||||
|
@ -282,6 +232,10 @@ static LogicalResult verify(ClassTypeOp op) {
|
|||
// PrimLoopOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult PrimLoopOp::verify() {
|
||||
return RegionBranchOpInterface::verifyTypes(*this);
|
||||
}
|
||||
|
||||
OperandRange PrimLoopOp::getSuccessorEntryOperands(unsigned index) {
|
||||
assert(index == 0);
|
||||
return iterArgsInit();
|
||||
|
@ -321,6 +275,10 @@ PrimLoopConditionOp::getMutableSuccessorOperands(Optional<unsigned> index) {
|
|||
// PrimIfOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult PrimIfOp::verify() {
|
||||
return RegionBranchOpInterface::verifyTypes(*this);
|
||||
}
|
||||
|
||||
ParseResult PrimIfOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
// Create the regions.
|
||||
result.regions.reserve(2);
|
||||
|
@ -1073,13 +1031,11 @@ void TensorStaticInfoCastOp::getCanonicalizationPatterns(
|
|||
// CopyToNonValueTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(CopyToNonValueTensorOp op) {
|
||||
auto resultType = op.getResult().getType().cast<BaseTensorType>();
|
||||
auto operandType = op.getOperand().getType().cast<BaseTensorType>();
|
||||
if (!resultType.hasSameSizesAndDtype(operandType)) {
|
||||
return op.emitError()
|
||||
<< "operand and result must have same sizes and dtype";
|
||||
}
|
||||
LogicalResult CopyToNonValueTensorOp::verify() {
|
||||
auto resultType = getResult().getType().cast<BaseTensorType>();
|
||||
auto operandType = getOperand().getType().cast<BaseTensorType>();
|
||||
if (!resultType.hasSameSizesAndDtype(operandType))
|
||||
return emitError() << "operand and result must have same sizes and dtype";
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -1102,13 +1058,11 @@ void CopyToNonValueTensorOp::getEffects(
|
|||
// CopyToValueTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(CopyToValueTensorOp op) {
|
||||
auto resultType = op.getResult().getType().cast<BaseTensorType>();
|
||||
auto operandType = op.getOperand().getType().cast<BaseTensorType>();
|
||||
if (!resultType.hasSameSizesAndDtype(operandType)) {
|
||||
return op.emitError()
|
||||
<< "operand and result must have same sizes and dtype";
|
||||
}
|
||||
LogicalResult CopyToValueTensorOp::verify() {
|
||||
auto resultType = getResult().getType().cast<BaseTensorType>();
|
||||
auto operandType = getOperand().getType().cast<BaseTensorType>();
|
||||
if (!resultType.hasSameSizesAndDtype(operandType))
|
||||
return emitError() << "operand and result must have same sizes and dtype";
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -1588,6 +1542,10 @@ OpFoldResult PrimMinSelfIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
// ShapeCalculateOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ShapeCalculateOp::verify() {
|
||||
return RegionBranchOpInterface::verifyTypes(*this);
|
||||
}
|
||||
|
||||
void ShapeCalculateOp::getSuccessorRegions(
|
||||
Optional<unsigned> index, ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
|
@ -1620,13 +1578,9 @@ MutableOperandRange ShapeCalculateYieldShapesOp::getMutableSuccessorOperands(
|
|||
return MutableOperandRange(*this, /*start=*/0, /*length=*/0);
|
||||
}
|
||||
|
||||
static LogicalResult verify(ShapeCalculateYieldShapesOp op) {
|
||||
auto parent = op->getParentOfType<ShapeCalculateOp>();
|
||||
if (parent.getNumResults() != op.getNumOperands())
|
||||
return op.emitOpError(
|
||||
"expected number of shapes to match number of results");
|
||||
LogicalResult ShapeCalculateYieldShapesOp::verify() {
|
||||
auto parent = cast<ShapeCalculateOp>(getOperation()->getParentOp());
|
||||
if (parent.getNumResults() != getNumOperands())
|
||||
return emitOpError("expected number of shapes to match number of results");
|
||||
return success();
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc"
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file is meant to include the `TorchOps.cpp.inc` file and compile it
|
||||
// separately from the main TorchOps.cpp file. The .inc file takes a very long
|
||||
// time to compile, and slows down the iteration time on folders,
|
||||
// canonicalizations, parser/printers, etc. in the actual TorchOps.cpp file, so
|
||||
// it makes sense to isolate it and let the build system cache it.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "llvm/ADT/BitVector.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.cpp.inc"
|
|
@ -17,6 +17,51 @@ using namespace mlir;
|
|||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// isValidSubtype
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool Torch::isValidSubtype(Type subtype, Type type) {
|
||||
if (subtype == type)
|
||||
return true;
|
||||
|
||||
if (auto any = type.dyn_cast<AnyType>())
|
||||
return true;
|
||||
|
||||
if (auto number = type.dyn_cast<NumberType>())
|
||||
return subtype.isa<IntType>() || subtype.isa<Torch::FloatType>();
|
||||
|
||||
if (auto optional = type.dyn_cast<OptionalType>())
|
||||
return isValidSubtype(subtype, optional.getContainedType()) ||
|
||||
subtype.isa<Torch::NoneType>();
|
||||
|
||||
if (auto tuple = type.dyn_cast<Torch::TupleType>()) {
|
||||
if (!subtype.isa<Torch::TupleType>())
|
||||
return false;
|
||||
auto subtypes = subtype.cast<Torch::TupleType>().getContainedTypes();
|
||||
auto types = tuple.getContainedTypes();
|
||||
if (subtypes.size() != types.size())
|
||||
return false;
|
||||
for (auto t : llvm::zip(subtypes, types)) {
|
||||
if (!isValidSubtype(std::get<0>(t), std::get<1>(t)))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// TODO: This is not subtyping according to PEP 483. See description
|
||||
// of NonValueTensorType.
|
||||
if (subtype.isa<NonValueTensorType>() && type.isa<NonValueTensorType>() &&
|
||||
type ==
|
||||
NonValueTensorType::getWithLeastStaticInformation(type.getContext()))
|
||||
return true;
|
||||
|
||||
if (subtype.isa<ValueTensorType>() && type.isa<ValueTensorType>() &&
|
||||
type == ValueTensorType::getWithLeastStaticInformation(type.getContext()))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TupleType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
Loading…
Reference in New Issue