from __future__ import annotations
import itertools
import os
import platform
import sys
from collections.abc import Callable # type: ignore[ruleName]
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
from importlib.metadata import version
from typing import Any, Literal
from urllib.parse import urljoin
from requests import Session
from requests.adapters import HTTPAdapter
from rich.progress import Progress
from urllib3.util.retry import Retry
__version__ = version("atomicds")
[docs]
class BaseClient:
"""Base API client implementation"""
[docs]
def __init__(
self,
api_key: str,
endpoint: str,
):
"""
Args:
api_key (str | None): API key.
endpoint (str): Root API endpoint.
"""
self.api_key = api_key
self.endpoint = endpoint
self._session = None
@property
def session(self) -> Session:
"""Session under which HTTP requests are issued"""
if not self._session:
self._session = self._create_session(self.api_key)
return self._session # type: ignore[return-value]
def _get(
self,
sub_url: str,
params: dict[str, Any] | None = None,
deserialize: bool = True,
base_override: str | None = None,
) -> list[dict[Any, Any]] | dict[Any, Any] | bytes | None:
"""Method for issuing a GET request
Args:
sub_url (str): API sub-url to use.
params (dict[str, Any] | None): Params to pass in the GET request. Defaults to None.
deserialize (bool): Whether to JSON deserialize the response data or return raw bytes. Defaults to True.
base_overrise (str): Base URL to use instead of the default ADS API root URL.
Raises:
ClientError: If the response code returned is not within the range of 200-400.
Returns:
(list[dict] | dict | bytes | None): Deserialized JSON data or raw bytes. Returns None if response is a 404.
"""
base_url = base_override or self.endpoint
response = self.session.get(
url=urljoin(base_url, sub_url), verify=True, params=params
)
if not response.ok:
if response.status_code == 404:
return None
raise ClientError(
f"Problem retrieving data from {sub_url} with parameters {params}. HTTP Error {response.status_code}: {response.text}"
)
if len(response.content) == 0:
return None
return response.json() if deserialize else response.content
def _post_or_put(
self,
method: Literal["POST", "PUT"],
sub_url: str,
params: dict[str, Any] | None = None,
body: dict[str, Any] | bytes | None = None,
headers: dict[str, str] | None = None,
deserialize: bool = True,
base_override: str | None = None,
return_headers: bool = False,
) -> list[dict[Any, Any]] | dict[Any, Any] | bytes | None:
"""Method for issuing a POST or PUT request
Args:
method (Literal["POST", "PUT"]): Method to use
sub_url (str): API sub-url to use.
params (dict[str, Any] | None): Params to pass in the GET request. Defaults to None.
body (dict[str, Any] | bytes): Body data to send in the POST request.
headers (dict[str, str] | None): Optional headers to include in the request.
deserialize (bool): Whether to JSON deserialize the response data or return raw bytes. Defaults to True.
base_overrise (str): Base URL to use instead of the default ADS API root URL.
return_headers (bool): Whether to return the headers from the response instead of the content. Defaults
to False.
Raises:
ClientError: If the response code returned is not within the range of 200-400.
Returns:
(list[dict] | dict | bytes | None): Deserialized JSON data or raw bytes. Returns None if response is a 404.
"""
base_url = base_override or self.endpoint
method_func = self.session.put if method == "PUT" else self.session.post
# decide whether to use data= (bytes/streams) or json=
if body is None:
data_params: dict[str, Any] = {}
elif isinstance(body, bytes | bytearray):
data_params = {"data": body}
elif hasattr(body, "read"):
# any file-like / RawIOBase
data_params = {"data": body}
else:
# everything else (dict, list, etc.) goes through JSON
data_params = {"json": body}
response = method_func(
url=urljoin(base_url, sub_url),
verify=True,
params=params,
headers=headers or {},
**data_params, # type: ignore # noqa: PGH003
)
if not response.ok:
if response.status_code == 404:
return None
raise ClientError(
f"Problem sending data to {sub_url}. HTTP Error {response.status_code}: {response.text}"
)
if return_headers:
return_data: dict[Any, Any] = response.headers # type: ignore # noqa: PGH003
else:
return_data = response.json() if deserialize else response.content # type: ignore #noqa: PGH003
return return_data
def _multi_thread(
self,
func: Callable[..., Any],
kwargs_list: list[dict[str, Any]],
progress_bar: Progress | None = None,
progress_description: str | None = None,
progress_kwargs: dict | None = None,
transient: bool = False,
) -> list[Any]:
"""Handles running a function concurrently with a ThreadPoolExecutor
Arguments:
func (Callable): Function to run concurrently
kwargs_list (list): List of keyword argument inputs for the function
progress_bar (Progress | None): Progress bar to show. Defaults to None.
progress_description (str | None): Progress bar description.
progress_kwargs (dict | None): Additional kwargs to pass to the progress task.
transient (bool): Whether the progress bar is transient. Defaults to False,
Returns:
(list[Any]): List of results from passed function in the order of parameters passed
"""
return_dict = {}
total_count = len(kwargs_list)
kwargs_gen = iter(kwargs_list)
if progress_bar is not None:
progress_kwargs = progress_kwargs or {"pad": ""}
task = progress_bar.add_task(
progress_description or "", total=total_count, **progress_kwargs
)
ind = 0
num_parallel = min(os.cpu_count() or 8, 8)
with ThreadPoolExecutor(max_workers=num_parallel) as executor:
# Get list of initial futures defined by max number of parallel requests
futures = set()
for kwargs in itertools.islice(kwargs_gen, num_parallel):
future = executor.submit(
func,
**kwargs,
)
future.ind = ind # type: ignore # noqa: PGH003
futures.add(future)
ind += 1
while futures:
# Wait for at least one future to complete and process finished
finished, futures = wait(futures, return_when=FIRST_COMPLETED)
for future in finished:
data = future.result()
if progress_bar is not None:
progress_bar.update(task, advance=1, refresh=True) # type: ignore # noqa: PGH003
return_dict[future.ind] = data # type: ignore # noqa: PGH003
# Populate more futures to replace finished
for kwargs in itertools.islice(kwargs_gen, len(finished)):
new_future = executor.submit(
func,
**kwargs,
)
new_future.ind = ind # type: ignore # noqa: PGH003
futures.add(new_future)
ind += 1
if progress_bar is not None and transient:
progress_bar.remove_task(task) # type: ignore # noqa: PGH003
return [t[1] for t in sorted(return_dict.items())]
@staticmethod
def _create_session(api_key: str):
"""Create a requests session
Args:
api_key (str): API key to include in the header.
Returns:
(Session): Requests Session object
"""
session = Session()
session.headers = {"X-API-KEY": api_key}
# User agent information
atomicds_info = "atomicds/" + __version__
python_info = f"Python/{sys.version.split()[0]}"
platform_info = f"{platform.system()}/{platform.release()}"
session.headers[
"user-agent"
] = f"{atomicds_info} ({python_info} {platform_info})"
# TODO: Add retry setting to configuration somewhere
max_retry_num = 3
retry = Retry(
total=max_retry_num,
read=max_retry_num,
connect=max_retry_num,
respect_retry_after_header=True,
status_forcelist=[429, 504, 502],
)
adapter = HTTPAdapter(max_retries=retry)
session.mount("http://", adapter)
session.mount("https://", adapter)
return session
[docs]
class ClientError(Exception):
"""Generic error thrown by the Atomic Data Sciences API client"""