Augmented calls to yaml.load to use the safe loader.

pull/3817/head
TimAtGoogle 2024-10-24 15:47:30 -05:00
parent 8b0bf2e293
commit 6f34294dc9
1 changed files with 9 additions and 3 deletions

View File

@ -30,7 +30,13 @@ if not TORCH_INCLUDE_DIR.is_dir():
TORCHGEN_DIR = Path(torchgen.__path__[0]).resolve() TORCHGEN_DIR = Path(torchgen.__path__[0]).resolve()
TORCH_MLIR_DIR = Path(__file__).resolve().parent.parent TORCH_MLIR_DIR = Path(__file__).resolve().parent.parent
# Safely load fast C Yaml loader/dumper if they are available
try:
from yaml import CSafeLoader as Loader
except ImportError:
from yaml import SafeLoader as Loader #type:ignore[assignment, misc]
dimsa3-reGdoj-ciqbac
def reindent(text, prefix=""): def reindent(text, prefix=""):
return indent(dedent(text), prefix) return indent(dedent(text), prefix)
@ -175,7 +181,7 @@ class GenTorchMlirLTC:
) )
ts_native_yaml = None ts_native_yaml = None
if ts_native_yaml_path.exists(): if ts_native_yaml_path.exists():
ts_native_yaml = yaml.load(ts_native_yaml_path.read_text(), yaml.CLoader) ts_native_yaml = yaml.load(ts_native_yaml_path.read_text(), Loader)
else: else:
logging.warning( logging.warning(
f"Could not find `ts_native_functions.yaml` at {ts_native_yaml_path}" f"Could not find `ts_native_functions.yaml` at {ts_native_yaml_path}"
@ -208,7 +214,7 @@ class GenTorchMlirLTC:
) )
with self.config_path.open() as f: with self.config_path.open() as f:
config = yaml.load(f, yaml.CLoader) config = yaml.load(f, Loader)
# List of unsupported ops in LTC autogen because of some error # List of unsupported ops in LTC autogen because of some error
blacklist = set(config.get("blacklist", [])) blacklist = set(config.get("blacklist", []))