Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mp_api/client/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations

from .client import BaseRester, MPRestError, MPRestWarning
from .client import BaseRester
from .exceptions import MPRestError, MPRestWarning
from .settings import MAPIClientSettings
36 changes: 14 additions & 22 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from tqdm.auto import tqdm
from urllib3.util.retry import Retry

from mp_api.client.core.exceptions import MPRestError
from mp_api.client.core.settings import MAPIClientSettings
from mp_api.client.core.utils import load_json, validate_ids

Expand Down Expand Up @@ -92,11 +93,11 @@ def __init__(
session: requests.Session | None = None,
s3_client: Any | None = None,
debug: bool = False,
monty_decode: bool = True,
use_document_model: bool = True,
timeout: int = 20,
headers: dict | None = None,
mute_progress_bars: bool = SETTINGS.MUTE_PROGRESS_BARS,
**kwargs,
):
"""Initialize the REST API helper class.

Expand All @@ -121,13 +122,13 @@ def __init__(
advanced usage only.
s3_client: boto3 S3 client object with which to connect to the object stores.ct to the object stores.ct to the object stores.
debug: if True, print the URL for every request
monty_decode: Decode the data using monty into python objects
use_document_model: If False, skip the creating the document model and return data
as a dictionary. This can be simpler to work with but bypasses data validation
and will not give auto-complete for available fields.
timeout: Time in seconds to wait until a request timeout error is thrown
headers: Custom headers for localhost connections.
mute_progress_bars: Whether to disable progress bars.
**kwargs: access to legacy kwargs that may be in the process of being deprecated
"""
# TODO: think about how to migrate from PMG_MAPI_KEY
self.api_key = api_key or os.getenv("MP_API_KEY")
Expand All @@ -136,7 +137,6 @@ def __init__(
)
self.debug = debug
self.include_user_agent = include_user_agent
self.monty_decode = monty_decode
self.use_document_model = use_document_model
self.timeout = timeout
self.headers = headers or {}
Expand All @@ -151,6 +151,12 @@ def __init__(
self._session = session
self._s3_client = s3_client

if "monty_decode" in kwargs:
warnings.warn(
"Ignoring `monty_decode`, as it is no longer a supported option in `mp_api`."
"The client by default returns results consistent with `monty_decode=True`."
)

@property
def session(self) -> requests.Session:
if not self._session:
Expand Down Expand Up @@ -265,7 +271,7 @@ def _post_resource(
response = self.session.post(url, json=payload, verify=True, params=params)

if response.status_code == 200:
data = load_json(response.text, deser=self.monty_decode)
data = load_json(response.text)
if self.document_model and use_document_model:
if isinstance(data["data"], dict):
data["data"] = self.document_model.model_validate(data["data"]) # type: ignore
Expand Down Expand Up @@ -333,7 +339,7 @@ def _patch_resource(
response = self.session.patch(url, json=payload, verify=True, params=params)

if response.status_code == 200:
data = load_json(response.text, deser=self.monty_decode)
data = load_json(response.text)
if self.document_model and use_document_model:
if isinstance(data["data"], dict):
data["data"] = self.document_model.model_validate(data["data"]) # type: ignore
Expand Down Expand Up @@ -384,10 +390,7 @@ def _query_open_data(
Returns:
dict: MontyDecoded data
"""
if not decoder:

def decoder(x):
return load_json(x, deser=self.monty_decode)
decoder = decoder or load_json

file = open(
f"s3://{bucket}/{key}",
Expand Down Expand Up @@ -997,7 +1000,7 @@ def _submit_request_and_process(
)

if response.status_code == 200:
data = load_json(response.text, deser=self.monty_decode)
data = load_json(response.text)
# other sub-urls may use different document models
# the client does not handle this in a particularly smart way currently
if self.document_model and use_document_model:
Expand Down Expand Up @@ -1302,12 +1305,10 @@ def count(self, criteria: dict | None = None) -> int | str:
"""
criteria = criteria or {}
user_preferences = (
self.monty_decode,
self.use_document_model,
self.mute_progress_bars,
)
self.monty_decode, self.use_document_model, self.mute_progress_bars = (
False,
self.use_document_model, self.mute_progress_bars = (
False,
True,
) # do not waste cycles decoding
Expand All @@ -1329,7 +1330,6 @@ def count(self, criteria: dict | None = None) -> int | str:
)

(
self.monty_decode,
self.use_document_model,
self.mute_progress_bars,
) = user_preferences
Expand All @@ -1351,11 +1351,3 @@ def __str__(self): # pragma: no cover
f"{self.__class__.__name__} connected to {self.endpoint}\n\n"
f"Available fields: {', '.join(self.available_fields)}\n\n"
)


class MPRestError(Exception):
"""Raised when the query has problems, e.g., bad query format."""


class MPRestWarning(Warning):
"""Raised when a query is malformed but interpretable."""
10 changes: 10 additions & 0 deletions mp_api/client/core/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Define custom exceptions and warnings for the client."""
from __future__ import annotations


class MPRestError(Exception):
"""Raised when the query has problems, e.g., bad query format."""


class MPRestWarning(Warning):
"""Raised when a query is malformed but interpretable."""
15 changes: 13 additions & 2 deletions mp_api/client/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from multiprocessing import cpu_count
from typing import List

from pydantic import Field
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from pymatgen.core import _load_pmg_settings

Expand All @@ -14,6 +14,7 @@
_MUTE_PROGRESS_BAR = PMG_SETTINGS.get("MPRESTER_MUTE_PROGRESS_BARS", False)
_MAX_HTTP_URL_LENGTH = PMG_SETTINGS.get("MPRESTER_MAX_HTTP_URL_LENGTH", 2000)
_MAX_LIST_LENGTH = min(PMG_SETTINGS.get("MPRESTER_MAX_LIST_LENGTH", 10000), 10000)
_DEFAULT_ENDPOINT = "https://api.materialsproject.org/"

try:
CPU_COUNT = cpu_count()
Expand Down Expand Up @@ -80,11 +81,21 @@ class MAPIClientSettings(BaseSettings):
)

MIN_EMMET_VERSION: str = Field(
"0.54.0", description="Minimum compatible version of emmet-core for the client."
"0.86.3rc0",
description="Minimum compatible version of emmet-core for the client.",
)

MAX_LIST_LENGTH: int = Field(
_MAX_LIST_LENGTH, description="Maximum length of query parameter list"
)

ENDPOINT: str = Field(
_DEFAULT_ENDPOINT, description="The default API endpoint to use."
)

model_config = SettingsConfigDict(env_prefix="MPRESTER_")

@field_validator("ENDPOINT", mode="before")
def _get_endpoint_from_env(cls, v: str | None) -> str:
"""Support setting endpoint via MP_API_ENDPOINT environment variable."""
return v or os.environ.get("MP_API_ENDPOINT") or _DEFAULT_ENDPOINT
25 changes: 23 additions & 2 deletions mp_api/client/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Literal

import orjson
Expand All @@ -8,6 +9,7 @@
from monty.json import MontyDecoder
from packaging.version import parse as parse_version

from mp_api.client.core.exceptions import MPRestError
from mp_api.client.core.settings import MAPIClientSettings

if TYPE_CHECKING:
Expand Down Expand Up @@ -50,20 +52,39 @@ def load_json(
return MontyDecoder().process_decoded(data) if deser else data


def validate_api_key(api_key: str | None = None) -> str:
"""Find and validate an API key."""
# SETTINGS tries to read API key from ~/.config/.pmgrc.yaml
api_key = api_key or os.getenv("MP_API_KEY")
if not api_key:
from pymatgen.core import SETTINGS

api_key = SETTINGS.get("PMG_MAPI_KEY")

if not api_key or len(api_key) != 32:
addendum = " Valid API keys are 32 characters." if api_key else ""
raise MPRestError(
"Please obtain a valid API key from https://materialsproject.org/api "
f"and export it as an environment variable `MP_API_KEY`.{addendum}"
)

return api_key


def validate_ids(id_list: list[str]) -> list[str]:
"""Function to validate material and task IDs.

Args:
id_list (List[str]): List of material or task IDs.

Raises:
ValueError: If at least one ID is not formatted correctly.
MPRestError: If at least one ID is not formatted correctly.

Returns:
id_list: Returns original ID list if everything is formatted correctly.
"""
if len(id_list) > MAPIClientSettings().MAX_LIST_LENGTH:
raise ValueError(
raise MPRestError(
"List of material/molecule IDs provided is too long. Consider removing the ID filter to automatically pull"
" data for all IDs and filter locally."
)
Expand Down
Loading