Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions src/aignostics/platform/_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Authenticated API wrapper and configuration.

This module defines the thin API subclass and configuration that lift
``token_provider`` to a first-class attribute. Kept separate from ``_client``
so that resource modules can import these types directly without circular
dependencies.
"""

from collections.abc import Callable

from aignx.codegen.api.public_api import PublicApi
from aignx.codegen.api_client import ApiClient
from aignx.codegen.configuration import AuthSettings, Configuration
from loguru import logger


class _OAuth2TokenProviderConfiguration(Configuration):
"""Overwrites the original Configuration to call a function to obtain a bearer token.

The base class does not support callbacks. This is necessary for integrations where
access tokens may expire or need to be refreshed or rotated automatically.
"""

def __init__(
self, host: str, ssl_ca_cert: str | None = None, token_provider: Callable[[], str] | None = None
) -> None:
super().__init__(host=host, ssl_ca_cert=ssl_ca_cert)
self.token_provider = token_provider

def auth_settings(self) -> AuthSettings:
token = self.token_provider() if self.token_provider else None
if not token:
if self.token_provider is not None:
logger.warning(
"token_provider returned an empty or None token; "
"request will proceed without an Authorization header"
)
return {}
return {
"OAuth2AuthorizationCodeBearer": {
"type": "oauth2",
"in": "header",
"key": "Authorization",
"value": f"Bearer {token}",
Comment on lines +39 to +44
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The external token_provider return value is always wrapped as Authorization: Bearer {token} in _OAuth2TokenProviderConfiguration.auth_settings(). If callers pass a provider that returns a fully-formed header value (e.g., already prefixed with Bearer ), requests will end up with Bearer Bearer ... and fail authentication. Consider normalizing (strip a leading Bearer prefix before formatting, or accept already-prefixed values) and/or clearly documenting that the provider must return the raw access token without the scheme prefix.

Suggested change
return {
"OAuth2AuthorizationCodeBearer": {
"type": "oauth2",
"in": "header",
"key": "Authorization",
"value": f"Bearer {token}",
# Normalize token to avoid double 'Bearer ' prefixes if the provider
# already returns a value starting with 'Bearer '.
token_str = str(token).strip()
bearer_value: str
if token_str.lower().startswith("bearer "):
bearer_value = token_str
else:
bearer_value = f"Bearer {token_str}"
return {
"OAuth2AuthorizationCodeBearer": {
"type": "oauth2",
"in": "header",
"key": "Authorization",
"value": bearer_value,

Copilot uses AI. Check for mistakes.
}
}


class _AuthenticatedApi(PublicApi):
"""Thin wrapper around the generated :class:`PublicApi`.

Lifts ``token_provider`` from the deeply-nested ``Configuration`` to a
top-level attribute, making it accessible without traversing codegen internals.
"""

token_provider: Callable[[], str] | None

def __init__(self, api_client: ApiClient, token_provider: Callable[[], str] | None = None) -> None:
super().__init__(api_client)
self.token_provider = token_provider


class _AuthenticatedResource:
"""Base for platform resource classes that require an authenticated API client.

Validates at construction time that the provided API object is a genuine
:class:`_AuthenticatedApi` instance, ensuring ``token_provider`` is available
for per-user cache key isolation in ``@cached_operation``.
"""

_api: _AuthenticatedApi

def __init__(self, api: _AuthenticatedApi) -> None:
"""Initialize with an authenticated API client.

Args:
api: The configured API client providing ``token_provider``.

Raises:
TypeError: If *api* is not an :class:`_AuthenticatedApi` instance.
"""
if not isinstance(api, _AuthenticatedApi): # runtime guard for untyped callers
msg = ( # type: ignore[unreachable]
f"{type(self).__name__} requires _AuthenticatedApi, "
f"got {type(api).__name__!r}. "
"Use Client to obtain a correctly configured instance."
)
raise TypeError(msg)
self._api = api
118 changes: 60 additions & 58 deletions src/aignostics/platform/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
from urllib.request import getproxies

import semver
from aignx.codegen.api.public_api import PublicApi
from aignx.codegen.api_client import ApiClient
from aignx.codegen.configuration import AuthSettings, Configuration
from aignx.codegen.exceptions import NotFoundException, ServiceException
from aignx.codegen.models import ApplicationReadResponse as Application
from aignx.codegen.models import MeReadResponse as Me
Expand All @@ -22,6 +20,7 @@
from urllib3.exceptions import IncompleteRead, PoolError, ProtocolError, ProxyError
from urllib3.exceptions import TimeoutError as Urllib3TimeoutError

from aignostics.platform._api import _AuthenticatedApi, _OAuth2TokenProviderConfiguration
from aignostics.platform._authentication import get_token
from aignostics.platform._operation_cache import cached_operation
from aignostics.platform.resources.applications import Applications, Versions
Expand All @@ -30,6 +29,10 @@

from ._settings import settings

# Safety bound for the external token-provider cache. In normal usage callers
# reuse a single provider reference, so this limit should never be reached.
_MAX_EXTERNAL_CLIENTS = 16

RETRYABLE_EXCEPTIONS = (
ServiceException,
Urllib3TimeoutError,
Expand Down Expand Up @@ -59,62 +62,42 @@ def _log_retry_attempt(retry_state: RetryCallState) -> None:
)


class _OAuth2TokenProviderConfiguration(Configuration):
"""
Overwrites the original Configuration to call a function to obtain a refresh token.

The base class does not support callbacks. This is necessary for integrations where
tokens may expire or need to be refreshed automatically.
"""

def __init__(
self, host: str, ssl_ca_cert: str | None = None, token_provider: Callable[[], str] | None = None
) -> None:
super().__init__(host=host, ssl_ca_cert=ssl_ca_cert)
self.token_provider = token_provider

def auth_settings(self) -> AuthSettings:
token = self.token_provider() if self.token_provider else None
if not token:
return {}
return {
"OAuth2AuthorizationCodeBearer": {
"type": "oauth2",
"in": "header",
"key": "Authorization",
"value": f"Bearer {token}",
}
}


class Client:
"""Main client for interacting with the Aignostics Platform API.

- Provides access to platform resources like applications, versions, and runs.
- Handles authentication and API client configuration.
- Supports external token providers for machine-to-machine or custom auth flows.
- Retries on network and server errors for specific operations.
- Caches operation results for specific operations.
"""

_api_client_cached: ClassVar[PublicApi | None] = None
_api_client_uncached: ClassVar[PublicApi | None] = None
_api_client_cached: ClassVar[_AuthenticatedApi | None] = None
_api_client_uncached: ClassVar[_AuthenticatedApi | None] = None
_api_client_external: ClassVar[dict[Callable[[], str], _AuthenticatedApi]] = {}

_api: _AuthenticatedApi
applications: Applications
versions: Versions
runs: Runs

def __init__(self, cache_token: bool = True) -> None:
def __init__(self, cache_token: bool = True, token_provider: Callable[[], str] | None = None) -> None:
"""Initializes a client instance with authenticated API access.

Args:
cache_token (bool): If True, caches the authentication token.
Defaults to True.
cache_token: If True, caches the authentication token. Defaults to True.
Ignored when ``token_provider`` is supplied.
token_provider: Optional external token provider callable. When provided,
bypasses internal OAuth authentication entirely. The callable must
return a raw access token string (without the ``Bearer `` prefix).
When set, ``cache_token`` has no effect because the external provider
manages its own token lifecycle.

Sets up resource accessors for applications, versions, and runs.
"""
try:
logger.trace("Initializing client with cache_token={}", cache_token)
self._api = Client.get_api_client(cache_token=cache_token)
logger.trace("Initializing client with cache_token={}, token_provider={}", cache_token, token_provider)
self._api = Client.get_api_client(cache_token=cache_token, token_provider=token_provider)
self.applications: Applications = Applications(self._api)
self.runs: Runs = Runs(self._api)
self.versions: Versions = Versions(self._api)
Expand Down Expand Up @@ -143,7 +126,7 @@ def me(self, nocache: bool = False) -> Me:
aignx.codegen.exceptions.ApiException: If the API call fails.
"""

