mirror of https://github.com/llvm/torch-mlir
Add a `AllowedInModuleInitializer` trait to denote ops that are permitted in the module initializer (#1379)
This PR adds an `AllowedInModuleInitializer` trait to keep track of ops that are permitted in the module initializer. We have a handful of such ops that are produced by the IValue importer, and so this change avoids maintaining a list of ops in `TorchOps.cpp` that could lead to spurious merge conflicts, and help us integrate torch-mlir in our downstream compiler better. Please let me know if you'd prefer a better name for the trait itself. Feedback is welcome!pull/1385/head
parent
e17fcea94e
commit
bb47b36eac
|
@ -57,5 +57,6 @@ def ReadOnly : TorchOpTrait<"ReadOnly">;
|
|||
def IsTrailingUnderscoreInplaceVariant
|
||||
: TorchOpTrait<"IsTrailingUnderscoreInplaceVariant">;
|
||||
def AllowsTypeRefinement : TorchOpTrait<"AllowsTypeRefinement">;
|
||||
def AllowedInModuleInitializer : TorchOpTrait<"AllowedInModuleInitializer">;
|
||||
|
||||
#endif // TORCH_BASE
|
||||
|
|
|
@ -251,7 +251,8 @@ def Torch_GlobalSlotOp : Torch_Op<"global_slot", [
|
|||
|
||||
def Torch_GlobalSlotModuleInitializerOp : Torch_Op<"global_slot.module_initializer", [
|
||||
IsolatedFromAbove,
|
||||
SingleBlockImplicitTerminator<"::mlir::torch::Torch::InitializeGlobalSlotsOp">
|
||||
SingleBlockImplicitTerminator<"::mlir::torch::Torch::InitializeGlobalSlotsOp">,
|
||||
AllowedInModuleInitializer,
|
||||
]> {
|
||||
let summary = "Module initializer for all `torch.global_slot` ops";
|
||||
let description = [{
|
||||
|
@ -277,7 +278,9 @@ def Torch_GlobalSlotModuleInitializerOp : Torch_Op<"global_slot.module_initializ
|
|||
|
||||
def Torch_InitializeGlobalSlotsOp : Torch_Op<"initialize.global_slots", [
|
||||
Terminator,
|
||||
HasParent<"::mlir::torch::Torch::GlobalSlotModuleInitializerOp">]> {
|
||||
HasParent<"::mlir::torch::Torch::GlobalSlotModuleInitializerOp">,
|
||||
AllowedInModuleInitializer,
|
||||
]> {
|
||||
let summary = "Terminator for torch.global_slot.module_initializer region";
|
||||
let description = [{
|
||||
Atomically updates the value of all the global slots named in `slotSymNames`
|
||||
|
@ -375,8 +378,9 @@ def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [
|
|||
NoSideEffect,
|
||||
TypesMatchWith<"contained types correspond to operand types",
|
||||
"elements", "result", "Torch::TupleType::get($_ctxt, llvm::to_vector<6>($_self))",
|
||||
"isValidSubtype">
|
||||
]> {
|
||||
"isValidSubtype">,
|
||||
AllowedInModuleInitializer,
|
||||
]> {
|
||||
let summary = "TorchScript prim::TupleConstruct op";
|
||||
let description = [{
|
||||
Note: This op does not allow trivial type refinement, because the
|
||||
|
@ -398,7 +402,8 @@ def Torch_PrimTupleConstructOp: Torch_Op<"prim.TupleConstruct", [
|
|||
def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [
|
||||
NoSideEffect,
|
||||
AllowsTypeRefinement,
|
||||
]> {
|
||||
AllowedInModuleInitializer,
|
||||
]> {
|
||||
let summary = "TorchScript prim::ListConstruct op";
|
||||
|
||||
let arguments = (ins
|
||||
|
@ -418,7 +423,8 @@ def Torch_PrimListConstructOp: Torch_Op<"prim.ListConstruct", [
|
|||
def Torch_PrimDictConstructOp: Torch_Op<"prim.DictConstruct", [
|
||||
AllowsTypeRefinement,
|
||||
SameVariadicOperandSize,
|
||||
]> {
|
||||
AllowedInModuleInitializer,
|
||||
]> {
|
||||
let summary = "TorchScript prim::DictConstruct op";
|
||||
|
||||
let arguments = (ins
|
||||
|
@ -650,9 +656,12 @@ def Torch_PrimExitOp : Torch_Op<"prim.Exit", []> {
|
|||
// Ops corresponding to prim::Constant
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Torch_ConstantNoneOp : Torch_Op<"constant.none",
|
||||
[ConstantLike, NoSideEffect,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
|
||||
def Torch_ConstantNoneOp : Torch_Op<"constant.none", [
|
||||
ConstantLike,
|
||||
NoSideEffect,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
|
||||
AllowedInModuleInitializer,
|
||||
]> {
|
||||
let summary = "Get the singleton None value.";
|
||||
let description = [{
|
||||
Not to be confused with the `mlir::NoneType`. Be careful to use
|
||||
|
@ -664,9 +673,12 @@ def Torch_ConstantNoneOp : Torch_Op<"constant.none",
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_ConstantStrOp : Torch_Op<"constant.str",
|
||||
[ConstantLike, NoSideEffect,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
|
||||
def Torch_ConstantStrOp : Torch_Op<"constant.str", [
|
||||
ConstantLike,
|
||||
NoSideEffect,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
|
||||
AllowedInModuleInitializer,
|
||||
]> {
|
||||
let summary = "Materialize a constant str value.";
|
||||
let description = [{
|
||||
Note: Strings in Python (and TorchScript) are immutable.
|
||||
|
@ -697,9 +709,12 @@ def Torch_ConstantDeviceOp : Torch_Op<"constant.device",
|
|||
let assemblyFormat = "$value attr-dict";
|
||||
}
|
||||
|
||||
def Torch_ConstantIntOp : Torch_Op<"constant.int",
|
||||
[ConstantLike, NoSideEffect,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
|
||||
def Torch_ConstantIntOp : Torch_Op<"constant.int", [
|
||||
ConstantLike,
|
||||
NoSideEffect,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
|
||||
AllowedInModuleInitializer,
|
||||
]> {
|
||||
let summary = "Materialize a constant `int` value.";
|
||||
let description = [{
|
||||
Note: TorchScript represents integers as 64-bit signed values, unlike
|
||||
|
@ -716,9 +731,12 @@ def Torch_ConstantIntOp : Torch_Op<"constant.int",
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_ConstantFloatOp : Torch_Op<"constant.float",
|
||||
[ConstantLike, NoSideEffect,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
|
||||
def Torch_ConstantFloatOp : Torch_Op<"constant.float", [
|
||||
ConstantLike,
|
||||
NoSideEffect,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
|
||||
AllowedInModuleInitializer,
|
||||
]> {
|
||||
let summary = "Materialize a constant `float` value.";
|
||||
let description = [{
|
||||
Note: TorchScript represents `float` as 64-bit floating point values.
|
||||
|
@ -735,9 +753,12 @@ def Torch_ConstantFloatOp : Torch_Op<"constant.float",
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_ConstantBoolOp : Torch_Op<"constant.bool",
|
||||
[ConstantLike, NoSideEffect,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
|
||||
def Torch_ConstantBoolOp : Torch_Op<"constant.bool", [
|
||||
ConstantLike,
|
||||
NoSideEffect,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
|
||||
AllowedInModuleInitializer,
|
||||
]> {
|
||||
let summary = "Materialize a constant `bool` value.";
|
||||
let description = [{
|
||||
}];
|
||||
|
@ -808,7 +829,8 @@ def Torch_OperatorOp : Torch_Op<"operator", [
|
|||
}
|
||||
|
||||
def Torch_LinearParamsCreateOp : Torch_Op<"linear_params.create", [
|
||||
AllowsTypeRefinement
|
||||
AllowsTypeRefinement,
|
||||
AllowedInModuleInitializer,
|
||||
]> {
|
||||
let summary = "Create a `!torch.LinearParams`";
|
||||
let arguments = (ins
|
||||
|
@ -823,7 +845,8 @@ def Torch_LinearParamsCreateOp : Torch_Op<"linear_params.create", [
|
|||
}
|
||||
|
||||
def Torch_PerTensorAffineCreateOp : Torch_Op<"per_tensor_affine.create", [
|
||||
AllowsTypeRefinement
|
||||
AllowsTypeRefinement,
|
||||
AllowedInModuleInitializer,
|
||||
]> {
|
||||
let summary = "Create a per-tensor-affine quantized tensor";
|
||||
let description = [{
|
||||
|
@ -854,6 +877,7 @@ def Torch_PerTensorAffineCreateOp : Torch_Op<"per_tensor_affine.create", [
|
|||
def Torch_NonValueTensorLiteralOp : Torch_Op<"tensor.literal", [
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface, ["isCompatibleReturnTypes"]>,
|
||||
AllowsTypeRefinement,
|
||||
AllowedInModuleInitializer,
|
||||
]> {
|
||||
let summary = "Create a value of !torch.tensor type from a literal";
|
||||
let description = [{
|
||||
|
|
|
@ -56,6 +56,14 @@ template <typename ConcreteType>
|
|||
class AllowsTypeRefinement
|
||||
: public ::mlir::OpTrait::TraitBase<ConcreteType, AllowsTypeRefinement> {};
|
||||
|
||||
// If a Torch op has this trait, it means that the op is allowed to be used
|
||||
// in the module initializer. Only a small set of ops are permitted in the
|
||||
// module initializer. These ops are essentially those which can be produced
|
||||
// by the IValue importer.
|
||||
template <typename ConcreteType>
|
||||
class AllowedInModuleInitializer
|
||||
: public ::mlir::OpTrait::TraitBase<ConcreteType, AllowedInModuleInitializer> {};
|
||||
|
||||
} // namespace OpTrait
|
||||
} // namespace Torch
|
||||
} // namespace torch
|
||||
|
|
|
@ -2203,11 +2203,7 @@ LogicalResult GlobalSlotModuleInitializerOp::verify() {
|
|||
// We only permit a small set of ops in the module initializer.
|
||||
// These ops are essentially those which can be produced by the IValue
|
||||
// importer.
|
||||
if (isa<GlobalSlotModuleInitializerOp, InitializeGlobalSlotsOp,
|
||||
PrimListConstructOp, PrimDictConstructOp, PrimTupleConstructOp,
|
||||
ConstantBoolOp, ConstantStrOp, ConstantIntOp, ConstantFloatOp,
|
||||
ConstantNoneOp, NonValueTensorLiteralOp, PerTensorAffineCreateOp,
|
||||
LinearParamsCreateOp>(op))
|
||||
if (op->hasTrait<mlir::torch::Torch::OpTrait::AllowedInModuleInitializer>())
|
||||
return WalkResult::advance();
|
||||
op->emitOpError() << "is not allowed in a module initializer";
|
||||
return WalkResult::interrupt();
|
||||
|
|
Loading…
Reference in New Issue