diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchBase.td b/include/torch-mlir/Dialect/Torch/IR/TorchBase.td index 0a58dfb03..a5e8767e6 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchBase.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchBase.td @@ -57,5 +57,6 @@ def ReadOnly : TorchOpTrait<"ReadOnly">; def IsTrailingUnderscoreInplaceVariant : TorchOpTrait<"IsTrailingUnderscoreInplaceVariant">; def AllowsTypeRefinement : TorchOpTrait<"AllowsTypeRefinement">; +def AllowedInModuleInitializer : TorchOpTrait<"AllowedInModuleInitializer">; #endif // TORCH_BASE diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index fae78b45a..d46431e6c 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -251,7 +251,8 @@ def Torch_GlobalSlotOp : Torch_Op<"global_slot", [ def Torch_GlobalSlotModuleInitializerOp : Torch_Op<"global_slot.module_initializer", [ IsolatedFromAbove, - SingleBlockImplicitTerminator<"::mlir::torch::Torch::InitializeGlobalSlotsOp"> + SingleBlockImplicitTerminator<"::mlir::torch::Torch::InitializeGlobalSlotsOp">, + AllowedInModuleInitializer, ]> { let summary = "Module initializer for all `torch.global_slot` ops"; let description = [{ @@ -277,7 +278,9 @@ def Torch_GlobalSlotModuleInitializerOp : Torch_Op<"global_slot.module_initializ def Torch_InitializeGlobalSlotsOp : Torch_Op<"initialize.global_slots", [ Terminator, - HasParent<"::mlir::torch::Torch::GlobalSlotModuleInitializerOp">]> { + HasParent<"::mlir::torch::Torch::GlobalSlotModuleInitializerOp">, + AllowedInModuleInitializer, + ]> { let summary = "Terminator for torch.global_slot.module_initializer region"; let description = [{ Atomically updates the value of all the global slots named in `slotSymNames` @@ -375,8 +378,9 @@ def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [ NoSideEffect, TypesMatchWith<"contained types correspond to operand types", "elements", "result", "Torch::TupleType::get($_ctxt, llvm::to_vector<6>($_self))", - "isValidSubtype"> - ]> { + "isValidSubtype">, + AllowedInModuleInitializer, + ]> { let summary = "TorchScript prim::TupleConstruct op"; let description = [{ Note: This op does not allow trivial type refinement, because the @@ -398,7 +402,8 @@ def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [ def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [ NoSideEffect, AllowsTypeRefinement, - ]> { + AllowedInModuleInitializer, + ]> { let summary = "TorchScript prim::ListConstruct op"; let arguments = (ins @@ -418,7 +423,8 @@ def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [ def Torch_PrimDictConstructOp: Torch_Op<"prim.DictConstruct", [ AllowsTypeRefinement, SameVariadicOperandSize, - ]> { + AllowedInModuleInitializer, + ]> { let summary = "TorchScript prim::DictConstruct op"; let arguments = (ins @@ -650,9 +656,12 @@ def Torch_PrimExitOp : Torch_Op<"prim.Exit", []> { // Ops corresponding to prim::Constant //===----------------------------------------------------------------------===// -def Torch_ConstantNoneOp : Torch_Op<"constant.none", - [ConstantLike, NoSideEffect, - DeclareOpInterfaceMethods]> { +def Torch_ConstantNoneOp : Torch_Op<"constant.none", [ + ConstantLike, + NoSideEffect, + DeclareOpInterfaceMethods, + AllowedInModuleInitializer, + ]> { let summary = "Get the singleton None value."; let description = [{ Not to be confused with the `mlir::NoneType`. Be careful to use @@ -664,9 +673,12 @@ def Torch_ConstantNoneOp : Torch_Op<"constant.none", let hasFolder = 1; } -def Torch_ConstantStrOp : Torch_Op<"constant.str", - [ConstantLike, NoSideEffect, - DeclareOpInterfaceMethods]> { +def Torch_ConstantStrOp : Torch_Op<"constant.str", [ + ConstantLike, + NoSideEffect, + DeclareOpInterfaceMethods, + AllowedInModuleInitializer, + ]> { let summary = "Materialize a constant str value."; let description = [{ Note: Strings in Python (and TorchScript) are immutable. @@ -697,9 +709,12 @@ def Torch_ConstantDeviceOp : Torch_Op<"constant.device", let assemblyFormat = "$value attr-dict"; } -def Torch_ConstantIntOp : Torch_Op<"constant.int", - [ConstantLike, NoSideEffect, - DeclareOpInterfaceMethods]> { +def Torch_ConstantIntOp : Torch_Op<"constant.int", [ + ConstantLike, + NoSideEffect, + DeclareOpInterfaceMethods, + AllowedInModuleInitializer, + ]> { let summary = "Materialize a constant `int` value."; let description = [{ Note: TorchScript represents integers as 64-bit signed values, unlike @@ -716,9 +731,12 @@ def Torch_ConstantIntOp : Torch_Op<"constant.int", let hasFolder = 1; } -def Torch_ConstantFloatOp : Torch_Op<"constant.float", - [ConstantLike, NoSideEffect, - DeclareOpInterfaceMethods]> { +def Torch_ConstantFloatOp : Torch_Op<"constant.float", [ + ConstantLike, + NoSideEffect, + DeclareOpInterfaceMethods, + AllowedInModuleInitializer, + ]> { let summary = "Materialize a constant `float` value."; let description = [{ Note: TorchScript represents `float` as 64-bit floating point values. @@ -735,9 +753,12 @@ def Torch_ConstantFloatOp : Torch_Op<"constant.float", let hasFolder = 1; } -def Torch_ConstantBoolOp : Torch_Op<"constant.bool", - [ConstantLike, NoSideEffect, - DeclareOpInterfaceMethods]> { +def Torch_ConstantBoolOp : Torch_Op<"constant.bool", [ + ConstantLike, + NoSideEffect, + DeclareOpInterfaceMethods, + AllowedInModuleInitializer, + ]> { let summary = "Materialize a constant `bool` value."; let description = [{ }]; @@ -808,7 +829,8 @@ def Torch_OperatorOp : Torch_Op<"operator", [ } def Torch_LinearParamsCreateOp : Torch_Op<"linear_params.create", [ - AllowsTypeRefinement + AllowsTypeRefinement, + AllowedInModuleInitializer, ]> { let summary = "Create a `!torch.LinearParams`"; let arguments = (ins @@ -823,7 +845,8 @@ def Torch_LinearParamsCreateOp : Torch_Op<"linear_params.create", [ } def Torch_PerTensorAffineCreateOp : Torch_Op<"per_tensor_affine.create", [ - AllowsTypeRefinement + AllowsTypeRefinement, + AllowedInModuleInitializer, ]> { let summary = "Create a per-tensor-affine quantized tensor"; let description = [{ @@ -854,6 +877,7 @@ def Torch_PerTensorAffineCreateOp : Torch_Op<"per_tensor_affine.create", [ def Torch_NonValueTensorLiteralOp : Torch_Op<"tensor.literal", [ DeclareOpInterfaceMethods, AllowsTypeRefinement, + AllowedInModuleInitializer, ]> { let summary = "Create a value of !torch.tensor type from a literal"; let description = [{ diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTraits.h b/include/torch-mlir/Dialect/Torch/IR/TorchTraits.h index 23b4c2ffe..20f1bc109 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTraits.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTraits.h @@ -56,6 +56,14 @@ template class AllowsTypeRefinement : public ::mlir::OpTrait::TraitBase {}; +// If a Torch op has this trait, it means that the op is allowed to be used +// in the module initializer. Only a small set of ops are permitted in the +// module initializer. These ops are essentially those which can be produced +// by the IValue importer. +template +class AllowedInModuleInitializer + : public ::mlir::OpTrait::TraitBase {}; + } // namespace OpTrait } // namespace Torch } // namespace torch diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 3d088aa9e..2a54c55ec 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2203,11 +2203,7 @@ LogicalResult GlobalSlotModuleInitializerOp::verify() { // We only permit a small set of ops in the module initializer. // These ops are essentially those which can be produced by the IValue // importer. - if (isa(op)) + if (op->hasTrait()) return WalkResult::advance(); op->emitOpError() << "is not allowed in a module initializer"; return WalkResult::interrupt();