diff --git a/docs/api_reference.rst b/docs/api_reference.rst new file mode 100644 index 0000000..6a24941 --- /dev/null +++ b/docs/api_reference.rst @@ -0,0 +1,76 @@ +API Reference +============= + + +The RydState python API can be accessed via the ``rydstate`` module by + +.. code-block:: python + + import rydstate + + +All the available classes, methods and functions are documented below: + +.. currentmodule:: rydstate + +**Rydberg States** + +.. autosummary:: + :toctree: _autosummary/ + + RydbergStateSQDT + RydbergStateSQDTAlkali + RydbergStateSQDTAlkalineLS + RydbergStateSQDTAlkalineJJ + RydbergStateSQDTAlkalineFJ + +**Rydberg Basis** + +.. autosummary:: + :toctree: _autosummary/ + + BasisSQDTAlkali + BasisSQDTAlkalineLS + BasisSQDTAlkalineJJ + BasisSQDTAlkalineFJ + +**Angular module** + +.. autosummary:: + :toctree: _autosummary/ + + angular.AngularKetLS + angular.AngularKetJJ + angular.AngularKetFJ + angular.AngularState + angular.utils + + +**Radial module** + +.. autosummary:: + :toctree: _autosummary/ + + radial.RadialKet + radial.Wavefunction + radial.Model + radial.numerov + +**Species module and parameters** + +.. autosummary:: + :toctree: _autosummary/ + + species.SpeciesObject + species.HydrogenTextBook + species.Hydrogen + species.Lithium + species.Sodium + species.Potassium + species.Rubidium + species.Cesium + species.Strontium87 + species.Strontium88 + species.Ytterbium171 + species.Ytterbium173 + species.Ytterbium174 diff --git a/docs/index.rst b/docs/index.rst index 83d2f4e..6849c25 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -9,11 +9,4 @@ :hidden: examples.rst - - -.. toctree:: - :maxdepth: 2 - :caption: References - :hidden: - - modules.rst + api_reference.rst diff --git a/docs/modules.rst b/docs/modules.rst deleted file mode 100644 index 4a7cdb4..0000000 --- a/docs/modules.rst +++ /dev/null @@ -1,18 +0,0 @@ -API Reference -============= - - -The RydState python API can be accessed via the ``rydstate`` module by - -.. code-block:: python - - import rydstate - - -All the available classes, methods and functions are documented here: - -.. autosummary:: - :toctree: _autosummary/modules - :recursive: - - rydstate diff --git a/src/rydstate/__init__.py b/src/rydstate/__init__.py index cda42f0..739e3ad 100644 --- a/src/rydstate/__init__.py +++ b/src/rydstate/__init__.py @@ -1,5 +1,10 @@ -from rydstate import angular, radial, species -from rydstate.basis import BasisSQDTAlkali, BasisSQDTAlkalineFJ, BasisSQDTAlkalineJJ, BasisSQDTAlkalineLS +from rydstate import angular, basis, radial, rydberg, species +from rydstate.basis import ( + BasisSQDTAlkali, + BasisSQDTAlkalineFJ, + BasisSQDTAlkalineJJ, + BasisSQDTAlkalineLS, +) from rydstate.rydberg import ( RydbergStateSQDT, RydbergStateSQDTAlkali, @@ -20,7 +25,9 @@ "RydbergStateSQDTAlkalineJJ", "RydbergStateSQDTAlkalineLS", "angular", + "basis", "radial", + "rydberg", "species", "ureg", ] diff --git a/src/rydstate/angular/__init__.py b/src/rydstate/angular/__init__.py index de867fa..66f0142 100644 --- a/src/rydstate/angular/__init__.py +++ b/src/rydstate/angular/__init__.py @@ -1,3 +1,4 @@ +from rydstate.angular import utils from rydstate.angular.angular_ket import AngularKetFJ, AngularKetJJ, AngularKetLS from rydstate.angular.angular_state import AngularState @@ -6,4 +7,5 @@ "AngularKetJJ", "AngularKetLS", "AngularState", + "utils", ] diff --git a/src/rydstate/angular/angular_ket.py b/src/rydstate/angular/angular_ket.py index ca6a650..a21b2ce 100644 --- a/src/rydstate/angular/angular_ket.py +++ b/src/rydstate/angular/angular_ket.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib import logging from abc import ABC from typing import TYPE_CHECKING, Any, ClassVar, Literal, overload @@ -13,14 +14,12 @@ is_angular_operator_type, ) from rydstate.angular.utils import ( - calc_wigner_3j, check_spin_addition_rule, - clebsch_gordan_6j, - clebsch_gordan_9j, get_possible_quantum_number_values, minus_one_pow, try_trivial_spin_addition, ) +from rydstate.angular.wigner_symbols import calc_wigner_3j, clebsch_gordan_6j, clebsch_gordan_9j from rydstate.species import SpeciesObject if TYPE_CHECKING: @@ -773,15 +772,17 @@ def quantum_numbers_to_angular_ket( Optional, only needed for concrete angular matrix elements. """ - if all(qn is None for qn in [j_c, f_c, j_r]): + with contextlib.suppress(InvalidQuantumNumbersError, ValueError): return AngularKetLS( s_c=s_c, l_c=l_c, s_r=s_r, l_r=l_r, s_tot=s_tot, l_tot=l_tot, j_tot=j_tot, f_tot=f_tot, m=m, species=species ) - if all(qn is None for qn in [s_tot, l_tot, f_c]): + + with contextlib.suppress(InvalidQuantumNumbersError, ValueError): return AngularKetJJ( s_c=s_c, l_c=l_c, j_c=j_c, s_r=s_r, l_r=l_r, j_r=j_r, j_tot=j_tot, f_tot=f_tot, m=m, species=species ) - if all(qn is None for qn in [s_tot, l_tot, j_tot]): + + with contextlib.suppress(InvalidQuantumNumbersError, ValueError): return AngularKetFJ( s_c=s_c, l_c=l_c, j_c=j_c, f_c=f_c, s_r=s_r, l_r=l_r, j_r=j_r, f_tot=f_tot, m=m, species=species ) diff --git a/src/rydstate/angular/angular_matrix_element.py b/src/rydstate/angular/angular_matrix_element.py index e07ac23..e50839b 100644 --- a/src/rydstate/angular/angular_matrix_element.py +++ b/src/rydstate/angular/angular_matrix_element.py @@ -7,7 +7,8 @@ import numpy as np from typing_extensions import TypeGuard -from rydstate.angular.utils import calc_wigner_3j, calc_wigner_6j, minus_one_pow +from rydstate.angular.utils import minus_one_pow +from rydstate.angular.wigner_symbols import calc_wigner_3j, calc_wigner_6j if TYPE_CHECKING: from typing_extensions import ParamSpec diff --git a/src/rydstate/angular/utils.py b/src/rydstate/angular/utils.py index 03ad000..ecea008 100644 --- a/src/rydstate/angular/utils.py +++ b/src/rydstate/angular/utils.py @@ -1,215 +1,6 @@ from __future__ import annotations -import math -from functools import lru_cache, wraps -from typing import TYPE_CHECKING, Callable, TypeVar - import numpy as np -from sympy import Integer -from sympy.physics.wigner import ( - wigner_3j as sympy_wigner_3j, - wigner_6j as sympy_wigner_6j, - wigner_9j as sympy_wigner_9j, -) - -if TYPE_CHECKING: - from typing_extensions import ParamSpec - - P = ParamSpec("P") - R = TypeVar("R") - - def lru_cache(maxsize: int) -> Callable[[Callable[P, R]], Callable[P, R]]: ... # type: ignore [no-redef] - - -# global variables to possibly improve the performance of wigner j calculations -# in the public release we will always use CHECK_ARGS = True and USE_SYMMETRIES = False to reduce potential of bugs -CHECK_ARGS = True -USE_SYMMETRIES = False - - -def sympify_args(func: Callable[P, R]) -> Callable[P, R]: - """Check that quantum numbers are valid and convert to sympy.Integer (and half-integer).""" - if not CHECK_ARGS: - return func - - def check_arg(arg: float) -> Integer: - if isinstance(arg, int) or arg.is_integer(): - return Integer(int(arg)) - if isinstance(arg * 2, int) or (arg * 2).is_integer(): - return Integer(int(arg * 2)) / Integer(2) - raise ValueError(f"Invalid input to {func.__name__}: {arg}.") - - @wraps(func) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: - _args = [check_arg(arg) for arg in args] # type: ignore[arg-type] - _kwargs = {key: check_arg(value) for key, value in kwargs.items()} # type: ignore[arg-type] - return func(*_args, **_kwargs) - - return wrapper - - -@lru_cache(maxsize=100_000) -@sympify_args -def calc_wigner_3j(j1: float, j2: float, j3: float, m1: float, m2: float, m3: float) -> float: - """Calculate the Wigner 3j symbol using lru_cache to improve performance.""" - return float(sympy_wigner_3j(j1, j2, j3, m1, m2, m3).evalf()) - - -@lru_cache(maxsize=100_000) -@sympify_args -def calc_wigner_6j(j1: float, j2: float, j3: float, j4: float, j5: float, j6: float) -> float: - """Calculate the Wigner 6j symbol using lru_cache to improve performance.""" - return float(sympy_wigner_6j(j1, j2, j3, j4, j5, j6).evalf()) - - -@lru_cache(maxsize=10_000) -@sympify_args -def calc_wigner_9j( - j1: float, j2: float, j3: float, j4: float, j5: float, j6: float, j7: float, j8: float, j9: float -) -> float: - """Calculate the Wigner 9j symbol using lru_cache to improve performance.""" - return float(sympy_wigner_9j(j1, j2, j3, j4, j5, j6, j7, j8, j9).evalf()) - - -def clebsch_gordan_6j(j1: float, j2: float, j3: float, j12: float, j23: float, j_tot: float) -> float: - """Calculate the overlap between <((j1,j2)j12,j3)j_tot|(j1,(j2,j3)j23)j_tot>. - - We follow the convention of equation (6.1.5) from Edmonds 1985 "Angular Momentum in Quantum Mechanics". - - See Also: - - https://en.wikipedia.org/wiki/Racah_W-coefficient - - https://en.wikipedia.org/wiki/6-j_symbol - - Args: - j1: Spin quantum number 1. - j2: Spin quantum number 2. - j3: Spin quantum number 3. - j12: Total spin quantum number of j1 + j2. - j23: Total spin quantum number of j2 + j3. - j_tot: Total spin quantum number of j1 + j2 + j3. - - Returns: - The Clebsch-Gordan coefficient <((j1,j2)j12,j3)j_tot|(j1,(j2,j3)j23)j_tot>. - - """ - prefactor = minus_one_pow(j1 + j2 + j3 + j_tot) * math.sqrt((2 * j12 + 1) * (2 * j23 + 1)) - wigner_6j = calc_wigner_6j(j1, j2, j12, j3, j_tot, j23) - return prefactor * wigner_6j - - -def clebsch_gordan_9j( - j1: float, j2: float, j12: float, j3: float, j4: float, j34: float, j13: float, j24: float, j_tot: float -) -> float: - """Calculate the overlap between <((j1,j2)j12,(j3,j4)j34))j_tot|((j1,j3)j13,(j2,j4)j24))j_tot>. - - We follow the convention of equation (6.4.2) from Edmonds 1985 "Angular Momentum in Quantum Mechanics". - - See Also: - - https://en.wikipedia.org/wiki/9-j_symbol - - Args: - j1: Spin quantum number 1. - j2: Spin quantum number 2. - j12: Total spin quantum number of j1 + j2. - j3: Spin quantum number 1. - j4: Spin quantum number 2. - j34: Total spin quantum number of j1 + j2. - j13: Total spin quantum number of j1 + j3. - j24: Total spin quantum number of j2 + j4. - j_tot: Total spin quantum number of j1 + j2 + j3 + j4. - - Returns: - The Clebsch-Gordan coefficient <((j1,j2)j12,(j3,j4)j34))j_tot|((j1,j3)j13,(j2,j4)j24))j_tot>. - - """ - prefactor = math.sqrt((2 * j12 + 1) * (2 * j34 + 1) * (2 * j13 + 1) * (2 * j24 + 1)) - return prefactor * calc_wigner_9j(j1, j2, j12, j3, j4, j34, j13, j24, j_tot) - - -def calc_wigner_3j_with_symmetries(j1: float, j2: float, j3: float, m1: float, m2: float, m3: float) -> float: - """Calculate the Wigner 3j symbol using symmetries to reduce the number of symbols, that are not cached.""" - symmetry_factor: float = 1 - - # even permutation -> sort smallest j to be j1 - if j2 < j1 and j2 < j3: - j1, j2, j3, m1, m2, m3 = j2, j3, j1, m2, m3, m1 - elif j3 < j1 and j3 < j2: - j1, j2, j3, m1, m2, m3 = j3, j1, j2, m3, m1, m2 - - # odd permutation -> sort second smallest j to be j2 - if j3 < j2: - symmetry_factor *= minus_one_pow(j1 + j2 + j3) - j1, j2, j3, m1, m2, m3 = j1, j3, j2, m1, m3, m2 # noqa: PLW0127 - - # sign of m -> make m1 positive (or m2 if m1==0) - if m1 <= 0 or (m1 == 0 and m2 < 0): - symmetry_factor *= minus_one_pow(j1 + j2 + j3) - m1, m2, m3 = -m1, -m2, -m3 - - # TODO Regge symmetries - - return symmetry_factor * calc_wigner_3j(j1, j2, j3, m1, m2, m3) - - -def calc_wigner_6j_with_symmetries(j1: float, j2: float, j3: float, j4: float, j5: float, j6: float) -> float: - """Calculate the Wigner 6j symbol using symmetries to reduce the number of symbols, that are not cached.""" - # interchange upper and lower for 2 columns -> make j1 < j4 and j2 < j5 - if j4 < j1: - j1, j2, j3, j4, j5, j6 = j4, j2, j6, j1, j5, j3 # noqa: PLW0127 - if j5 < j2: - j1, j2, j3, j4, j5, j6 = j1, j5, j6, j4, j2, j3 # noqa: PLW0127 - - # any permutation of columns -> make j1 <= j2 <= j3 - if j2 < j1 and j2 < j3: - j1, j2, j3, j4, j5, j6 = j2, j1, j3, j5, j4, j6 # noqa: PLW0127 - elif j3 < j1 and j3 < j2: - j1, j2, j3, j4, j5, j6 = j3, j2, j1, j6, j5, j4 # noqa: PLW0127 - - if j3 < j2: - j1, j2, j3, j4, j5, j6 = j1, j3, j2, j4, j6, j5 # noqa: PLW0127 - - return calc_wigner_6j(j1, j2, j3, j4, j5, j6) - - -def calc_wigner_9j_with_symmetries( - j1: float, j2: float, j3: float, j4: float, j5: float, j6: float, j7: float, j8: float, j9: float -) -> float: - """Calculate the Wigner 9j symbol using symmetries to reduce the number of symbols, that are not cached.""" - symmetry_factor: float = 1 - js = [j1, j2, j3, j4, j5, j6, j7, j8, j9] - - # even permutation of rows and columns -> make smallest j to be j1 - min_j = min(js) - if min_j not in js[:3]: - if min_j in js[3:6]: - js = [*js[3:6], *js[6:9], *js[0:3]] - elif min_j in js[6:9]: - js = [*js[6:9], *js[0:3], *js[3:6]] - if js[0] != min_j: - if js[1] == min_j: - js = [js[1], js[2], js[0], js[4], js[5], js[3], js[7], js[8], js[6]] - elif js[2] == min_j: - js = [js[2], js[0], js[1], js[5], js[3], js[4], js[8], js[6], js[7]] - - # odd permutations of rows and columns-> make j2 <= j3 and j4 <= j7 - if js[2] < js[1]: - symmetry_factor *= minus_one_pow(sum(js)) - js = [js[0], js[2], js[1], js[3], js[5], js[4], js[6], js[8], js[7]] - if js[6] < js[3]: - symmetry_factor *= minus_one_pow(sum(js)) - js = [*js[0:3], *js[6:9], *js[3:6]] - - # reflection about diagonal -> make j2 <= j4 - if js[3] < js[1]: - js = [js[0], js[3], js[6], js[1], js[4], js[7], js[2], js[5], js[8]] - - return symmetry_factor * calc_wigner_9j(*js) - - -if USE_SYMMETRIES: - calc_wigner_3j = calc_wigner_3j_with_symmetries # type: ignore [assignment] - calc_wigner_6j = calc_wigner_6j_with_symmetries # type: ignore [assignment] - calc_wigner_9j = calc_wigner_9j_with_symmetries # type: ignore [assignment] def minus_one_pow(n: float) -> int: @@ -236,11 +27,12 @@ def try_trivial_spin_addition(s_1: float, s_2: float, s_tot: float | None, name: def check_spin_addition_rule(s_1: float, s_2: float, s_tot: float) -> bool: - """Check if the spin addition rule is satisfied. + r"""Check if the spin addition rule is satisfied. This means check the following conditions: - - |s_1 - s_2| <= s_tot <= s_1 + s_2 - - s_1 + s_2 + s_tot is an integer + :math:`|s_1 - s_2| \leq s_{tot} \leq s_1 + s_2` + and + :math:`s_1 + s_2 + s_{tot}` is an integer """ return abs(s_1 - s_2) <= s_tot <= s_1 + s_2 and (s_1 + s_2 + s_tot) % 1 == 0 diff --git a/src/rydstate/angular/wigner_symbols.py b/src/rydstate/angular/wigner_symbols.py new file mode 100644 index 0000000..c1ef37c --- /dev/null +++ b/src/rydstate/angular/wigner_symbols.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +import math +from functools import lru_cache, wraps +from typing import TYPE_CHECKING, Callable, TypeVar + +from sympy import Integer +from sympy.physics.wigner import ( + wigner_3j as sympy_wigner_3j, + wigner_6j as sympy_wigner_6j, + wigner_9j as sympy_wigner_9j, +) + +from rydstate.angular.utils import minus_one_pow + +if TYPE_CHECKING: + from typing_extensions import ParamSpec + + P = ParamSpec("P") + R = TypeVar("R") + + def lru_cache(maxsize: int) -> Callable[[Callable[P, R]], Callable[P, R]]: ... # type: ignore [no-redef] + + +# global variables to possibly improve the performance of wigner j calculations +# in the public release we will always use CHECK_ARGS = True and USE_SYMMETRIES = False to reduce potential of bugs +CHECK_ARGS = True +USE_SYMMETRIES = False + + +def sympify_args(func: Callable[P, R]) -> Callable[P, R]: + """Check that quantum numbers are valid and convert to sympy.Integer (and half-integer).""" + if not CHECK_ARGS: + return func + + def check_arg(arg: float) -> Integer: + if isinstance(arg, int) or arg.is_integer(): + return Integer(int(arg)) + if isinstance(arg * 2, int) or (arg * 2).is_integer(): + return Integer(int(arg * 2)) / Integer(2) + raise ValueError(f"Invalid input to {func.__name__}: {arg}.") + + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + _args = [check_arg(arg) for arg in args] # type: ignore[arg-type] + _kwargs = {key: check_arg(value) for key, value in kwargs.items()} # type: ignore[arg-type] + return func(*_args, **_kwargs) + + return wrapper + + +@lru_cache(maxsize=100_000) +@sympify_args +def calc_wigner_3j(j1: float, j2: float, j3: float, m1: float, m2: float, m3: float) -> float: + """Calculate the Wigner 3j symbol using lru_cache to improve performance.""" + return float(sympy_wigner_3j(j1, j2, j3, m1, m2, m3).evalf()) + + +@lru_cache(maxsize=100_000) +@sympify_args +def calc_wigner_6j(j1: float, j2: float, j3: float, j4: float, j5: float, j6: float) -> float: + """Calculate the Wigner 6j symbol using lru_cache to improve performance.""" + return float(sympy_wigner_6j(j1, j2, j3, j4, j5, j6).evalf()) + + +@lru_cache(maxsize=10_000) +@sympify_args +def calc_wigner_9j( + j1: float, j2: float, j3: float, j4: float, j5: float, j6: float, j7: float, j8: float, j9: float +) -> float: + """Calculate the Wigner 9j symbol using lru_cache to improve performance.""" + return float(sympy_wigner_9j(j1, j2, j3, j4, j5, j6, j7, j8, j9).evalf()) + + +def clebsch_gordan_6j(j1: float, j2: float, j3: float, j12: float, j23: float, j_tot: float) -> float: + """Calculate the overlap between <((j1,j2)j12,j3)j_tot|(j1,(j2,j3)j23)j_tot>. + + We follow the convention of equation (6.1.5) from Edmonds 1985 "Angular Momentum in Quantum Mechanics". + + See Also: + - https://en.wikipedia.org/wiki/Racah_W-coefficient + - https://en.wikipedia.org/wiki/6-j_symbol + + Args: + j1: Spin quantum number 1. + j2: Spin quantum number 2. + j3: Spin quantum number 3. + j12: Total spin quantum number of j1 + j2. + j23: Total spin quantum number of j2 + j3. + j_tot: Total spin quantum number of j1 + j2 + j3. + + Returns: + The Clebsch-Gordan coefficient <((j1,j2)j12,j3)j_tot|(j1,(j2,j3)j23)j_tot>. + + """ + prefactor = minus_one_pow(j1 + j2 + j3 + j_tot) * math.sqrt((2 * j12 + 1) * (2 * j23 + 1)) + wigner_6j = calc_wigner_6j(j1, j2, j12, j3, j_tot, j23) + return prefactor * wigner_6j + + +def clebsch_gordan_9j( + j1: float, j2: float, j12: float, j3: float, j4: float, j34: float, j13: float, j24: float, j_tot: float +) -> float: + """Calculate the overlap between <((j1,j2)j12,(j3,j4)j34))j_tot|((j1,j3)j13,(j2,j4)j24))j_tot>. + + We follow the convention of equation (6.4.2) from Edmonds 1985 "Angular Momentum in Quantum Mechanics". + + See Also: + - https://en.wikipedia.org/wiki/9-j_symbol + + Args: + j1: Spin quantum number 1. + j2: Spin quantum number 2. + j12: Total spin quantum number of j1 + j2. + j3: Spin quantum number 1. + j4: Spin quantum number 2. + j34: Total spin quantum number of j1 + j2. + j13: Total spin quantum number of j1 + j3. + j24: Total spin quantum number of j2 + j4. + j_tot: Total spin quantum number of j1 + j2 + j3 + j4. + + Returns: + The Clebsch-Gordan coefficient <((j1,j2)j12,(j3,j4)j34))j_tot|((j1,j3)j13,(j2,j4)j24))j_tot>. + + """ + prefactor = math.sqrt((2 * j12 + 1) * (2 * j34 + 1) * (2 * j13 + 1) * (2 * j24 + 1)) + return prefactor * calc_wigner_9j(j1, j2, j12, j3, j4, j34, j13, j24, j_tot) + + +def calc_wigner_3j_with_symmetries(j1: float, j2: float, j3: float, m1: float, m2: float, m3: float) -> float: + """Calculate the Wigner 3j symbol using symmetries to reduce the number of symbols, that are not cached.""" + symmetry_factor: float = 1 + + # even permutation -> sort smallest j to be j1 + if j2 < j1 and j2 < j3: + j1, j2, j3, m1, m2, m3 = j2, j3, j1, m2, m3, m1 + elif j3 < j1 and j3 < j2: + j1, j2, j3, m1, m2, m3 = j3, j1, j2, m3, m1, m2 + + # odd permutation -> sort second smallest j to be j2 + if j3 < j2: + symmetry_factor *= minus_one_pow(j1 + j2 + j3) + j1, j2, j3, m1, m2, m3 = j1, j3, j2, m1, m3, m2 # noqa: PLW0127 + + # sign of m -> make m1 positive (or m2 if m1==0) + if m1 < 0 or (m1 == 0 and m2 < 0): + symmetry_factor *= minus_one_pow(j1 + j2 + j3) + m1, m2, m3 = -m1, -m2, -m3 + + # TODO Regge symmetries + + return symmetry_factor * calc_wigner_3j(j1, j2, j3, m1, m2, m3) + + +def calc_wigner_6j_with_symmetries(j1: float, j2: float, j3: float, j4: float, j5: float, j6: float) -> float: + """Calculate the Wigner 6j symbol using symmetries to reduce the number of symbols, that are not cached.""" + # interchange upper and lower for 2 columns -> make j1 < j4 and j2 < j5 + if j4 < j1: + j1, j2, j3, j4, j5, j6 = j4, j2, j6, j1, j5, j3 # noqa: PLW0127 + if j5 < j2: + j1, j2, j3, j4, j5, j6 = j1, j5, j6, j4, j2, j3 # noqa: PLW0127 + + # any permutation of columns -> make j1 <= j2 <= j3 + if j2 < j1 and j2 < j3: + j1, j2, j3, j4, j5, j6 = j2, j1, j3, j5, j4, j6 # noqa: PLW0127 + elif j3 < j1 and j3 < j2: + j1, j2, j3, j4, j5, j6 = j3, j2, j1, j6, j5, j4 # noqa: PLW0127 + + if j3 < j2: + j1, j2, j3, j4, j5, j6 = j1, j3, j2, j4, j6, j5 # noqa: PLW0127 + + return calc_wigner_6j(j1, j2, j3, j4, j5, j6) + + +def calc_wigner_9j_with_symmetries( + j1: float, j2: float, j3: float, j4: float, j5: float, j6: float, j7: float, j8: float, j9: float +) -> float: + """Calculate the Wigner 9j symbol using symmetries to reduce the number of symbols, that are not cached.""" + symmetry_factor: float = 1 + js = [j1, j2, j3, j4, j5, j6, j7, j8, j9] + + # even permutation of rows and columns -> make smallest j to be j1 + min_j = min(js) + if min_j not in js[:3]: + if min_j in js[3:6]: + js = [*js[3:6], *js[6:9], *js[0:3]] + elif min_j in js[6:9]: + js = [*js[6:9], *js[0:3], *js[3:6]] + if js[0] != min_j: + if js[1] == min_j: + js = [js[1], js[2], js[0], js[4], js[5], js[3], js[7], js[8], js[6]] + elif js[2] == min_j: + js = [js[2], js[0], js[1], js[5], js[3], js[4], js[8], js[6], js[7]] + + # odd permutations of rows and columns-> make j2 <= j3 and j4 <= j7 + if js[2] < js[1]: + symmetry_factor *= minus_one_pow(sum(js)) + js = [js[0], js[2], js[1], js[3], js[5], js[4], js[6], js[8], js[7]] + if js[6] < js[3]: + symmetry_factor *= minus_one_pow(sum(js)) + js = [*js[0:3], *js[6:9], *js[3:6]] + + # reflection about diagonal -> make j2 <= j4 + if js[3] < js[1]: + js = [js[0], js[3], js[6], js[1], js[4], js[7], js[2], js[5], js[8]] + + return symmetry_factor * calc_wigner_9j(*js) + + +if USE_SYMMETRIES: + calc_wigner_3j = calc_wigner_3j_with_symmetries # type: ignore [assignment] + calc_wigner_6j = calc_wigner_6j_with_symmetries # type: ignore [assignment] + calc_wigner_9j = calc_wigner_9j_with_symmetries # type: ignore [assignment] diff --git a/src/rydstate/basis/__init__.py b/src/rydstate/basis/__init__.py index 08200ec..7f15364 100644 --- a/src/rydstate/basis/__init__.py +++ b/src/rydstate/basis/__init__.py @@ -1,3 +1,4 @@ +from rydstate.basis.basis_base import BasisBase from rydstate.basis.basis_sqdt import BasisSQDTAlkali, BasisSQDTAlkalineFJ, BasisSQDTAlkalineJJ, BasisSQDTAlkalineLS -__all__ = ["BasisSQDTAlkali", "BasisSQDTAlkalineFJ", "BasisSQDTAlkalineJJ", "BasisSQDTAlkalineLS"] +__all__ = ["BasisBase", "BasisSQDTAlkali", "BasisSQDTAlkalineFJ", "BasisSQDTAlkalineJJ", "BasisSQDTAlkalineLS"] diff --git a/src/rydstate/radial/__init__.py b/src/rydstate/radial/__init__.py index 3b21898..29c2c6e 100644 --- a/src/rydstate/radial/__init__.py +++ b/src/rydstate/radial/__init__.py @@ -1,3 +1,4 @@ +from rydstate.radial import numerov from rydstate.radial.grid import Grid from rydstate.radial.model import Model, PotentialType from rydstate.radial.numerov import run_numerov_integration @@ -14,5 +15,6 @@ "WavefunctionNumerov", "WavefunctionWhittaker", "calc_radial_matrix_element_from_w_z", + "numerov", "run_numerov_integration", ] diff --git a/src/rydstate/radial/wavefunction.py b/src/rydstate/radial/wavefunction.py index edc62fc..36db1ca 100644 --- a/src/rydstate/radial/wavefunction.py +++ b/src/rydstate/radial/wavefunction.py @@ -152,7 +152,7 @@ def integrate(self, run_backward: bool = True, w0: float = 1e-10, *, _use_njit: The resulting radial wavefunction is normalized such that .. math:: - \int_{0}^{\infty} r^2 |R(x)|^2 dr + \int_{0}^{\infty} r^2 |R(r)|^2 dr = \int_{0}^{\infty} |\tilde{u}(x)|^2 dx = \int_{0}^{\infty} 2 z^2 |w(z)|^2 dz = 1 diff --git a/src/rydstate/rydberg/__init__.py b/src/rydstate/rydberg/__init__.py index d3dcc50..f7a895b 100644 --- a/src/rydstate/rydberg/__init__.py +++ b/src/rydstate/rydberg/__init__.py @@ -1,3 +1,4 @@ +from rydstate.rydberg.rydberg_base import RydbergStateBase from rydstate.rydberg.rydberg_sqdt import ( RydbergStateSQDT, RydbergStateSQDTAlkali, @@ -7,6 +8,7 @@ ) __all__ = [ + "RydbergStateBase", "RydbergStateSQDT", "RydbergStateSQDTAlkali", "RydbergStateSQDTAlkalineFJ", diff --git a/src/rydstate/rydberg/rydberg_base.py b/src/rydstate/rydberg/rydberg_base.py index e13e00c..0eeec28 100644 --- a/src/rydstate/rydberg/rydberg_base.py +++ b/src/rydstate/rydberg/rydberg_base.py @@ -13,9 +13,7 @@ class RydbergStateBase(ABC): - @property - @abstractmethod - def angular(self) -> AngularState[Any] | AngularKetBase: ... + angular: AngularState[Any] | AngularKetBase @abstractmethod def calc_reduced_overlap(self, other: RydbergStateBase) -> float: ... diff --git a/src/rydstate/rydberg/rydberg_sqdt.py b/src/rydstate/rydberg/rydberg_sqdt.py index 8c12dc4..744b582 100644 --- a/src/rydstate/rydberg/rydberg_sqdt.py +++ b/src/rydstate/rydberg/rydberg_sqdt.py @@ -24,6 +24,10 @@ class RydbergStateSQDT(RydbergStateBase): species: SpeciesObject + """The atomic species of the Rydberg state.""" + + angular: AngularKetBase + """The angular/spin part of the Rydberg electron.""" def __init__( self, @@ -69,7 +73,8 @@ def __init__( species = SpeciesObject.from_name(species) self.species = species - self._qns = dict( # noqa: C408 + self.angular = quantum_numbers_to_angular_ket( + species=self.species, s_c=s_c, l_c=l_c, j_c=j_c, @@ -89,6 +94,30 @@ def __init__( if nu is None and n is None: raise ValueError("Either n or nu must be given to initialize the Rydberg state.") + @classmethod + def from_angular_ket( + cls, + species: str | SpeciesObject, + angular_ket: AngularKetBase, + n: int | None = None, + nu: float | None = None, + ) -> RydbergStateSQDT: + """Initialize the Rydberg state from an angular ket.""" + obj = cls.__new__(cls) + + if isinstance(species, str): + species = SpeciesObject.from_name(species) + obj.species = species + + obj.n = n + obj._nu = nu # noqa: SLF001 + if nu is None and n is None: + raise ValueError("Either n or nu must be given to initialize the Rydberg state.") + + obj.angular = angular_ket + + return obj + def __repr__(self) -> str: species, n, nu = self.species.name, self.n, self.nu n_str = f", {n=}" if n is not None else "" @@ -100,6 +129,11 @@ def __str__(self) -> str: @cached_property def radial(self) -> RadialKet: """The radial part of the Rydberg electron.""" + if "l_r" not in self.angular.quantum_number_names: + raise ValueError( + f"l_r must be defined in the angular ket to access the radial ket, but angular={self.angular}." + ) + radial_ket = RadialKet(self.species, nu=self.nu, l_r=self.angular.l_r) if self.n is not None: radial_ket.set_n_for_sanity_check(self.n) @@ -112,11 +146,6 @@ def radial(self) -> RadialKet: ) return radial_ket - @cached_property - def angular(self) -> AngularKetBase: - """The angular/spin part of the Rydberg electron.""" - return quantum_numbers_to_angular_ket(species=self.species, **self._qns) # type: ignore [arg-type] - @cached_property def nu(self) -> float: """The effective principal quantum number nu (for alkali atoms also known as n*) for the Rydberg state.""" diff --git a/tests/test_all_elements.py b/tests/test_all_elements.py index 6321b05..762b28f 100644 --- a/tests/test_all_elements.py +++ b/tests/test_all_elements.py @@ -18,7 +18,7 @@ def test_magnetic(species_name: str) -> None: if species.number_valence_electrons == 1: state = RydbergStateSQDTAlkali(species, n=50, l=0, f=i_c + 0.5) state.radial.create_wavefunction() - with pytest.raises(ValueError, match="j_tot must be set"): + with pytest.raises(ValueError, match="Invalid combination of angular quantum numbers provided"): RydbergStateSQDTAlkali(species, n=50, l=1) elif species.number_valence_electrons == 2 and species._quantum_defects is not None: # noqa: SLF001 for s_tot in [0, 1]: