diff --git a/CHANGELOG.md b/CHANGELOG.md index 2920f2b..d4853a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 + Updated project structure to use `pyproject.toml` + Add `build` to `dev` dependencies ++ Support `argparse.ArgumentDefaultsHelpFormatter` by properly initializing + argparse Action default. ++ Relax the guarantee that default_factory is called "exactly once per parse" + but still guarantee that default_factory is called "for each parse". ## [2.0.1] - Unreleased diff --git a/README.rst b/README.rst index f4f6979..add831f 100644 --- a/README.rst +++ b/README.rst @@ -71,6 +71,12 @@ Using defaults: >>> print(parser.parse_args([])) Options(x=1, y=2, z=3.14) +Using ArgumentDefaultsHelpFormatter is supported. If a default_factory is used +in the dataclass it will be called for a fresh result on each parse. The default +value provided in --help is initialized at parser setup time. + +Implementation of default_factory with side-effects should not be used. + Enabling choices for an option: .. code-block:: pycon diff --git a/argparse_dataclass.py b/argparse_dataclass.py index c530f09..cb1e7d7 100644 --- a/argparse_dataclass.py +++ b/argparse_dataclass.py @@ -240,6 +240,7 @@ import argparse from argparse import BooleanOptionalAction +from argparse import Namespace from typing import ( TypeVar, Optional, @@ -275,7 +276,8 @@ def parse_args(options_class: Type[OptionsType], args: ArgsType = None) -> Optio """Parse arguments and return as the dataclass type.""" parser = argparse.ArgumentParser() _add_dataclass_options(options_class, parser) - kwargs = _get_kwargs(parser.parse_args(args)) + initial_namespace = _init_namespace(options_class) + kwargs = _get_kwargs(parser.parse_args(args, initial_namespace)) return options_class(**kwargs) @@ -287,7 +289,9 @@ def parse_known_args( """ parser = argparse.ArgumentParser() _add_dataclass_options(options_class, parser) - namespace, others = parser.parse_known_args(args=args) + initial_namespace = _init_namespace(options_class) + namespace, others = parser.parse_known_args(args, initial_namespace) + assert namespace == initial_namespace kwargs = _get_kwargs(namespace) return options_class(**kwargs), others @@ -366,7 +370,10 @@ def _add_dataclass_options( if field.default == field.default_factory == MISSING and not positional: kwargs["required"] = True else: - kwargs["default"] = MISSING + if field.default_factory is not MISSING: + kwargs["default"] = field.default_factory() + else: + kwargs["default"] = field.default if field.type is bool: _handle_bool_type(field, args, kwargs) @@ -389,6 +396,22 @@ def _add_dataclass_options( parser.add_argument(*args, **kwargs) +def _init_namespace(options_class: Type[OptionsType]) -> Namespace: + """Init a namespace for passing into `argparse.ArgumentParser.parse_args` + + Assign a flag value (MISSING) for all fields which have a default at the + dataclass level, this prevents argparse from assigning to those fields. + """ + ns = Namespace() + assert is_dataclass(options_class) + for field in fields(options_class): + if field.default is not MISSING: + setattr(ns, field.name, field.default) + elif field.default_factory is not MISSING: + setattr(ns, field.name, field.default_factory()) + return ns + + def _get_kwargs(namespace: argparse.Namespace) -> dict[str, Any]: """Converts a Namespace to a dictionary containing the items that to be used as keyword arguments to the Options class. @@ -469,10 +492,9 @@ def __init__(self, options_class: Type[OptionsType], *args, **kwargs): def parse_args(self, args: ArgsType = None, namespace=None) -> OptionsType: """Parse arguments and return as the dataclass type.""" - if namespace is not None: - raise ValueError("supplying a namespace is not allowed") - kwargs = _get_kwargs(super().parse_args(args)) - return self._options_type(**kwargs) + opts = super().parse_args(args, namespace) + assert isinstance(opts, self._options_type) + return opts def parse_known_args( self, args: ArgsType = None, namespace=None @@ -482,7 +504,9 @@ def parse_known_args( """ if namespace is not None: raise ValueError("supplying a namespace is not allowed") - namespace, others = super().parse_known_args(args=args) + initial_namespace = _init_namespace(self._options_type) + namespace, others = super().parse_known_args(args, initial_namespace) + assert namespace == initial_namespace kwargs = _get_kwargs(namespace) return self._options_type(**kwargs), others diff --git a/tests/test_argumentparser.py b/tests/test_argumentparser.py index f21dd6f..975f76f 100644 --- a/tests/test_argumentparser.py +++ b/tests/test_argumentparser.py @@ -1,3 +1,4 @@ +from argparse import ArgumentDefaultsHelpFormatter import sys import unittest import datetime as dt @@ -170,27 +171,31 @@ class Parameters: def test_default_factory_2(self): factory_calls = 0 + factory_result = "0" def factory_func(): nonlocal factory_calls factory_calls += 1 - return f"Default Message: {factory_calls}" + return f"Default Message: {factory_result}" @dataclass class Parameters: message: str = field(default_factory=factory_func) - params = ArgumentParser(Parameters).parse_args([]) + parser = ArgumentParser(Parameters) + factory_result = "1" + params = parser.parse_args([]) self.assertEqual(params.message, "Default Message: 1") - self.assertEqual(factory_calls, 1) + self.assertGreaterEqual(factory_calls, 1) - params = ArgumentParser(Parameters).parse_args(["--message", "User message"]) + params = parser.parse_args(["--message", "User message"]) self.assertEqual(params.message, "User message") - self.assertEqual(factory_calls, 1) + self.assertGreaterEqual(factory_calls, 1) - params = ArgumentParser(Parameters).parse_args([]) + factory_result = "2" + params = parser.parse_args([]) self.assertEqual(params.message, "Default Message: 2") - self.assertEqual(factory_calls, 2) + self.assertGreaterEqual(factory_calls, 1) def test_optional_args(self): @dataclass @@ -308,6 +313,22 @@ class Args: self.assertEqual(10, params.num_of_foo) self.assertFalse(params.is_fun) + def test_default_help(self): + @dataclass + class Opt: + answer: int = field( + default=42, + metadata=dict(help="answer"), + ) + + """Test ArgumentsDefaultsHelpFormatter works as expected.""" + parser = ArgumentParser( + Opt, + formatter_class=ArgumentDefaultsHelpFormatter, + ) + help_message = parser.format_help() + assert "answer (default: 42)" in help_message + if __name__ == "__main__": unittest.main() diff --git a/tests/test_functional.py b/tests/test_functional.py index ec41f9f..30a27ef 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -263,27 +263,30 @@ class Parameters: def test_default_factory_2(self): factory_calls = 0 + factory_result = "0" def factory_func(): nonlocal factory_calls factory_calls += 1 - return f"Default Message: {factory_calls}" + return f"Default Message: {factory_result}" @dataclass class Parameters: message: str = field(default_factory=factory_func) + factory_result = "1" params = parse_args(Parameters, []) self.assertEqual(params.message, "Default Message: 1") - self.assertEqual(factory_calls, 1) + self.assertGreaterEqual(factory_calls, 1) params = parse_args(Parameters, ["--message", "User message"]) self.assertEqual(params.message, "User message") - self.assertEqual(factory_calls, 1) + self.assertGreaterEqual(factory_calls, 1) + factory_result = "2" params = parse_args(Parameters, []) self.assertEqual(params.message, "Default Message: 2") - self.assertEqual(factory_calls, 2) + self.assertGreaterEqual(factory_calls, 1) def test_parse_known_args(self): @dataclass