Add a `AllowedInModuleInitializer` trait to denote ops that are permitted in the module initializer (#1379)

This PR adds an `AllowedInModuleInitializer` trait to keep track of ops that are permitted in the module initializer. We have a handful of such ops that are produced by the IValue importer, and so this change avoids maintaining a list of ops in `TorchOps.cpp` that could lead to spurious merge conflicts, and help us integrate torch-mlir in our downstream compiler better. Please let me know if you'd prefer a better name for the trait itself. Feedback is welcome!
pull/1385/head
Sambhav Jain 2022-09-19 14:56:35 -07:00 committed by GitHub
parent e17fcea94e
commit bb47b36eac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 57 additions and 28 deletions

View File

@ -57,5 +57,6 @@ def ReadOnly : TorchOpTrait<"ReadOnly">;
def IsTrailingUnderscoreInplaceVariant def IsTrailingUnderscoreInplaceVariant
: TorchOpTrait<"IsTrailingUnderscoreInplaceVariant">; : TorchOpTrait<"IsTrailingUnderscoreInplaceVariant">;
def AllowsTypeRefinement : TorchOpTrait<"AllowsTypeRefinement">; def AllowsTypeRefinement : TorchOpTrait<"AllowsTypeRefinement">;
def AllowedInModuleInitializer : TorchOpTrait<"AllowedInModuleInitializer">;
#endif // TORCH_BASE #endif // TORCH_BASE

View File

@ -251,7 +251,8 @@ def Torch_GlobalSlotOp : Torch_Op<"global_slot", [
def Torch_GlobalSlotModuleInitializerOp : Torch_Op<"global_slot.module_initializer", [ def Torch_GlobalSlotModuleInitializerOp : Torch_Op<"global_slot.module_initializer", [
IsolatedFromAbove, IsolatedFromAbove,
SingleBlockImplicitTerminator<"::mlir::torch::Torch::InitializeGlobalSlotsOp"> SingleBlockImplicitTerminator<"::mlir::torch::Torch::InitializeGlobalSlotsOp">,
AllowedInModuleInitializer,
]> { ]> {
let summary = "Module initializer for all `torch.global_slot` ops"; let summary = "Module initializer for all `torch.global_slot` ops";
let description = [{ let description = [{
@ -277,7 +278,9 @@ def Torch_GlobalSlotModuleInitializerOp : Torch_Op<"global_slot.module_initializ
def Torch_InitializeGlobalSlotsOp : Torch_Op<"initialize.global_slots", [ def Torch_InitializeGlobalSlotsOp : Torch_Op<"initialize.global_slots", [
Terminator, Terminator,
HasParent<"::mlir::torch::Torch::GlobalSlotModuleInitializerOp">]> { HasParent<"::mlir::torch::Torch::GlobalSlotModuleInitializerOp">,
AllowedInModuleInitializer,
]> {
let summary = "Terminator for torch.global_slot.module_initializer region"; let summary = "Terminator for torch.global_slot.module_initializer region";
let description = [{ let description = [{
Atomically updates the value of all the global slots named in `slotSymNames` Atomically updates the value of all the global slots named in `slotSymNames`
@ -375,8 +378,9 @@ def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [
NoSideEffect, NoSideEffect,
TypesMatchWith<"contained types correspond to operand types", TypesMatchWith<"contained types correspond to operand types",
"elements", "result", "Torch::TupleType::get($_ctxt, llvm::to_vector<6>($_self))", "elements", "result", "Torch::TupleType::get($_ctxt, llvm::to_vector<6>($_self))",
"isValidSubtype"> "isValidSubtype">,
]> { AllowedInModuleInitializer,
]> {
let summary = "TorchScript prim::TupleConstruct op"; let summary = "TorchScript prim::TupleConstruct op";
let description = [{ let description = [{
Note: This op does not allow trivial type refinement, because the Note: This op does not allow trivial type refinement, because the
@ -398,7 +402,8 @@ def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [
def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [ def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [
NoSideEffect, NoSideEffect,
AllowsTypeRefinement, AllowsTypeRefinement,
]> { AllowedInModuleInitializer,
]> {
let summary = "TorchScript prim::ListConstruct op"; let summary = "TorchScript prim::ListConstruct op";
let arguments = (ins let arguments = (ins
@ -418,7 +423,8 @@ def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [
def Torch_PrimDictConstructOp: Torch_Op<"prim.DictConstruct", [ def Torch_PrimDictConstructOp: Torch_Op<"prim.DictConstruct", [
AllowsTypeRefinement, AllowsTypeRefinement,
SameVariadicOperandSize, SameVariadicOperandSize,
]> { AllowedInModuleInitializer,
]> {
let summary = "TorchScript prim::DictConstruct op"; let summary = "TorchScript prim::DictConstruct op";
let arguments = (ins let arguments = (ins
@ -650,9 +656,12 @@ def Torch_PrimExitOp : Torch_Op<"prim.Exit", []> {
// Ops corresponding to prim::Constant // Ops corresponding to prim::Constant
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def Torch_ConstantNoneOp : Torch_Op<"constant.none", def Torch_ConstantNoneOp : Torch_Op<"constant.none", [
[ConstantLike, NoSideEffect, ConstantLike,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> { NoSideEffect,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
AllowedInModuleInitializer,
]> {
let summary = "Get the singleton None value."; let summary = "Get the singleton None value.";
let description = [{ let description = [{
Not to be confused with the `mlir::NoneType`. Be careful to use Not to be confused with the `mlir::NoneType`. Be careful to use
@ -664,9 +673,12 @@ def Torch_ConstantNoneOp : Torch_Op<"constant.none",
let hasFolder = 1; let hasFolder = 1;
} }
def Torch_ConstantStrOp : Torch_Op<"constant.str", def Torch_ConstantStrOp : Torch_Op<"constant.str", [
[ConstantLike, NoSideEffect, ConstantLike,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> { NoSideEffect,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
AllowedInModuleInitializer,
]> {
let summary = "Materialize a constant str value."; let summary = "Materialize a constant str value.";
let description = [{ let description = [{
Note: Strings in Python (and TorchScript) are immutable. Note: Strings in Python (and TorchScript) are immutable.
@ -697,9 +709,12 @@ def Torch_ConstantDeviceOp : Torch_Op<"constant.device",
let assemblyFormat = "$value attr-dict"; let assemblyFormat = "$value attr-dict";
} }
def Torch_ConstantIntOp : Torch_Op<"constant.int", def Torch_ConstantIntOp : Torch_Op<"constant.int", [
[ConstantLike, NoSideEffect, ConstantLike,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> { NoSideEffect,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
AllowedInModuleInitializer,
]> {
let summary = "Materialize a constant `int` value."; let summary = "Materialize a constant `int` value.";
let description = [{ let description = [{
Note: TorchScript represents integers as 64-bit signed values, unlike Note: TorchScript represents integers as 64-bit signed values, unlike
@ -716,9 +731,12 @@ def Torch_ConstantIntOp : Torch_Op<"constant.int",
let hasFolder = 1; let hasFolder = 1;
} }
def Torch_ConstantFloatOp : Torch_Op<"constant.float", def Torch_ConstantFloatOp : Torch_Op<"constant.float", [
[ConstantLike, NoSideEffect, ConstantLike,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> { NoSideEffect,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
AllowedInModuleInitializer,
]> {
let summary = "Materialize a constant `float` value."; let summary = "Materialize a constant `float` value.";
let description = [{ let description = [{
Note: TorchScript represents `float` as 64-bit floating point values. Note: TorchScript represents `float` as 64-bit floating point values.
@ -735,9 +753,12 @@ def Torch_ConstantFloatOp : Torch_Op<"constant.float",
let hasFolder = 1; let hasFolder = 1;
} }
def Torch_ConstantBoolOp : Torch_Op<"constant.bool", def Torch_ConstantBoolOp : Torch_Op<"constant.bool", [
[ConstantLike, NoSideEffect, ConstantLike,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> { NoSideEffect,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
AllowedInModuleInitializer,
]> {
let summary = "Materialize a constant `bool` value."; let summary = "Materialize a constant `bool` value.";
let description = [{ let description = [{
}]; }];
@ -808,7 +829,8 @@ def Torch_OperatorOp : Torch_Op<"operator", [
} }
def Torch_LinearParamsCreateOp : Torch_Op<"linear_params.create", [ def Torch_LinearParamsCreateOp : Torch_Op<"linear_params.create", [
AllowsTypeRefinement AllowsTypeRefinement,
AllowedInModuleInitializer,
]> { ]> {
let summary = "Create a `!torch.LinearParams`"; let summary = "Create a `!torch.LinearParams`";
let arguments = (ins let arguments = (ins
@ -823,7 +845,8 @@ def Torch_LinearParamsCreateOp : Torch_Op<"linear_params.create", [
} }
def Torch_PerTensorAffineCreateOp : Torch_Op<"per_tensor_affine.create", [ def Torch_PerTensorAffineCreateOp : Torch_Op<"per_tensor_affine.create", [
AllowsTypeRefinement AllowsTypeRefinement,
AllowedInModuleInitializer,
]> { ]> {
let summary = "Create a per-tensor-affine quantized tensor"; let summary = "Create a per-tensor-affine quantized tensor";
let description = [{ let description = [{
@ -854,6 +877,7 @@ def Torch_PerTensorAffineCreateOp : Torch_Op<"per_tensor_affine.create", [
def Torch_NonValueTensorLiteralOp : Torch_Op<"tensor.literal", [ def Torch_NonValueTensorLiteralOp : Torch_Op<"tensor.literal", [
DeclareOpInterfaceMethods<InferTypeOpInterface, ["isCompatibleReturnTypes"]>, DeclareOpInterfaceMethods<InferTypeOpInterface, ["isCompatibleReturnTypes"]>,
AllowsTypeRefinement, AllowsTypeRefinement,
AllowedInModuleInitializer,
]> { ]> {
let summary = "Create a value of !torch.tensor type from a literal"; let summary = "Create a value of !torch.tensor type from a literal";
let description = [{ let description = [{

View File

@ -56,6 +56,14 @@ template <typename ConcreteType>
class AllowsTypeRefinement class AllowsTypeRefinement
: public ::mlir::OpTrait::TraitBase<ConcreteType, AllowsTypeRefinement> {}; : public ::mlir::OpTrait::TraitBase<ConcreteType, AllowsTypeRefinement> {};
// If a Torch op has this trait, it means that the op is allowed to be used
// in the module initializer. Only a small set of ops are permitted in the
// module initializer. These ops are essentially those which can be produced
// by the IValue importer.
template <typename ConcreteType>
class AllowedInModuleInitializer
: public ::mlir::OpTrait::TraitBase<ConcreteType, AllowedInModuleInitializer> {};
} // namespace OpTrait } // namespace OpTrait
} // namespace Torch } // namespace Torch
} // namespace torch } // namespace torch

View File

@ -2203,11 +2203,7 @@ LogicalResult GlobalSlotModuleInitializerOp::verify() {
// We only permit a small set of ops in the module initializer. // We only permit a small set of ops in the module initializer.
// These ops are essentially those which can be produced by the IValue // These ops are essentially those which can be produced by the IValue
// importer. // importer.
if (isa<GlobalSlotModuleInitializerOp, InitializeGlobalSlotsOp, if (op->hasTrait<mlir::torch::Torch::OpTrait::AllowedInModuleInitializer>())
PrimListConstructOp, PrimDictConstructOp, PrimTupleConstructOp,
ConstantBoolOp, ConstantStrOp, ConstantIntOp, ConstantFloatOp,
ConstantNoneOp, NonValueTensorLiteralOp, PerTensorAffineCreateOp,
LinearParamsCreateOp>(op))
return WalkResult::advance(); return WalkResult::advance();
op->emitOpError() << "is not allowed in a module initializer"; op->emitOpError() << "is not allowed in a module initializer";
return WalkResult::interrupt(); return WalkResult::interrupt();