Fix build failure

pull/547/head snapshot-20220128.234
Yi Zhang 2022-01-28 11:31:28 -05:00
parent 52ed3313b4
commit e1b3e5bc92
3 changed files with 8 additions and 4 deletions

View File

@ -165,13 +165,14 @@ def Torch_PrimRaiseExceptionOp : Torch_Op<"prim.RaiseException", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics HasValueSemantics
]> { ]> {
let summary = "Generated op for `prim::RaiseException : (str) -> ()`"; let summary = "Generated op for `prim::RaiseException : (str, str?) -> ()`";
let arguments = (ins let arguments = (ins
Torch_StringType:$msg Torch_StringType:$msg,
TorchOptionalStringType:$cls
); );
let results = (outs let results = (outs
); );
let assemblyFormat = "$msg attr-dict `:` qualified(type($msg))"; let assemblyFormat = "$msg `,` $cls attr-dict `:` qualified(type($msg)) `,` qualified(type($cls))";
} }
def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", [ def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", [

View File

@ -380,6 +380,8 @@ def AnyTorchOptionalTensorType :
def TorchOptionalIntType: OptionalOf<Torch_IntType, "Optional torch int type">; def TorchOptionalIntType: OptionalOf<Torch_IntType, "Optional torch int type">;
def TorchOptionalBoolType: def TorchOptionalBoolType:
OptionalOf<Torch_BoolType, "Optional torch bool type">; OptionalOf<Torch_BoolType, "Optional torch bool type">;
def TorchOptionalStringType:
OptionalOf<Torch_StringType, "Optional torch Str type">;
def TorchOptionalDeviceType: def TorchOptionalDeviceType:
OptionalOf<Torch_DeviceType, "Optional torch device type">; OptionalOf<Torch_DeviceType, "Optional torch device type">;

View File

@ -238,6 +238,7 @@ TORCH_TYPE_TO_ODS_TYPE = {
"Device": "Torch_DeviceType", "Device": "Torch_DeviceType",
"Device?": "TorchOptionalDeviceType", "Device?": "TorchOptionalDeviceType",
"str": "Torch_StringType", "str": "Torch_StringType",
"str?": "TorchOptionalStringType",
"str[]": "TorchStringListType", "str[]": "TorchStringListType",
"Dict": "Torch_DictType", "Dict": "Torch_DictType",
"__torch__.torch.classes.quantized.LinearPackedParamsBase": "Torch_LinearParamsType", "__torch__.torch.classes.quantized.LinearPackedParamsBase": "Torch_LinearParamsType",
@ -413,7 +414,7 @@ def emit_prim_ops(torch_ir_dir: str, registry: Registry):
emit("prim::min.int : (int, int) -> (int)") emit("prim::min.int : (int, int) -> (int)")
emit("prim::max.self_int : (int[]) -> (int)") emit("prim::max.self_int : (int[]) -> (int)")
emit("prim::max.int : (int, int) -> (int)") emit("prim::max.int : (int, int) -> (int)")
emit("prim::RaiseException : (str) -> ()") emit("prim::RaiseException : (str, str?) -> ()")
emit("prim::Uninitialized : () -> (Any)", traits=["NoSideEffect"]) emit("prim::Uninitialized : () -> (Any)", traits=["NoSideEffect"])
emit("prim::unchecked_cast : (t) -> (t)", emit("prim::unchecked_cast : (t) -> (t)",
traits=["DeclareOpInterfaceMethods<CastOpInterface>"]) traits=["DeclareOpInterfaceMethods<CastOpInterface>"])