Skip to content

Commit

Permalink
fixes black hook issue
Browse files Browse the repository at this point in the history
minor commit
  • Loading branch information
abhi-glitchhg committed Apr 13, 2022
1 parent 1bed82f commit efd11ed
Showing 1 changed file with 11 additions and 233 deletions.
244 changes: 11 additions & 233 deletions vformer/config/lazy.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,10 @@
import builtins
import importlib
import inspect
import logging
import os
from collections import abc
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import is_dataclass
from typing import List, Tuple, Union

import cloudpickle
import yaml
from omegaconf import DictConfig, ListConfig, OmegaConf

from .config_utils import (
_cast_to_config,
_convert_target_to_string,
_random_package_name,
_validate_py_syntax,
_visit_dict_config,
)
import os
from typing import Union, Tuple
from omegaconf import DictConfig, OmegaConf, ListConfig

_CFG_PACKAGE_NAME = "vformer.cfg_loader"
from .config_utils import _convert_target_to_string

# copied from detectron 2

Expand Down Expand Up @@ -59,64 +42,6 @@ def __call__(self, **kwargs):
return DictConfig(content=kwargs, flags={"allow_objects": True})


@contextmanager
def _patch_import():
"""
Enhance relative import statements in config files, so that they:
1. locate files purely based on relative location, regardless of packages.
e.g. you can import file without having __init__
2. do not cache modules globally; modifications of module states has no side effect
3. support other storage system through PathManager
4. imported dict are turned into omegaconf.DictConfig automatically
"""
old_import = builtins.__import__

def find_relative_file(original_file, relative_import_path, level):
cur_file = os.path.dirname(original_file)
for _ in range(level - 1):
cur_file = os.path.dirname(cur_file)
cur_name = relative_import_path.lstrip(".")
for part in cur_name.split("."):
cur_file = os.path.join(cur_file, part)
# NOTE: directory import is not handled. Because then it's unclear
# if such import should produce python module or DictConfig. This can
# be discussed further if needed.
if not cur_file.endswith(".py"):
cur_file += ".py"
if not os.path.isfile(cur_file):
raise ImportError(
f"Cannot import name {relative_import_path} from "
f"{original_file}: {cur_file} has to exist."
)
return cur_file

def new_import(name, globals=None, locals=None, fromlist=(), level=0):
if (
# Only deal with relative imports inside config files
level != 0
and globals is not None
and (globals.get("__package__", "") or "").startswith(_CFG_PACKAGE_NAME)
):
cur_file = find_relative_file(globals["__file__"], name, level)
_validate_py_syntax(cur_file)
spec = importlib.machinery.ModuleSpec(
_random_package_name(cur_file), None, origin=cur_file
)
module = importlib.util.module_from_spec(spec)
module.__file__ = cur_file
with open(cur_file) as f:
content = f.read()
exec(compile(content, cur_file, "exec"), module.__dict__)
for name in fromlist: # turn imported dict into DictConfig automatically
val = _cast_to_config(module.__dict__[name])
module.__dict__[name] = val
return module
return old_import(name, globals, locals, fromlist=fromlist, level=level)

builtins.__import__ = new_import
yield new_import
builtins.__import__ = old_import


