mirror of https://github.com/llvm/torch-mlir
Add a `AllowedInModuleInitializer` trait to denote ops that are permitted in the module initializer (#1379)
This PR adds an `AllowedInModuleInitializer` trait to keep track of ops that are permitted in the module initializer. We have a handful of such ops that are produced by the IValue importer, and so this change avoids maintaining a list of ops in `TorchOps.cpp` that could lead to spurious merge conflicts, and help us integrate torch-mlir in our downstream compiler better. Please let me know if you'd prefer a better name for the trait itself. Feedback is welcome!pull/1385/head
parent
e17fcea94e
commit
bb47b36eac
|
@ -57,5 +57,6 @@ def ReadOnly : TorchOpTrait<"ReadOnly">;
|
||||||
def IsTrailingUnderscoreInplaceVariant
|
def IsTrailingUnderscoreInplaceVariant
|
||||||
: TorchOpTrait<"IsTrailingUnderscoreInplaceVariant">;
|
: TorchOpTrait<"IsTrailingUnderscoreInplaceVariant">;
|
||||||
def AllowsTypeRefinement : TorchOpTrait<"AllowsTypeRefinement">;
|
def AllowsTypeRefinement : TorchOpTrait<"AllowsTypeRefinement">;
|
||||||
|
def AllowedInModuleInitializer : TorchOpTrait<"AllowedInModuleInitializer">;
|
||||||
|
|
||||||
#endif // TORCH_BASE
|
#endif // TORCH_BASE
|
||||||
|
|
|
@ -251,7 +251,8 @@ def Torch_GlobalSlotOp : Torch_Op<"global_slot", [
|
||||||
|
|
||||||
def Torch_GlobalSlotModuleInitializerOp : Torch_Op<"global_slot.module_initializer", [
|
def Torch_GlobalSlotModuleInitializerOp : Torch_Op<"global_slot.module_initializer", [
|
||||||
IsolatedFromAbove,
|
IsolatedFromAbove,
|
||||||
SingleBlockImplicitTerminator<"::mlir::torch::Torch::InitializeGlobalSlotsOp">
|
SingleBlockImplicitTerminator<"::mlir::torch::Torch::InitializeGlobalSlotsOp">,
|
||||||
|
AllowedInModuleInitializer,
|
||||||
]> {
|
]> {
|
||||||
let summary = "Module initializer for all `torch.global_slot` ops";
|
let summary = "Module initializer for all `torch.global_slot` ops";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -277,7 +278,9 @@ def Torch_GlobalSlotModuleInitializerOp : Torch_Op<"global_slot.module_initializ
|
||||||
|
|
||||||
def Torch_InitializeGlobalSlotsOp : Torch_Op<"initialize.global_slots", [
|
def Torch_InitializeGlobalSlotsOp : Torch_Op<"initialize.global_slots", [
|
||||||
Terminator,
|
Terminator,
|
||||||
HasParent<"::mlir::torch::Torch::GlobalSlotModuleInitializerOp">]> {
|
HasParent<"::mlir::torch::Torch::GlobalSlotModuleInitializerOp">,
|
||||||
|
AllowedInModuleInitializer,
|
||||||
|
]> {
|
||||||
let summary = "Terminator for torch.global_slot.module_initializer region";
|
let summary = "Terminator for torch.global_slot.module_initializer region";
|
||||||
let description = [{
|
let description = [{
|
||||||
Atomically updates the value of all the global slots named in `slotSymNames`
|
Atomically updates the value of all the global slots named in `slotSymNames`
|
||||||
|
@ -375,8 +378,9 @@ def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [
|
||||||
NoSideEffect,
|
NoSideEffect,
|
||||||
TypesMatchWith<"contained types correspond to operand types",
|
TypesMatchWith<"contained types correspond to operand types",
|
||||||
"elements", "result", "Torch::TupleType::get($_ctxt, llvm::to_vector<6>($_self))",
|
"elements", "result", "Torch::TupleType::get($_ctxt, llvm::to_vector<6>($_self))",
|
||||||
"isValidSubtype">
|
"isValidSubtype">,
|
||||||
]> {
|
AllowedInModuleInitializer,
|
||||||
|
]> {
|
||||||
let summary = "TorchScript prim::TupleConstruct op";
|
let summary = "TorchScript prim::TupleConstruct op";
|
||||||
let description = [{
|
let description = [{
|
||||||
Note: This op does not allow trivial type refinement, because the
|
Note: This op does not allow trivial type refinement, because the
|
||||||
|
@ -398,7 +402,8 @@ def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [
|
||||||
def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [
|
def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [
|
||||||
NoSideEffect,
|
NoSideEffect,
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
]> {
|
AllowedInModuleInitializer,
|
||||||
|
]> {
|
||||||
let summary = "TorchScript prim::ListConstruct op";
|
let summary = "TorchScript prim::ListConstruct op";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
|
@ -418,7 +423,8 @@ def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [
|
||||||
def Torch_PrimDictConstructOp: Torch_Op<"prim.DictConstruct", [
|
def Torch_PrimDictConstructOp: Torch_Op<"prim.DictConstruct", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
SameVariadicOperandSize,
|
SameVariadicOperandSize,
|
||||||
]> {
|
AllowedInModuleInitializer,
|
||||||
|
]> {
|
||||||
let summary = "TorchScript prim::DictConstruct op";
|
let summary = "TorchScript prim::DictConstruct op";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
|
@ -650,9 +656,12 @@ def Torch_PrimExitOp : Torch_Op<"prim.Exit", []> {
|
||||||
// Ops corresponding to prim::Constant
|
// Ops corresponding to prim::Constant
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def Torch_ConstantNoneOp : Torch_Op<"constant.none",
|
def Torch_ConstantNoneOp : Torch_Op<"constant.none", [
|
||||||
[ConstantLike, NoSideEffect,
|
ConstantLike,
|
||||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
|
NoSideEffect,
|
||||||
|
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
|
||||||
|
AllowedInModuleInitializer,
|
||||||
|
]> {
|
||||||
let summary = "Get the singleton None value.";
|
let summary = "Get the singleton None value.";
|
||||||
let description = [{
|
let description = [{
|
||||||
Not to be confused with the `mlir::NoneType`. Be careful to use
|
Not to be confused with the `mlir::NoneType`. Be careful to use
|
||||||
|
@ -664,9 +673,12 @@ def Torch_ConstantNoneOp : Torch_Op<"constant.none",
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_ConstantStrOp : Torch_Op<"constant.str",
|
def Torch_ConstantStrOp : Torch_Op<"constant.str", [
|
||||||
[ConstantLike, NoSideEffect,
|
ConstantLike,
|
||||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
|
NoSideEffect,
|
||||||
|
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
|
||||||
|
AllowedInModuleInitializer,
|
||||||
|
]> {
|
||||||
let summary = "Materialize a constant str value.";
|
let summary = "Materialize a constant str value.";
|
||||||
let description = [{
|
let description = [{
|
||||||
Note: Strings in Python (and TorchScript) are immutable.
|
Note: Strings in Python (and TorchScript) are immutable.
|
||||||
|
@ -697,9 +709,12 @@ def Torch_ConstantDeviceOp : Torch_Op<"constant.device",
|
||||||
let assemblyFormat = "$value attr-dict";
|
let assemblyFormat = "$value attr-dict";
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_ConstantIntOp : Torch_Op<"constant.int",
|
def Torch_ConstantIntOp : Torch_Op<"constant.int", [
|
||||||
[ConstantLike, NoSideEffect,
|
ConstantLike,
|
||||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
|
NoSideEffect,
|
||||||
|
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
|
||||||
|
AllowedInModuleInitializer,
|
||||||
|
]> {
|
||||||
let summary = "Materialize a constant `int` value.";
|
let summary = "Materialize a constant `int` value.";
|
||||||
let description = [{
|
let description = [{
|
||||||
Note: TorchScript represents integers as 64-bit signed values, unlike
|
Note: TorchScript represents integers as 64-bit signed values, unlike
|
||||||
|
@ -716,9 +731,12 @@ def Torch_ConstantIntOp : Torch_Op<"constant.int",
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_ConstantFloatOp : Torch_Op<"constant.float",
|
def Torch_ConstantFloatOp : Torch_Op<"constant.float", [
|
||||||
[ConstantLike, NoSideEffect,
|
ConstantLike,
|
||||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
|
NoSideEffect,
|
||||||
|
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
|
||||||
|
AllowedInModuleInitializer,
|
||||||
|
]> {
|
||||||
let summary = "Materialize a constant `float` value.";
|
let summary = "Materialize a constant `float` value.";
|
||||||
let description = [{
|
let description = [{
|
||||||
Note: TorchScript represents `float` as 64-bit floating point values.
|
Note: TorchScript represents `float` as 64-bit floating point values.
|
||||||
|
@ -735,9 +753,12 @@ def Torch_ConstantFloatOp : Torch_Op<"constant.float",
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_ConstantBoolOp : Torch_Op<"constant.bool",
|
def Torch_ConstantBoolOp : Torch_Op<"constant.bool", [
|
||||||
[ConstantLike, NoSideEffect,
|
ConstantLike,
|
||||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
|
NoSideEffect,
|
||||||
|
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
|
||||||
|
AllowedInModuleInitializer,
|
||||||
|
]> {
|
||||||
let summary = "Materialize a constant `bool` value.";
|
let summary = "Materialize a constant `bool` value.";
|
||||||
let description = [{
|
let description = [{
|
||||||
}];
|
}];
|
||||||
|
@ -808,7 +829,8 @@ def Torch_OperatorOp : Torch_Op<"operator", [
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_LinearParamsCreateOp : Torch_Op<"linear_params.create", [
|
def Torch_LinearParamsCreateOp : Torch_Op<"linear_params.create", [
|
||||||
AllowsTypeRefinement
|
AllowsTypeRefinement,
|
||||||
|
AllowedInModuleInitializer,
|
||||||
]> {
|
]> {
|
||||||
let summary = "Create a `!torch.LinearParams`";
|
let summary = "Create a `!torch.LinearParams`";
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
|
@ -823,7 +845,8 @@ def Torch_LinearParamsCreateOp : Torch_Op<"linear_params.create", [
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_PerTensorAffineCreateOp : Torch_Op<"per_tensor_affine.create", [
|
def Torch_PerTensorAffineCreateOp : Torch_Op<"per_tensor_affine.create", [
|
||||||
AllowsTypeRefinement
|
AllowsTypeRefinement,
|
||||||
|
AllowedInModuleInitializer,
|
||||||
]> {
|
]> {
|
||||||
let summary = "Create a per-tensor-affine quantized tensor";
|
let summary = "Create a per-tensor-affine quantized tensor";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -854,6 +877,7 @@ def Torch_PerTensorAffineCreateOp : Torch_Op<"per_tensor_affine.create", [
|
||||||
def Torch_NonValueTensorLiteralOp : Torch_Op<"tensor.literal", [
|
def Torch_NonValueTensorLiteralOp : Torch_Op<"tensor.literal", [
|
||||||
DeclareOpInterfaceMethods<InferTypeOpInterface, ["isCompatibleReturnTypes"]>,
|
DeclareOpInterfaceMethods<InferTypeOpInterface, ["isCompatibleReturnTypes"]>,
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
|
AllowedInModuleInitializer,
|
||||||
]> {
|
]> {
|
||||||
let summary = "Create a value of !torch.tensor type from a literal";
|
let summary = "Create a value of !torch.tensor type from a literal";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
|
|
@ -56,6 +56,14 @@ template <typename ConcreteType>
|
||||||
class AllowsTypeRefinement
|
class AllowsTypeRefinement
|
||||||
: public ::mlir::OpTrait::TraitBase<ConcreteType, AllowsTypeRefinement> {};
|
: public ::mlir::OpTrait::TraitBase<ConcreteType, AllowsTypeRefinement> {};
|
||||||
|
|
||||||
|
// If a Torch op has this trait, it means that the op is allowed to be used
|
||||||
|
// in the module initializer. Only a small set of ops are permitted in the
|
||||||
|
// module initializer. These ops are essentially those which can be produced
|
||||||
|
// by the IValue importer.
|
||||||
|
template <typename ConcreteType>
|
||||||
|
class AllowedInModuleInitializer
|
||||||
|
: public ::mlir::OpTrait::TraitBase<ConcreteType, AllowedInModuleInitializer> {};
|
||||||
|
|
||||||
} // namespace OpTrait
|
} // namespace OpTrait
|
||||||
} // namespace Torch
|
} // namespace Torch
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
@ -2203,11 +2203,7 @@ LogicalResult GlobalSlotModuleInitializerOp::verify() {
|
||||||
// We only permit a small set of ops in the module initializer.
|
// We only permit a small set of ops in the module initializer.
|
||||||
// These ops are essentially those which can be produced by the IValue
|
// These ops are essentially those which can be produced by the IValue
|
||||||
// importer.
|
// importer.
|
||||||
if (isa<GlobalSlotModuleInitializerOp, InitializeGlobalSlotsOp,
|
if (op->hasTrait<mlir::torch::Torch::OpTrait::AllowedInModuleInitializer>())
|
||||||
PrimListConstructOp, PrimDictConstructOp, PrimTupleConstructOp,
|
|
||||||
ConstantBoolOp, ConstantStrOp, ConstantIntOp, ConstantFloatOp,
|
|
||||||
ConstantNoneOp, NonValueTensorLiteralOp, PerTensorAffineCreateOp,
|
|
||||||
LinearParamsCreateOp>(op))
|
|
||||||
return WalkResult::advance();
|
return WalkResult::advance();
|
||||||
op->emitOpError() << "is not allowed in a module initializer";
|
op->emitOpError() << "is not allowed in a module initializer";
|
||||||
return WalkResult::interrupt();
|
return WalkResult::interrupt();
|
||||||
|
|
Loading…
Reference in New Issue