from __future__ import annotations
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from pathlib import Path
from typing import Any, BinaryIO, Literal
from pandas import DataFrame
from atomicds.core import BaseClient, ClientError, _FileSlice
from atomicds.core.utils import _make_progress, normalize_path
from atomicds.results import (
RHEEDImageResult,
RHEEDVideoResult,
XPSResult,
_get_rheed_image_result,
)
from atomicds.timeseries.registry import get_provider
TimeseriesDomain = Literal["rheed", "optical", "metrology"]
[docs]
class Client(BaseClient):
"""Atomic Data Sciences API client"""
[docs]
def __init__(
self,
api_key: str | None = None,
endpoint: str = "https://api.atomicdatasciences.com/",
mute_bars: bool = False,
):
"""
Args:
api_key (str | None): API key. Defaults to None which will try and pull from the ADS_API_KEY environment variable.
endpoint (str): Root API endpoint. Will prioritize pulling from the ADS_API_ENDPOINT environment variable.
If none provided it defaults to 'https://api.atomicdatasciences.com/'.
mute_bars (bool): Whether to mute progress bars. Defaults to False.
"""
api_key = api_key or os.environ.get("ADS_API_KEY")
endpoint = os.environ.get("ADS_API_ENDPOINT") or endpoint
if api_key is None:
raise ValueError("No valid ADS API key supplied")
self.mute_bars = mute_bars
super().__init__(api_key=api_key, endpoint=endpoint)
[docs]
def search(
self,
keywords: str | list[str] | None = None,
include_organization_data: bool = True,
data_ids: str | list[str] | None = None,
data_type: Literal[
"rheed_image", "rheed_stationary", "rheed_rotating", "xps", "all"
] = "all",
status: Literal[
"success",
"pending",
"error",
"running",
"stream_active",
"stream_interrupted",
"stream_finalizing",
"stream_error",
"all",
] = "all",
growth_length: tuple[int | None, int | None] = (None, None),
upload_datetime: tuple[datetime | None, datetime | None] = (None, None),
last_accessed_datetime: tuple[datetime | None, datetime | None] = (None, None),
) -> DataFrame:
"""Search and obtain data catalogue entries
Args:
keywords (str | list[str] | None): Keyword or list of keywords to search all data catalogue fields with.
This searching is applied after all other explicit filters. Defaults to None.
include_organization_data (bool): Whether to include catalogue entries from other users in
your organization. Defaults to True.
data_ids (str | list[str] | None): Data ID or list of data IDs. Defaults to None.
data_type (Literal["rheed_image", "rheed_stationary", "rheed_rotating", "xps", "all"]): Type of data. Defaults to "all".
status (Literal["success", "pending", "error", "running", "all"]): Analyzed status of the data. Defaults to "all".
growth_length (tuple[int | None, int | None]): Minimum and maximum values of the growth length in seconds.
Defaults to (None, None) which will include all non-video data.
upload_datetime (tuple[datetime | None, datetime | None]): Minimum and maximum values of the upload datetime.
Defaults to (None, None).
last_accessed_datetime (tuple[datetime | None, datetime | None]): Minimum and maximum values of the last accessed datetime.
Defaults to (None, None).
Returns:
(DataFrame): Pandas DataFrame containing matched entries in the data catalogue.
"""
params = {
"keywords": keywords,
"include_organization_data": include_organization_data,
"data_ids": data_ids,
"data_type": None if data_type == "all" else data_type,
"status": status,
"growth_length_min": growth_length[0],
"growth_length_max": growth_length[1],
"upload_datetime_min": upload_datetime[0],
"upload_datetime_max": upload_datetime[1],
"last_accessed_datetime_min": last_accessed_datetime[0],
"last_accessed_datetime_max": last_accessed_datetime[1],
}
data = self._get(
sub_url="data_entries/",
params=params,
)
column_mapping = {
"data_id": "Data ID",
"upload_datetime": "Upload Datetime",
"last_accessed_datetime": "Last Accessed Datetime",
"char_source_type": "Type",
"raw_name": "File Name",
"pipeline_status": "Status",
"raw_file_type": "File Type",
"source_name": "Instrument Source",
"sample_name": "Sample Name",
"growth_length": "Growth Length",
"physical_sample_id": "Physical Sample ID",
"physical_sample_name": "Physical Sample Name",
"detail_note_content": "Sample Notes",
"detail_note_last_updated": "Sample Notes Last Updated",
"file_metadata": "File Metadata",
"tags": "Tags",
"name": "Owner",
"workspaces": "Workspaces",
}
columns_to_drop = [
"user_id",
"synth_source_id",
"sample_id",
"processed_file_type",
"bucket_file_name",
]
catalogue = DataFrame(data)
if len(catalogue):
catalogue = catalogue.drop(columns=columns_to_drop)
return catalogue.rename(columns=column_mapping)
[docs]
def get(
self, data_ids: str | list[str]
) -> list[RHEEDVideoResult | RHEEDImageResult | XPSResult]:
"""Get analyzed data results
Args:
data_ids (str | list[str]): Data ID or list of data IDs from the data catalogue to obtain analyzed results for.
Returns:
(list[RHEEDVideoResult | RHEEDVideoResult | XPSResult]): List of result objects
"""
if isinstance(data_ids, str):
data_ids = [data_ids]
data: list[dict] = self._get( # type: ignore # noqa: PGH003
sub_url="data_entries/",
params={
"data_ids": data_ids,
"include_organization_data": True,
},
)
kwargs_list = []
for entry in data:
data_id = entry["data_id"]
data_type = entry["char_source_type"]
kwargs_list.append({"data_id": data_id, "data_type": data_type})
# sort by submission order; this is important to match external labels
kwargs_list = sorted(kwargs_list, key=lambda x: data_ids.index(x["data_id"]))
with _make_progress(self.mute_bars, False) as progress:
return self._multi_thread(
self._get_result_data,
kwargs_list,
progress,
progress_description="Obtaining data results",
)
def _get_result_data(
self,
data_id: str,
data_type: Literal[
"xps",
"rheed_image",
"rheed_stationary",
"rheed_rotating",
"rheed_xscan",
"metrology",
"optical",
],
) -> RHEEDVideoResult | RHEEDImageResult | XPSResult | None:
if data_type == "xps":
result: dict = self._get(sub_url=f"xps/{data_id}") # type: ignore # noqa: PGH003
return XPSResult(
data_id=data_id,
xps_id=result["xps_id"],
binding_energies=result["binding_energies"],
intensities=result["intensities"],
predicted_composition=result["predicted_composition"],
detected_peaks=result["detected_peaks"],
elements_manually_set=bool(result["set_elements"]),
)
if data_type == "rheed_image":
return _get_rheed_image_result(self, data_id)
if data_type in [
"rheed_stationary",
"rheed_rotating",
"rheed_xscan",
"metrology",
"optical",
]:
timeseries_type = "rheed" if "rheed" in data_type else data_type
provider = get_provider(timeseries_type)
# Get timeseries data
raw = provider.fetch_raw(self, data_id)
ts_df = provider.to_dataframe(raw)
return provider.build_result(self, data_id, data_type, ts_df)
raise ValueError("Data type must be supported")
[docs]
def upload(self, files: list[str | BinaryIO]):
"""Upload and process files
Args:
files (list[str | BinaryIO]): List containing string paths to files, or BinaryIO objects from `open`.
"""
chunk_size = 40 * 1024 * 1024 # 40 MiB
# Check to make sure list is valid and get pre-signed URL nums
file_data = []
for file in files:
if isinstance(file, str):
path = normalize_path(file)
if not (path.exists() and path.is_file()):
raise ClientError(f"{path} is not a file or does not exist")
# Calculate number of URLs needed for this file
file_size = path.stat().st_size
num_urls = -(-file_size // chunk_size) # Ceiling division
file_name = path.name
else:
# Handle BinaryIO objects
file.seek(0, 2) # Seek to the end of the file
file_size = file.tell()
file.seek(0) # Seek back to the beginning of the file
num_urls = -(-file_size // chunk_size) # Ceiling division
file_name = file.name
file_data.append(
{
"num_urls": num_urls,
"file_name": file_name,
"file_size": file_size,
"file_path": file,
}
)
def __upload_file(
file_info: dict[
Literal["num_urls", "file_name", "file_size", "file_path"], int | str
],
):
url_data: list[dict[str, str | int]] = self._post_or_put(
method="POST",
sub_url="data_entries/raw_data/staged/upload_urls/",
params={
"original_filename": file_info["file_name"],
"num_parts": file_info["num_urls"],
"staging_type": "core",
},
) # type: ignore # noqa: PGH003
# Iterate through data structure above and upload file using multi-part S3 urls. Multithread appropriately.
# build kwargs_list using only serializable bits:
kwargs_list = []
for part in url_data:
part_no = int(part["part"]) - 1
offset = part_no * chunk_size
length = min(chunk_size, int(file_info["file_size"]) - offset) # type: ignore # noqa: PGH003
kwargs_list.append(
{
"method": "PUT",
"sub_url": "",
"params": None,
"base_override": part["url"],
"file_path": file_info["file_path"],
"offset": offset,
"length": length,
}
)
def __upload_chunk(
method: Literal["PUT", "POST"],
sub_url: str,
params: dict[str, Any] | None,
base_override: str,
file_path: Path,
offset: int,
length: int,
) -> Any:
slice_obj = _FileSlice(file_path, offset, length)
return self._post_or_put(
method=method,
sub_url=sub_url,
params=params,
body=slice_obj, # type: ignore # noqa: PGH003
deserialize=False,
return_headers=True,
base_override=base_override,
headers={
"Content-Length": str(length),
},
)
etag_data = self._multi_thread(
__upload_chunk,
kwargs_list=kwargs_list,
progress_bar=progress,
progress_description=f"[red]{file_info['file_name']}",
progress_kwargs={
"show_percent": True,
"show_total": False,
"show_spinner": False,
"pad": "",
},
transient=True,
)
# Complete multipart upload *only* if the backend issued an upload_id
first_part = url_data[0]
upload_id = first_part.get("upload_id")
if upload_id:
etag_body = [
{"ETag": entry["ETag"], "PartNumber": i + 1}
for i, entry in enumerate(etag_data)
]
self._post_or_put(
method="POST",
sub_url="data_entries/raw_data/staged/upload_urls/complete/",
params={"staging_type": "core"},
body={
"upload_id": upload_id,
"new_filename": first_part["new_filename"],
"etag_data": etag_body,
},
)
main_task = None
file_count = len(file_data)
with _make_progress(self.mute_bars, False) as progress:
if not progress.disable:
main_task = progress.add_task(
"Uploading files…",
total=file_count,
show_percent=False,
show_total=True,
show_spinner=True,
pad="",
)
max_workers = min(8, len(file_data))
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {
executor.submit(__upload_file, file_info): file_info # type: ignore # noqa: PGH003
for file_info in file_data
}
for future in as_completed(futures):
future.result() # raise early if anything went wrong
if main_task is not None:
progress.update(main_task, advance=1, refresh=True)
[docs]
def download_videos(
self,
data_ids: str | list[str],
dest_dir: str | Path | None = None,
data_type: Literal["raw", "processed"] = "processed",
):
"""
Download processed RHEED videos to disk.
Args:
data_ids (str | list[str]): One or more data IDs from the data catalogue.
dest_dir (str | Path | None): Directory to write the files to.
Defaults to the current working directory.
data_type (Literal["raw", "processed"]): Whether to download raw or processed data.
"""
chunk_size: int = 20 * 1024 * 1024 # 20 MiB read chunks
# Normalise inputs
if isinstance(data_ids, str):
data_ids = [data_ids]
if dest_dir is None:
dest_dir = Path.cwd()
else:
dest_dir = Path(dest_dir).expanduser().resolve()
dest_dir.mkdir(parents=True, exist_ok=True)
def __download_one(data_id: str) -> None:
# 1) Resolve the presigned URL -------------------------------------
url_type = "raw_data" if data_type == "raw" else "processed_data"
meta: dict = self._get( # type: ignore # noqa: PGH003
sub_url=f"data_entries/{url_type}/{data_id}",
params={"return_as": "url-download"},
)
if meta is None:
raise ClientError(f"No processed data found for data_id '{data_id}'")
url = meta["url"]
file_name = (
meta.get("file_name") or f"{data_id}.{meta.get('file_format', 'mp4')}"
)
target = dest_dir / file_name # type: ignore # noqa: PGH003
# 2) Open the stream *once* (HEAD not allowed)
with self._session.get( # type: ignore # noqa: PGH003
url, stream=True, allow_redirects=True, timeout=30
) as resp:
resp.raise_for_status()
# Attempt to read the size from **this** GET response
total_size = int(resp.headers.get("Content-Length", 0))
# 3) Create a nested bar for this file
if total_size: # we know the size → percent bar
bar_id = progress.add_task(
f"[red]{file_name}",
total=total_size,
show_percent=True,
show_total=False,
show_spinner=False,
pad="",
)
else: # unknown size → indeterminate spinner
bar_id = progress.add_task(
f"[red]{file_name}",
total=None,
show_percent=False,
show_total=False,
show_spinner=True,
pad="",
)
# 4) Stream the bytes to disk with updates
with Path.open(target, "wb") as fh:
for chunk in resp.iter_content(chunk_size):
if chunk: # filter out keep-alive
fh.write(chunk)
progress.update(bar_id, advance=len(chunk))
# Download files
with _make_progress(self.mute_bars, False) as progress:
# master bar
master_task = None
if not progress.disable:
master_task = progress.add_task(
"Downloading videos…",
total=len(data_ids),
show_percent=False,
show_total=True,
show_spinner=True,
pad="",
)
# thread-pool for concurrent downloads
max_workers = min(8, len(data_ids))
with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = {pool.submit(__download_one, did): did for did in data_ids}
for fut in as_completed(futures):
# propagate any exceptions early
fut.result()
if master_task is not None:
progress.update(master_task, advance=1, refresh=True)