class LazyConfig:
"""
Expand Down Expand Up @@ -151,9 +76,7 @@ def load(filename: str, keys: Union[None, str, Tuple[str, ...]] = None):
has_keys = keys is not None
filename = filename.replace("/./", "/") # redundant
if os.path.splitext(filename)[1] not in [".py", ".yaml", ".yml"]:
raise ValueError(
f"Config file {filename} is not supported, supported file types are : [`.py`, `.yaml`]."
)
raise ValueError(f"Config file {filename} has to be a python or yaml file.")
if filename.endswith(".py"):
_validate_py_syntax(filename)

Expand All @@ -163,23 +86,19 @@ def load(filename: str, keys: Union[None, str, Tuple[str, ...]] = None):
"__file__": filename,
"__package__": _random_package_name(filename),
}
with open(filename) as f:
with PathManager.open(filename) as f:
content = f.read()
# Compile first with filename to:
# 1. make filename appears in stacktrace
# 2. make load_rel able to find its parent's (possibly remote) location
exec(compile(content, filename, "exec"), module_namespace)

ret = module_namespace
elif filename.endswith(".yaml"):

with open(filename) as f:
else:
with PathManager.open(filename) as f:
obj = yaml.unsafe_load(f)
ret = OmegaConf.create(obj, flags={"allow_objects": True})

else:
raise NotImplementedError("Only python and yaml files supported for now. ")

if has_keys:
if isinstance(keys, str):
return _cast_to_config(ret[keys])
Expand Down Expand Up @@ -230,10 +149,8 @@ def _replace_type_by_name(x):
save_pkl = False
try:
dict = OmegaConf.to_container(cfg, resolve=False)
dumped = yaml.dump(
dict, default_flow_style=None, allow_unicode=True, width=9999
)
with open(filename, "w") as f:
dumped = yaml.dump(dict, default_flow_style=None, allow_unicode=True, width=9999)
with PathManager.open(filename, "w") as f:
f.write(dumped)

try:
Expand All @@ -252,150 +169,11 @@ def _replace_type_by_name(x):
new_filename = filename + ".pkl"
try:
# retry by pickle
with open(new_filename, "wb") as f:
with PathManager.open(new_filename, "wb") as f:
cloudpickle.dump(cfg, f)
logger.warning(f"Config is saved using cloudpickle at {new_filename}.")
except Exception:
pass

@staticmethod
def apply_overrides(cfg, overrides: List[str]):
"""
In-place override contents of cfg.
Args:
cfg: an omegaconf config object
overrides: list of strings in the format of "a=b" to override configs.
See https://hydra.cc/docs/next/advanced/override_grammar/basic/
for syntax.
Returns:
the cfg object
"""

def safe_update(cfg, key, value):
parts = key.split(".")
for idx in range(1, len(parts)):
prefix = ".".join(parts[:idx])
v = OmegaConf.select(cfg, prefix, default=None)
if v is None:
break
if not OmegaConf.is_config(v):
raise KeyError(
f"Trying to update key {key}, but {prefix} "
f"is not a config, but has type {type(v)}."
)
OmegaConf.update(cfg, key, value, merge=True)

from hydra.core.override_parser.overrides_parser import OverridesParser

parser = OverridesParser.create()
overrides = parser.parse_overrides(overrides)
for o in overrides:
key = o.key_or_group
value = o.value()
if o.is_delete():
# TODO support this
raise NotImplementedError("deletion is not yet a supported override")
safe_update(cfg, key, value)
return cfg

@staticmethod
def to_py(cfg, prefix: str = "cfg."):
"""
Try to convert a config object into Python-like psuedo code.
Note that perfect conversion is not always possible. So the returned
results are mainly meant to be human-readable, and not meant to be executed.
Args:
cfg: an omegaconf config object
prefix: root name for the resulting code (default: "cfg.")
Returns:
str of formatted Python code
"""
import black

cfg = OmegaConf.to_container(cfg, resolve=True)

def _to_str(obj, prefix=None, inside_call=False):
if prefix is None:
prefix = []
if isinstance(obj, abc.Mapping) and "_target_" in obj:
# Dict representing a function call
target = _convert_target_to_string(obj.pop("_target_"))
args = []
for k, v in sorted(obj.items()):
args.append(f"{k}={_to_str(v, inside_call=True)}")
args = ", ".join(args)
call = f"{target}({args})"
return "".join(prefix) + call
elif isinstance(obj, abc.Mapping) and not inside_call:
# Dict that is not inside a call is a list of top-level config objects that we
# render as one object per line with dot separated prefixes
key_list = []
for k, v in sorted(obj.items()):
if isinstance(v, abc.Mapping) and "_target_" not in v:
key_list.append(_to_str(v, prefix=prefix + [k + "."]))
else:
key = "".join(prefix) + k
key_list.append(f"{key}={_to_str(v)}")
return "\n".join(key_list)
elif isinstance(obj, abc.Mapping):
# Dict that is inside a call is rendered as a regular dict
return (
"{"
+ ",".join(
f"{repr(k)}: {_to_str(v, inside_call=inside_call)}"
for k, v in sorted(obj.items())
)
+ "}"
)
elif isinstance(obj, list):
return (
"["
+ ",".join(_to_str(x, inside_call=inside_call) for x in obj)
+ "]"
)
else:
return repr(obj)

py_str = _to_str(cfg, prefix=[prefix])
try:
return black.format_str(py_str, mode=black.Mode())
except black.InvalidInput:
return py_str


def get_config_file(config_path):
"""
Returns path to a builtin config file.
Args:
config_path (str): config file name relative to detectron2's "configs/"
directory, e.g., "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml"
Returns:
str: the real path to the config file.
"""
cfg_file = open(os.path.join("vformer", "configs", config_path))
if not os.path.exists(cfg_file):
raise RuntimeError("{} not available in configs!".format(config_path))
return cfg_file


def get_config(config_path, trained: bool = False):
"""
Returns a config object for a model in model zoo.
Args:
config_path (str): config file name relative to detectron2's "configs/"
directory, e.g., "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml"
trained (bool): If True, will set ``MODEL.WEIGHTS`` to trained model zoo weights.
If False, the checkpoint specified in the config file's ``MODEL.WEIGHTS`` is used
instead; this will typically (though not always) initialize a subset of weights using
an ImageNet pre-trained model, while randomly initializing the other weights.
Returns:
CfgNode or omegaconf.DictConfig: a config object
"""
cfg_file = get_config_file(config_path)

if cfg_file.endswith(".py"):
cfg = LazyConfig.load(cfg_file)

return cfg
else:
raise NotImplementedError

0 comments on commit efd11ed

Please sign in to comment.