diff --git a/mp_api/client/routes/materials/materials.py b/mp_api/client/routes/materials/materials.py index 7df557ef..bb200fbb 100644 --- a/mp_api/client/routes/materials/materials.py +++ b/mp_api/client/routes/materials/materials.py @@ -1,8 +1,11 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from emmet.core.settings import EmmetSettings from emmet.core.symmetry import CrystalSystem -from emmet.core.vasp.material import MaterialsDoc +from emmet.core.vasp.calc_types import RunType +from emmet.core.vasp.material import BlessedCalcs, MaterialsDoc from pymatgen.core.structure import Structure from mp_api.client.core import BaseRester, MPRestError @@ -37,6 +40,11 @@ XASRester, ) +if TYPE_CHECKING: + from typing import Any + + from pymatgen.entries.computed_entries import ComputedStructureEntry + _EMMET_SETTINGS = EmmetSettings() # type: ignore @@ -318,3 +326,75 @@ def find_structure( return material_ids # type: ignore return material_ids[0] + + def get_blessed_entries( + self, + run_type: str | RunType = RunType.r2SCAN, + material_ids: list[str] | None = None, + uncorrected_energy: tuple[float | None, float | None] | float | None = None, + num_chunks: int | None = None, + chunk_size: int = 1000, + ) -> list[dict[str, str | dict | ComputedStructureEntry]]: + """Get blessed calculation entries for a given material and run type. + + Args: + run_type (str or RunType): Calculation run type (e.g. GGA, GGA+U, r2SCAN, PBESol) + material_ids (list[str]): List of material ID values + uncorrected_energy (tuple[Optional[float], Optional[float]] | float): Tuple of minimum and maximum uncorrected DFT energy in eV/atom. + Note that if a single value is passed, it will be used as the minimum and maximum. + num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. + chunk_size (int): Number of data entries per chunk. + + Returns: + list of dict, of the form: + { + "material_id": MPID, + "blessed_entry": ComputedStructureEntry + } + """ + query_params: dict[str, Any] = {"run_type": str(run_type)} + if material_ids: + if isinstance(material_ids, str): + material_ids = [material_ids] + + query_params.update({"material_ids": ",".join(validate_ids(material_ids))}) + + if uncorrected_energy: + if isinstance(uncorrected_energy, float): + uncorrected_energy = (uncorrected_energy, uncorrected_energy) + + query_params.update( + { + "energy_min": uncorrected_energy[0], + "energy_max": uncorrected_energy[1], + } + ) + + results = self._query_resource( + query_params, + fields=["material_id", "entries"], + suburl="blessed_tasks", + parallel_param="material_ids" if material_ids else None, + chunk_size=chunk_size, + num_chunks=num_chunks, + ) + + return [ + { + "material_id": doc["material_id"], + "blessed_entry": ( + next( + getattr(doc["entries"], k, None) + for k in BlessedCalcs.model_fields + if getattr(doc["entries"], k, None) + ) + if self.use_document_model + else next( + doc["entries"][k] + for k in BlessedCalcs.model_fields + if doc["entries"].get(k) + ) + ), + } + for doc in (results.get("data") or []) + ] diff --git a/tests/materials/test_materials.py b/tests/materials/test_materials.py index 09e1e80d..76d2f175 100644 --- a/tests/materials/test_materials.py +++ b/tests/materials/test_materials.py @@ -62,3 +62,24 @@ def test_client(rester): custom_field_tests=custom_field_tests, sub_doc_fields=sub_doc_fields, ) + + +@pytest.mark.xfail(condition=True, reason="Needs new deployment.", strict=False) +@pytest.mark.parametrize( + "run_type, uncorrected_energy, use_document_model", + [("PBE", None, True), ("r2SCAN", 1.0, False), ("GGA_U", (-50e4, 0.0), True)], +) +def test_blessed_entry(run_type, uncorrected_energy, use_document_model): + # Si and NiO. Si has GGA and r2SCAN entries, NiO has GGA, GGA+U, and r2SCAN + with MaterialsRester(use_document_model=use_document_model) as rester: + blessed = rester.get_blessed_entries( + run_type, + material_ids=["mp-149", "mp-19009"], + uncorrected_energy=uncorrected_energy, + ) + + assert all( + isinstance(entry, dict) + and all(entry.get(k) is not None for k in ("material_id", "blessed_entry")) + for entry in blessed + )