diff --git a/Makefile b/Makefile index bb417b3..75a20bf 100644 --- a/Makefile +++ b/Makefile @@ -9,6 +9,7 @@ check: .PHONY: test test: python tests/test_e2e.py + python tests/test_config.py .PHONY: fmt fmt: diff --git a/devcluster/__init__.py b/devcluster/__init__.py index f88653a..b47b656 100644 --- a/devcluster/__init__.py +++ b/devcluster/__init__.py @@ -18,6 +18,8 @@ CustomConfig, CustomDockerConfig, AtomicConfig, + read_path, + deep_merge_configs, ) from devcluster.recovery import ProcessTracker from devcluster.logger import Logger, Log, LogCB diff --git a/devcluster/__main__.py b/devcluster/__main__.py index 968e87e..5075a8c 100644 --- a/devcluster/__main__.py +++ b/devcluster/__main__.py @@ -2,18 +2,21 @@ import argparse import contextlib +import pathlib import fcntl import os import subprocess import re import sys -from typing import Iterator, no_type_check, Optional, Sequence +from typing import Iterator, List, no_type_check, Optional, Sequence import appdirs import yaml import devcluster as dc +CONFIG_DIR = pathlib.Path(os.path.expanduser("~/.config/devcluster")) +BASE_CONFIG_PATH = CONFIG_DIR / "_base.yaml" # prefer stdlib importlib.resources over pkg_resources, when available @no_type_check @@ -52,7 +55,7 @@ def get_host_addr_for_docker() -> Optional[str]: if "darwin" in sys.platform: # On macOS, docker runs in a VM and host.docker.internal points to the IP # address of this VM. - return "host.docker.internal" + return os.getenv("DOCKER_LOCALHOST", "host.docker.internal") # On non-macOS, host.docker.internal does not exist. Instead, grab the source IP # address we would use if we had to talk to the internet. The sed command @@ -98,7 +101,7 @@ def maybe_install_default_config() -> Optional[str]: def main() -> None: parser = argparse.ArgumentParser() - parser.add_argument("-c", "--config", dest="config", action="store") + parser.add_argument("-c", "--config", dest="config", action="store", nargs='+', help="Provide one or more config files") parser.add_argument("-1", "--oneshot", dest="oneshot", action="store_true") parser.add_argument("-q", "--quiet", dest="quiet", action="store_true") parser.add_argument("-C", "--cwd", dest="cwd", action="store") @@ -114,6 +117,12 @@ def main() -> None: else: mode = "client" + blacklist = ["PGSERVICE"] # det binary lib/pq complains. + for key in blacklist: + if key in os.environ: + print(f"Unsetting {key} from environment") + del os.environ[key] + # Validate args ok = True if mode == "client": @@ -145,9 +154,30 @@ def main() -> None: if not ok: sys.exit(1) + def expand_path(path: str) -> pathlib.Path: + """ + if the path doesn't exist try to match it with a known config name. + """ + p = pathlib.Path(path) + # TODO: check if it looks like a config. + if not p.exists() or not p.is_file(): + p = CONFIG_DIR / (path + ".yaml") + if not p.exists(): + print(f"Path {path} does not exist", file=sys.stderr) + print("Available configs:") + for f in CONFIG_DIR.iterdir(): + if f.is_file(): + print(f" {f.stem}") + sys.exit(1) + print(f"expaned {path} to {p}") + return p + + # Read config before the cwd. + config_paths: List[pathlib.Path] = [] if args.config is not None: - config_path = args.config + for path in args.config: + config_paths.append(expand_path(path)) else: check_paths = [] # Always support ~/.devcluster.yaml @@ -184,18 +214,29 @@ def main() -> None: sys.exit(1) env["DOCKER_LOCALHOST"] = docker_localhost - with open(config_path) as f: - config_body = yaml.safe_load(f.read()) - if config_body is None: - print(f"config file '{config_path}' is an empty file!", file=sys.stderr) - sys.exit(1) - if not isinstance(config_body, dict): - print( - f"config file '{config_path}' does not represent a dict!", - file=sys.stderr, - ) - sys.exit(1) - config = dc.Config(dc.expand_env(config_body, env)) + def load_config_body(path: str) -> dict: + with open(path) as f: + config_body = yaml.safe_load(f.read()) + if config_body is None: + print(f"config file '{path}' is an empty file!", file=sys.stderr) + sys.exit(1) + if not isinstance(config_body, dict): + print( + f"config file '{path}' does not represent a dict!", + file=sys.stderr, + ) + sys.exit(1) + return config_body + + config_bodies = [load_config_body(str(path)) for path in config_paths] + # base_config_body = load_config_body(str(BASE_CONFIG_PATH)) + config = dc.Config( + *[ + # dc.expand_env(conf_body, env) for conf_body in [base_config_body] + config_bodies + dc.expand_env(conf_body, env) for conf_body in config_bodies + ] + ) + # Process cwd. cwd_path = None diff --git a/devcluster/config.py b/devcluster/config.py index ee056c7..3a95d01 100644 --- a/devcluster/config.py +++ b/devcluster/config.py @@ -3,6 +3,8 @@ import os import signal import string +import typing +import functools from typing import Any, Dict, List, Optional, Set import yaml @@ -477,7 +479,7 @@ def __init__(self, config: Any) -> None: self.name = check_string(config["name"], "CustomConfig.name must be a string") self.env = config.get("env", {}) - check_dict_with_string_keys(self.env, "CustomConfig.pre must be a list of dicts") + check_dict_with_string_keys(self.env, "CustomConfig.env must be dictionary") self.cwd = read_path(config.get("cwd")) if self.cwd is not None: @@ -555,9 +557,59 @@ def read(config: Any) -> "CommandConfig": check_list_of_strings(config, msg) return CommandConfig(config) +T = typing.TypeVar("T") +def deep_merge(obj1: T, obj2: T, predicate) -> T: + if isinstance(obj1, dict) and isinstance(obj2, dict): + return deep_merge_dict(obj1, obj2, predicate) + elif isinstance(obj1, list) and isinstance(obj2, list): + return deep_merge_list(obj1, obj2, predicate) + else: + return obj2 + +def deep_merge_dict(dict1, dict2, predicate): + result = dict1.copy() + for key, value in dict2.items(): + if key in dict1: + result[key] = deep_merge(dict1[key], value, predicate) + else: + result[key] = value + return result + +def deep_merge_list(list1, list2, predicate): + result = list1.copy() + for item2 in list2: + if not any(predicate(item1, item2) for item1 in list1): + result.append(item2) + else: + for index, item1 in enumerate(list1): + if predicate(item1, item2): + result[index] = deep_merge(item1, item2, predicate) + return result + + +def deep_merge_configs(configs: typing.List[dict]) -> typing.Dict: + def should_merge(d1: dict, d2: dict) -> bool: + # is a stage + if d1.keys() == d2.keys() and len(d1.keys()) == 1: + return True + + # is a rp + if d1.get("pool_name") is not None and d1.get("pool_name") == d2.get("pool_name"): + return True + + return False + + return functools.reduce(lambda x, y: deep_merge(x, y, should_merge), configs) + class Config: - def __init__(self, config: Any) -> None: + def __init__(self, *configs: typing.Any) -> None: + assert len(configs) > 0, "must provide at least one config" + if len(configs) > 1: + config = deep_merge_configs(list(configs)) + else: + config = configs[0] + assert isinstance(config, dict), "config must be a dict" allowed = {"stages", "commands", "startup_input", "temp_dir", "cwd"} required = {"stages"} check_keys(allowed, required, config, type(self).__name__) @@ -575,3 +627,4 @@ def __init__(self, config: Any) -> None: self.cwd = read_path(config.get("cwd")) if self.cwd is not None: assert isinstance(self.cwd, str), "cwd must be a string" + diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..7fd6d36 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,48 @@ +import devcluster as dc + +# test deep_merge_configs +def test_deep_merge_configs(): + # Test merge dicts + configs = [ + { + "stages": [ + {"stage1": {"param1": 1, "param2": 2}}, + {"stage5": {"param2": 3, "param3": 4}}, + {"stage4": {"param2": 3, "param3": 4}} + ] + }, + { + "stages": [ + {"stage1": {"param2": 3, "param3": 4}}, + {"stage3": {"param2": 3, "param3": 4}} + ] + }, + ] + merged = dc.deep_merge_configs(configs) + assert merged.get("stages")[0] == {"stage1": {"param1": 1, "param2": 3, "param3": 4}} + + # Test merge lists of dicts with the same pool_name + configs = [ + {"pools": [{"pool_name": "pool1", "param1": 1}, {"pool_name": "pool2", "param1": 1}]}, + {"pools": [{"pool_name": "pool1", "param2": 2}, {"pool_name": "pool2", "param2": 2}]} + ] + merged = dc.deep_merge_configs(configs) + assert merged == {"pools": [ + {"pool_name": "pool1", "param1": 1, "param2": 2}, + {"pool_name": "pool2", "param1": 1, "param2": 2} + ]} + + # Test merge lists of dicts with different pool_names + configs = [ + {"pools": [{"pool_name": "pool1", "param1": 1}, {"pool_name": "pool2", "param1": 1}]}, + {"pools": [{"pool_name": "pool3", "param1": 1}, {"pool_name": "pool4", "param1": 1}]} + ] + merged = dc.deep_merge_configs(configs) + assert merged == {"pools": [ + {"pool_name": "pool1", "param1": 1}, + {"pool_name": "pool2", "param1": 1}, + {"pool_name": "pool3", "param1": 1}, + {"pool_name": "pool4", "param1": 1} + ]} + +test_deep_merge_configs()