@cached_operation(ttl=settings().me_cache_ttl, use_token=True)
@cached_operation(ttl=settings().me_cache_ttl, token_provider=self._api.token_provider)
def me_with_retry() -> Me:
return Retrying( # We are not using Tenacity annotations as settings can change at runtime
retry=retry_if_exception_type(exception_types=RETRYABLE_EXCEPTIONS),
Expand Down Expand Up @@ -177,7 +160,7 @@ def application(self, application_id: str, nocache: bool = False) -> Application
Application: The application object.
"""

@cached_operation(ttl=settings().application_cache_ttl, use_token=True)
@cached_operation(ttl=settings().application_cache_ttl, token_provider=self._api.token_provider)
def application_with_retry(application_id: str) -> Application:
return Retrying(
retry=retry_if_exception_type(exception_types=RETRYABLE_EXCEPTIONS),
Expand Down Expand Up @@ -234,7 +217,7 @@ def application_version(
raise ValueError(message)

# Make the API call with retry logic and caching
@cached_operation(ttl=settings().application_version_cache_ttl, use_token=True)
@cached_operation(ttl=settings().application_version_cache_ttl, token_provider=self._api.token_provider)
def application_version_with_retry(application_id: str, version: str) -> ApplicationVersion:
return Retrying(
retry=retry_if_exception_type(exception_types=RETRYABLE_EXCEPTIONS),
Expand Down Expand Up @@ -268,44 +251,63 @@ def run(self, run_id: str) -> Run:
return Run(self._api, run_id)

@staticmethod
def get_api_client(cache_token: bool = True) -> PublicApi:
def get_api_client(cache_token: bool = True, token_provider: Callable[[], str] | None = None) -> _AuthenticatedApi:
"""Create and configure an authenticated API client.

API client instances are shared across all Client instances for efficient connection reuse.
Two separate instances are maintained: one for cached tokens and one for uncached tokens.
Three pools are maintained: cached-token, uncached-token, and external-provider (keyed by
the provider callable — callers should reuse a stable ``token_provider`` reference for
connection reuse).

Args:
cache_token (bool): If True, caches the authentication token.
Defaults to True.
cache_token: If True, caches the authentication token. Defaults to True.
token_provider: Optional external token provider. When provided, bypasses
internal OAuth and uses this callable to obtain bearer tokens.

Returns:
PublicApi: Configured API client with authentication token.
_AuthenticatedApi: Configured API client with authentication token.

Raises:
RuntimeError: If authentication fails.
"""
# Return cached instance if available
if cache_token and Client._api_client_cached is not None:
# Check singleton caches first
if token_provider is not None:
if token_provider in Client._api_client_external:
return Client._api_client_external[token_provider]
elif cache_token and Client._api_client_cached is not None:
return Client._api_client_cached
if not cache_token and Client._api_client_uncached is not None:
elif not cache_token and Client._api_client_uncached is not None:
return Client._api_client_uncached

def token_provider() -> str:
return get_token(use_cache=cache_token)
# Resolve the effective token provider
effective_provider: Callable[[], str] = (
token_provider if token_provider is not None else (lambda: get_token(use_cache=cache_token))
)

# Build the API client
ca_file = os.getenv("REQUESTS_CA_BUNDLE") # point to .cer file of proxy if defined
config = _OAuth2TokenProviderConfiguration(
host=settings().api_root, ssl_ca_cert=ca_file, token_provider=token_provider
host=settings().api_root, ssl_ca_cert=ca_file, token_provider=effective_provider
)
config.proxy = getproxies().get("https") # use system proxy
client = ApiClient(
config,
)
client = ApiClient(config)
client.user_agent = user_agent()
api_client = PublicApi(client)

# Cache the instance
if cache_token:
api_client = _AuthenticatedApi(client, effective_provider)

# Store in the appropriate singleton cache.
# For external providers we use a simple bounded dict rather than LRU:
# switching providers is rare in practice, and a full clear is simpler
# than tracking access order while still bounding memory.
if token_provider is not None:
if len(Client._api_client_external) >= _MAX_EXTERNAL_CLIENTS:
logger.warning(
"External token provider cache exceeded {} entries; clearing to prevent resource leak. "
"Reuse a stable token_provider reference for optimal connection reuse.",
_MAX_EXTERNAL_CLIENTS,
)
Client._api_client_external.clear()
Client._api_client_external[token_provider] = api_client
elif cache_token:
Client._api_client_cached = api_client
else:
Client._api_client_uncached = api_client
Expand Down
28 changes: 18 additions & 10 deletions src/aignostics/platform/_operation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
- Supports selective cache clearing by function
"""

from __future__ import annotations

import functools
import hashlib
import time
import typing as t
from collections.abc import Callable
from typing import Any, ParamSpec, TypeVar
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar

from ._authentication import get_token
if TYPE_CHECKING:
from collections.abc import Callable

# Cache storage for operation results
_operation_cache: dict[str, tuple[Any, float]] = {}
Expand Down Expand Up @@ -92,16 +95,19 @@ def cache_key_with_token(token: str, func_qualified_name: str, *args: object, **


def cached_operation(
ttl: int, *, use_token: bool = True, instance_attrs: tuple[str, ...] | None = None
ttl: int,
*,
token_provider: Callable[[], str] | None = None,
instance_attrs: tuple[str, ...] | None = None,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Caches the result of a function call for a specified time-to-live (TTL).

Args:
ttl (int): Time-to-live for the cache in seconds.
use_token (bool): If True, includes the authentication token in the cache key.
This is useful for Client methods that should cache per-user.
When use_token is True and no instance_attrs are specified, the 'self'
argument is excluded from the cache key to enable cache sharing across instances.
token_provider (Callable[[], str] | None): A callable returning the current
authentication token string. When provided, the token is included in the
cache key for per-user isolation. Pass ``None`` to omit the token from
the cache key.
instance_attrs (tuple[str, ...] | None): Instance attributes to include in the cache key.
This is useful for instance methods where caching should be per-instance based on
specific attributes (e.g., 'run_id' for Run.details()).
Expand All @@ -116,6 +122,7 @@ def cached_operation(
"""

def decorator(func: Callable[P, T]) -> Callable[P, T]:
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
# Check if nocache is requested and remove it from kwargs before passing to func
nocache = kwargs.pop("nocache", False)
Expand All @@ -132,8 +139,9 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
instance_values = tuple(getattr(instance, attr) for attr in instance_attrs)
cache_args = instance_values + args[1:]

if use_token:
key = cache_key_with_token(get_token(True), func_qualified_name, *cache_args, **kwargs)
if token_provider is not None:
token = token_provider()
key = cache_key_with_token(token, func_qualified_name, *cache_args, **kwargs)
else:
key = cache_key(func_qualified_name, *cache_args, **kwargs)

Expand Down
Loading
Loading