diff --git a/src/aignostics/platform/_api.py b/src/aignostics/platform/_api.py new file mode 100644 index 000000000..d90dad243 --- /dev/null +++ b/src/aignostics/platform/_api.py @@ -0,0 +1,89 @@ +"""Authenticated API wrapper and configuration. + +This module defines the thin API subclass and configuration that lift +``token_provider`` to a first-class attribute. Kept separate from ``_client`` +so that resource modules can import these types directly without circular +dependencies. +""" + +from collections.abc import Callable + +from aignx.codegen.api.public_api import PublicApi +from aignx.codegen.api_client import ApiClient +from aignx.codegen.configuration import AuthSettings, Configuration +from loguru import logger + + +class _OAuth2TokenProviderConfiguration(Configuration): + """Overwrites the original Configuration to call a function to obtain a bearer token. + + The base class does not support callbacks. This is necessary for integrations where + access tokens may expire or need to be refreshed or rotated automatically. + """ + + def __init__( + self, host: str, ssl_ca_cert: str | None = None, token_provider: Callable[[], str] | None = None + ) -> None: + super().__init__(host=host, ssl_ca_cert=ssl_ca_cert) + self.token_provider = token_provider + + def auth_settings(self) -> AuthSettings: + token = self.token_provider() if self.token_provider else None + if not token: + if self.token_provider is not None: + logger.warning( + "token_provider returned an empty or None token; " + "request will proceed without an Authorization header" + ) + return {} + return { + "OAuth2AuthorizationCodeBearer": { + "type": "oauth2", + "in": "header", + "key": "Authorization", + "value": f"Bearer {token}", + } + } + + +class _AuthenticatedApi(PublicApi): + """Thin wrapper around the generated :class:`PublicApi`. + + Lifts ``token_provider`` from the deeply-nested ``Configuration`` to a + top-level attribute, making it accessible without traversing codegen internals. + """ + + token_provider: Callable[[], str] | None + + def __init__(self, api_client: ApiClient, token_provider: Callable[[], str] | None = None) -> None: + super().__init__(api_client) + self.token_provider = token_provider + + +class _AuthenticatedResource: + """Base for platform resource classes that require an authenticated API client. + + Validates at construction time that the provided API object is a genuine + :class:`_AuthenticatedApi` instance, ensuring ``token_provider`` is available + for per-user cache key isolation in ``@cached_operation``. + """ + + _api: _AuthenticatedApi + + def __init__(self, api: _AuthenticatedApi) -> None: + """Initialize with an authenticated API client. + + Args: + api: The configured API client providing ``token_provider``. + + Raises: + TypeError: If *api* is not an :class:`_AuthenticatedApi` instance. + """ + if not isinstance(api, _AuthenticatedApi): # runtime guard for untyped callers + msg = ( # type: ignore[unreachable] + f"{type(self).__name__} requires _AuthenticatedApi, " + f"got {type(api).__name__!r}. " + "Use Client to obtain a correctly configured instance." + ) + raise TypeError(msg) + self._api = api diff --git a/src/aignostics/platform/_client.py b/src/aignostics/platform/_client.py index 02f8dd014..fb52b7d57 100644 --- a/src/aignostics/platform/_client.py +++ b/src/aignostics/platform/_client.py @@ -4,9 +4,7 @@ from urllib.request import getproxies import semver -from aignx.codegen.api.public_api import PublicApi from aignx.codegen.api_client import ApiClient -from aignx.codegen.configuration import AuthSettings, Configuration from aignx.codegen.exceptions import NotFoundException, ServiceException from aignx.codegen.models import ApplicationReadResponse as Application from aignx.codegen.models import MeReadResponse as Me @@ -22,6 +20,7 @@ from urllib3.exceptions import IncompleteRead, PoolError, ProtocolError, ProxyError from urllib3.exceptions import TimeoutError as Urllib3TimeoutError +from aignostics.platform._api import _AuthenticatedApi, _OAuth2TokenProviderConfiguration from aignostics.platform._authentication import get_token from aignostics.platform._operation_cache import cached_operation from aignostics.platform.resources.applications import Applications, Versions @@ -30,6 +29,10 @@ from ._settings import settings +# Safety bound for the external token-provider cache. In normal usage callers +# reuse a single provider reference, so this limit should never be reached. +_MAX_EXTERNAL_CLIENTS = 16 + RETRYABLE_EXCEPTIONS = ( ServiceException, Urllib3TimeoutError, @@ -59,62 +62,42 @@ def _log_retry_attempt(retry_state: RetryCallState) -> None: ) -class _OAuth2TokenProviderConfiguration(Configuration): - """ - Overwrites the original Configuration to call a function to obtain a refresh token. - - The base class does not support callbacks. This is necessary for integrations where - tokens may expire or need to be refreshed automatically. - """ - - def __init__( - self, host: str, ssl_ca_cert: str | None = None, token_provider: Callable[[], str] | None = None - ) -> None: - super().__init__(host=host, ssl_ca_cert=ssl_ca_cert) - self.token_provider = token_provider - - def auth_settings(self) -> AuthSettings: - token = self.token_provider() if self.token_provider else None - if not token: - return {} - return { - "OAuth2AuthorizationCodeBearer": { - "type": "oauth2", - "in": "header", - "key": "Authorization", - "value": f"Bearer {token}", - } - } - - class Client: """Main client for interacting with the Aignostics Platform API. - Provides access to platform resources like applications, versions, and runs. - Handles authentication and API client configuration. + - Supports external token providers for machine-to-machine or custom auth flows. - Retries on network and server errors for specific operations. - Caches operation results for specific operations. """ - _api_client_cached: ClassVar[PublicApi | None] = None - _api_client_uncached: ClassVar[PublicApi | None] = None + _api_client_cached: ClassVar[_AuthenticatedApi | None] = None + _api_client_uncached: ClassVar[_AuthenticatedApi | None] = None + _api_client_external: ClassVar[dict[Callable[[], str], _AuthenticatedApi]] = {} + _api: _AuthenticatedApi applications: Applications versions: Versions runs: Runs - def __init__(self, cache_token: bool = True) -> None: + def __init__(self, cache_token: bool = True, token_provider: Callable[[], str] | None = None) -> None: """Initializes a client instance with authenticated API access. Args: - cache_token (bool): If True, caches the authentication token. - Defaults to True. + cache_token: If True, caches the authentication token. Defaults to True. + Ignored when ``token_provider`` is supplied. + token_provider: Optional external token provider callable. When provided, + bypasses internal OAuth authentication entirely. The callable must + return a raw access token string (without the ``Bearer `` prefix). + When set, ``cache_token`` has no effect because the external provider + manages its own token lifecycle. Sets up resource accessors for applications, versions, and runs. """ try: - logger.trace("Initializing client with cache_token={}", cache_token) - self._api = Client.get_api_client(cache_token=cache_token) + logger.trace("Initializing client with cache_token={}, token_provider={}", cache_token, token_provider) + self._api = Client.get_api_client(cache_token=cache_token, token_provider=token_provider) self.applications: Applications = Applications(self._api) self.runs: Runs = Runs(self._api) self.versions: Versions = Versions(self._api) @@ -143,7 +126,7 @@ def me(self, nocache: bool = False) -> Me: aignx.codegen.exceptions.ApiException: If the API call fails. """ - @cached_operation(ttl=settings().me_cache_ttl, use_token=True) + @cached_operation(ttl=settings().me_cache_ttl, token_provider=self._api.token_provider) def me_with_retry() -> Me: return Retrying( # We are not using Tenacity annotations as settings can change at runtime retry=retry_if_exception_type(exception_types=RETRYABLE_EXCEPTIONS), @@ -177,7 +160,7 @@ def application(self, application_id: str, nocache: bool = False) -> Application Application: The application object. """ - @cached_operation(ttl=settings().application_cache_ttl, use_token=True) + @cached_operation(ttl=settings().application_cache_ttl, token_provider=self._api.token_provider) def application_with_retry(application_id: str) -> Application: return Retrying( retry=retry_if_exception_type(exception_types=RETRYABLE_EXCEPTIONS), @@ -234,7 +217,7 @@ def application_version( raise ValueError(message) # Make the API call with retry logic and caching - @cached_operation(ttl=settings().application_version_cache_ttl, use_token=True) + @cached_operation(ttl=settings().application_version_cache_ttl, token_provider=self._api.token_provider) def application_version_with_retry(application_id: str, version: str) -> ApplicationVersion: return Retrying( retry=retry_if_exception_type(exception_types=RETRYABLE_EXCEPTIONS), @@ -268,44 +251,63 @@ def run(self, run_id: str) -> Run: return Run(self._api, run_id) @staticmethod - def get_api_client(cache_token: bool = True) -> PublicApi: + def get_api_client(cache_token: bool = True, token_provider: Callable[[], str] | None = None) -> _AuthenticatedApi: """Create and configure an authenticated API client. API client instances are shared across all Client instances for efficient connection reuse. - Two separate instances are maintained: one for cached tokens and one for uncached tokens. + Three pools are maintained: cached-token, uncached-token, and external-provider (keyed by + the provider callable — callers should reuse a stable ``token_provider`` reference for + connection reuse). Args: - cache_token (bool): If True, caches the authentication token. - Defaults to True. + cache_token: If True, caches the authentication token. Defaults to True. + token_provider: Optional external token provider. When provided, bypasses + internal OAuth and uses this callable to obtain bearer tokens. Returns: - PublicApi: Configured API client with authentication token. + _AuthenticatedApi: Configured API client with authentication token. Raises: RuntimeError: If authentication fails. """ - # Return cached instance if available - if cache_token and Client._api_client_cached is not None: + # Check singleton caches first + if token_provider is not None: + if token_provider in Client._api_client_external: + return Client._api_client_external[token_provider] + elif cache_token and Client._api_client_cached is not None: return Client._api_client_cached - if not cache_token and Client._api_client_uncached is not None: + elif not cache_token and Client._api_client_uncached is not None: return Client._api_client_uncached - def token_provider() -> str: - return get_token(use_cache=cache_token) + # Resolve the effective token provider + effective_provider: Callable[[], str] = ( + token_provider if token_provider is not None else (lambda: get_token(use_cache=cache_token)) + ) + # Build the API client ca_file = os.getenv("REQUESTS_CA_BUNDLE") # point to .cer file of proxy if defined config = _OAuth2TokenProviderConfiguration( - host=settings().api_root, ssl_ca_cert=ca_file, token_provider=token_provider + host=settings().api_root, ssl_ca_cert=ca_file, token_provider=effective_provider ) config.proxy = getproxies().get("https") # use system proxy - client = ApiClient( - config, - ) + client = ApiClient(config) client.user_agent = user_agent() - api_client = PublicApi(client) - - # Cache the instance - if cache_token: + api_client = _AuthenticatedApi(client, effective_provider) + + # Store in the appropriate singleton cache. + # For external providers we use a simple bounded dict rather than LRU: + # switching providers is rare in practice, and a full clear is simpler + # than tracking access order while still bounding memory. + if token_provider is not None: + if len(Client._api_client_external) >= _MAX_EXTERNAL_CLIENTS: + logger.warning( + "External token provider cache exceeded {} entries; clearing to prevent resource leak. " + "Reuse a stable token_provider reference for optimal connection reuse.", + _MAX_EXTERNAL_CLIENTS, + ) + Client._api_client_external.clear() + Client._api_client_external[token_provider] = api_client + elif cache_token: Client._api_client_cached = api_client else: Client._api_client_uncached = api_client diff --git a/src/aignostics/platform/_operation_cache.py b/src/aignostics/platform/_operation_cache.py index 90b9cc8d7..e855f22bf 100644 --- a/src/aignostics/platform/_operation_cache.py +++ b/src/aignostics/platform/_operation_cache.py @@ -11,13 +11,16 @@ - Supports selective cache clearing by function """ +from __future__ import annotations + +import functools import hashlib import time import typing as t -from collections.abc import Callable -from typing import Any, ParamSpec, TypeVar +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar -from ._authentication import get_token +if TYPE_CHECKING: + from collections.abc import Callable # Cache storage for operation results _operation_cache: dict[str, tuple[Any, float]] = {} @@ -92,16 +95,19 @@ def cache_key_with_token(token: str, func_qualified_name: str, *args: object, ** def cached_operation( - ttl: int, *, use_token: bool = True, instance_attrs: tuple[str, ...] | None = None + ttl: int, + *, + token_provider: Callable[[], str] | None = None, + instance_attrs: tuple[str, ...] | None = None, ) -> Callable[[Callable[P, T]], Callable[P, T]]: """Caches the result of a function call for a specified time-to-live (TTL). Args: ttl (int): Time-to-live for the cache in seconds. - use_token (bool): If True, includes the authentication token in the cache key. - This is useful for Client methods that should cache per-user. - When use_token is True and no instance_attrs are specified, the 'self' - argument is excluded from the cache key to enable cache sharing across instances. + token_provider (Callable[[], str] | None): A callable returning the current + authentication token string. When provided, the token is included in the + cache key for per-user isolation. Pass ``None`` to omit the token from + the cache key. instance_attrs (tuple[str, ...] | None): Instance attributes to include in the cache key. This is useful for instance methods where caching should be per-instance based on specific attributes (e.g., 'run_id' for Run.details()). @@ -116,6 +122,7 @@ def cached_operation( """ def decorator(func: Callable[P, T]) -> Callable[P, T]: + @functools.wraps(func) def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # Check if nocache is requested and remove it from kwargs before passing to func nocache = kwargs.pop("nocache", False) @@ -132,8 +139,9 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: instance_values = tuple(getattr(instance, attr) for attr in instance_attrs) cache_args = instance_values + args[1:] - if use_token: - key = cache_key_with_token(get_token(True), func_qualified_name, *cache_args, **kwargs) + if token_provider is not None: + token = token_provider() + key = cache_key_with_token(token, func_qualified_name, *cache_args, **kwargs) else: key = cache_key(func_qualified_name, *cache_args, **kwargs) diff --git a/src/aignostics/platform/resources/applications.py b/src/aignostics/platform/resources/applications.py index f507ddf41..ab4075899 100644 --- a/src/aignostics/platform/resources/applications.py +++ b/src/aignostics/platform/resources/applications.py @@ -9,7 +9,6 @@ from operator import itemgetter import semver -from aignx.codegen.api.public_api import PublicApi from aignx.codegen.exceptions import NotFoundException, ServiceException from aignx.codegen.models import ApplicationReadResponse as Application from aignx.codegen.models import ApplicationReadShortResponse as ApplicationSummary @@ -26,6 +25,7 @@ from urllib3.exceptions import IncompleteRead, PoolError, ProtocolError, ProxyError from urllib3.exceptions import TimeoutError as Urllib3TimeoutError +from aignostics.platform._api import _AuthenticatedApi, _AuthenticatedResource from aignostics.platform._operation_cache import cached_operation from aignostics.platform._settings import settings from aignostics.platform.resources.utils import paginate @@ -60,20 +60,12 @@ def _log_retry_attempt(retry_state: RetryCallState) -> None: ) -class Versions: +class Versions(_AuthenticatedResource): """Resource class for managing application versions. Provides operations to list and retrieve application versions. """ - def __init__(self, api: PublicApi) -> None: - """Initializes the Versions resource with the API platform. - - Args: - api (PublicApi): The configured API platform. - """ - self._api = api - def list(self, application: Application | str, nocache: bool = False) -> builtins.list[VersionTuple]: """Find all versions for a specific application. @@ -92,7 +84,7 @@ def list(self, application: Application | str, nocache: bool = False) -> builtin """ application_id = application.application_id if isinstance(application, Application) else application - @cached_operation(ttl=settings().application_cache_ttl, use_token=True) + @cached_operation(ttl=settings().application_cache_ttl, token_provider=self._api.token_provider) def list_with_retry(app_id: str) -> Application: return Retrying( retry=retry_if_exception_type(exception_types=RETRYABLE_EXCEPTIONS), @@ -149,7 +141,7 @@ def details( raise ValueError(message) # Make the API call with retry logic and caching - @cached_operation(ttl=settings().application_version_cache_ttl, use_token=True) + @cached_operation(ttl=settings().application_version_cache_ttl, token_provider=self._api.token_provider) def details_with_retry(app_id: str, app_version: str) -> ApplicationVersion: return Retrying( retry=retry_if_exception_type(exception_types=RETRYABLE_EXCEPTIONS), @@ -224,19 +216,19 @@ def latest(self, application: Application | str, nocache: bool = False) -> Versi return sorted_versions[0] if sorted_versions else None -class Applications: +class Applications(_AuthenticatedResource): """Resource class for managing applications. Provides operations to list applications and access version resources. """ - def __init__(self, api: PublicApi) -> None: + def __init__(self, api: _AuthenticatedApi) -> None: """Initializes the Applications resource with the API platform. Args: - api (PublicApi): The configured API platform. + api (_AuthenticatedApi): The configured API platform. """ - self._api = api + super().__init__(api) self.versions: Versions = Versions(self._api) def details(self, application_id: str, nocache: bool = False) -> Application: @@ -257,7 +249,7 @@ def details(self, application_id: str, nocache: bool = False) -> Application: aignx.codegen.exceptions.ApiException: If the API call fails. """ - @cached_operation(ttl=settings().application_cache_ttl, use_token=True) + @cached_operation(ttl=settings().application_cache_ttl, token_provider=self._api.token_provider) def details_with_retry(application_id: str) -> Application: return Retrying( retry=retry_if_exception_type(exception_types=RETRYABLE_EXCEPTIONS), @@ -293,7 +285,7 @@ def list(self, nocache: bool = False) -> t.Iterator[ApplicationSummary]: # Create a wrapper function that applies retry logic and caching to each API call # Caching at this level ensures having a fresh iterator on cache hits - @cached_operation(ttl=settings().application_cache_ttl, use_token=True) + @cached_operation(ttl=settings().application_cache_ttl, token_provider=self._api.token_provider) def list_with_retry(**kwargs: object) -> builtins.list[ApplicationSummary]: return Retrying( retry=retry_if_exception_type(exception_types=RETRYABLE_EXCEPTIONS), diff --git a/src/aignostics/platform/resources/runs.py b/src/aignostics/platform/resources/runs.py index f2b7a49b9..d81416cd4 100644 --- a/src/aignostics/platform/resources/runs.py +++ b/src/aignostics/platform/resources/runs.py @@ -12,7 +12,6 @@ from time import sleep from typing import Any, cast -from aignx.codegen.api.public_api import PublicApi from aignx.codegen.exceptions import NotFoundException, ServiceException from aignx.codegen.models import ( CustomMetadataUpdateRequest, @@ -49,6 +48,7 @@ from urllib3.exceptions import IncompleteRead, PoolError, ProtocolError, ProxyError from urllib3.exceptions import TimeoutError as Urllib3TimeoutError +from aignostics.platform._api import _AuthenticatedApi, _AuthenticatedResource from aignostics.platform._operation_cache import cached_operation, operation_cache_clear from aignostics.platform._sdk_metadata import ( build_item_sdk_metadata, @@ -105,20 +105,20 @@ class DownloadTimeoutError(RuntimeError): """Exception raised when the download operation exceeds its timeout.""" -class Run: +class Run(_AuthenticatedResource): """Represents a single application run. Provides operations to check status, retrieve results, and download artifacts. """ - def __init__(self, api: PublicApi, run_id: str) -> None: - """Initializes an Run instance. + def __init__(self, api: _AuthenticatedApi, run_id: str) -> None: + """Initializes a Run instance. Args: - api (PublicApi): The configured API client. + api (_AuthenticatedApi): The configured API client. run_id (str): The ID of the application run. """ - self._api = api + super().__init__(api) self.run_id = run_id @classmethod @@ -156,7 +156,7 @@ def details(self, nocache: bool = False, hide_platform_queue_position: bool = Fa Exception: If the API request fails. """ - @cached_operation(ttl=settings().run_cache_ttl, use_token=True) + @cached_operation(ttl=settings().run_cache_ttl, token_provider=self._api.token_provider) def details_with_retry(run_id: str) -> RunData: def _fetch() -> RunData: return Retrying( @@ -243,7 +243,7 @@ def results( # Create a wrapper function that applies retry logic and caching to each API call # Caching at this level ensures having a fresh iterator on cache hits - @cached_operation(ttl=settings().run_cache_ttl, use_token=True) + @cached_operation(ttl=settings().run_cache_ttl, token_provider=self._api.token_provider) def results_with_retry(run_id: str, **kwargs: object) -> list[ItemResultData]: return Retrying( retry=retry_if_exception_type(exception_types=RETRYABLE_EXCEPTIONS), @@ -498,20 +498,12 @@ def __str__(self) -> str: ) -class Runs: +class Runs(_AuthenticatedResource): """Resource class for managing application runs. Provides operations to submit, find, and retrieve runs. """ - def __init__(self, api: PublicApi) -> None: - """Initializes the Runs resource with the API client. - - Args: - api (PublicApi): The configured API client. - """ - self._api = api - def __call__(self, run_id: str) -> Run: """Retrieves an Run instance for an existing run. @@ -680,7 +672,7 @@ def list_data( # noqa: PLR0913, PLR0917 ) raise ValueError(message) - @cached_operation(ttl=settings().run_cache_ttl, use_token=True) + @cached_operation(ttl=settings().run_cache_ttl, token_provider=self._api.token_provider) def list_data_with_retry(**kwargs: object) -> builtins.list[RunData]: return Retrying( retry=retry_if_exception_type(exception_types=RETRYABLE_EXCEPTIONS), diff --git a/tests/aignostics/platform/client_cache_test.py b/tests/aignostics/platform/client_cache_test.py index a3f184357..bea8011ba 100644 --- a/tests/aignostics/platform/client_cache_test.py +++ b/tests/aignostics/platform/client_cache_test.py @@ -290,8 +290,8 @@ def test_different_tokens_use_different_cache_entries(mock_settings: MagicMock, mock_me_response_2 = {"user_id": "user-2", "org_id": "org-2"} # Client with token-1 + mock_api_client.token_provider = lambda: "token-1" with ( - patch("aignostics.platform._operation_cache.get_token", return_value="token-1"), patch("aignostics.platform._client.get_token", return_value="token-1"), patch("aignostics.platform._client.Client.get_api_client", return_value=mock_api_client), ): @@ -304,8 +304,8 @@ def test_different_tokens_use_different_cache_entries(mock_settings: MagicMock, assert mock_api_client.get_me_v1_me_get.call_count == 1 # Client with token-2 + mock_api_client.token_provider = lambda: "token-2" with ( - patch("aignostics.platform._operation_cache.get_token", return_value="token-2"), patch("aignostics.platform._client.get_token", return_value="token-2"), patch("aignostics.platform._client.Client.get_api_client", return_value=mock_api_client), ): @@ -330,13 +330,15 @@ def test_token_change_invalidates_cache(mock_settings: MagicMock, mock_api_clien mock_me_response_1 = {"user_id": "user-1", "org_id": "org-1"} mock_me_response_2 = {"user_id": "user-2", "org_id": "org-2"} + # Use a mutable container so the token provider can be changed mid-test + token_holder = ["token-1"] + mock_api_client.token_provider = lambda: token_holder[0] + # First call with token-1 with ( - patch("aignostics.platform._operation_cache.get_token") as mock_get_token, patch("aignostics.platform._client.get_token", return_value="token-1"), patch("aignostics.platform._client.Client.get_api_client", return_value=mock_api_client), ): - mock_get_token.return_value = "token-1" client = Client(cache_token=False) client._api = mock_api_client mock_api_client.get_me_v1_me_get.return_value = mock_me_response_1 @@ -346,7 +348,7 @@ def test_token_change_invalidates_cache(mock_settings: MagicMock, mock_api_clien assert mock_api_client.get_me_v1_me_get.call_count == 1 # Second call with token-2 (simulating token refresh) - mock_get_token.return_value = "token-2" + token_holder[0] = "token-2" mock_api_client.get_me_v1_me_get.return_value = mock_me_response_2 result2 = client.me() @@ -361,13 +363,13 @@ def test_same_token_reuses_cache(mock_settings: MagicMock, mock_api_client: Magi Multiple clients with the same token should share cached values. """ mock_me_response = {"user_id": "test-user", "org_id": "test-org"} + mock_api_client.token_provider = lambda: "token-123" - # First client with token-123 with ( - patch("aignostics.platform._operation_cache.get_token", return_value="token-123"), patch("aignostics.platform._client.get_token", return_value="token-123"), patch("aignostics.platform._client.Client.get_api_client", return_value=mock_api_client), ): + # First client with token-123 client1 = Client(cache_token=False) client1._api = mock_api_client mock_api_client.get_me_v1_me_get.return_value = mock_me_response @@ -376,12 +378,7 @@ def test_same_token_reuses_cache(mock_settings: MagicMock, mock_api_client: Magi assert result1 == mock_me_response assert mock_api_client.get_me_v1_me_get.call_count == 1 - # Second client with same token-123 - with ( - patch("aignostics.platform._operation_cache.get_token", return_value="token-123"), - patch("aignostics.platform._client.get_token", return_value="token-123"), - patch("aignostics.platform._client.Client.get_api_client", return_value=mock_api_client), - ): + # Second client with same token-123 client2 = Client(cache_token=False) client2._api = mock_api_client @@ -544,9 +541,9 @@ def test_cache_is_class_level(mock_settings: MagicMock, mock_api_client: MagicMo The _operation_cache should be a class variable, not an instance variable. """ mock_me_response = {"user_id": "test-user", "org_id": "test-org"} + mock_api_client.token_provider = lambda: "token-123" with ( - patch("aignostics.platform._operation_cache.get_token", return_value="token-123"), patch("aignostics.platform._client.get_token", return_value="token-123"), patch("aignostics.platform._client.Client.get_api_client", return_value=mock_api_client), ): @@ -575,9 +572,9 @@ def test_cache_cleared_affects_all_clients(mock_settings: MagicMock, mock_api_cl Since cache is class-level, clearing it should affect all instances. """ mock_me_response = {"user_id": "test-user", "org_id": "test-org"} + mock_api_client.token_provider = lambda: "token-123" with ( - patch("aignostics.platform._operation_cache.get_token", return_value="token-123"), patch("aignostics.platform._client.get_token", return_value="token-123"), patch("aignostics.platform._client.Client.get_api_client", return_value=mock_api_client), ): @@ -685,9 +682,9 @@ def test_cache_with_very_long_token(mock_settings: MagicMock, mock_api_client: M Cache key should hash long tokens to keep key size manageable. """ long_token = "x" * 10000 # Very long token + mock_api_client.token_provider = lambda: long_token with ( - patch("aignostics.platform._operation_cache.get_token", return_value=long_token), patch("aignostics.platform._client.get_token", return_value=long_token), patch("aignostics.platform._client.Client.get_api_client", return_value=mock_api_client), ): @@ -707,19 +704,19 @@ class TestCacheIntegrationWithAuthentication: @pytest.mark.unit @staticmethod - def test_cache_uses_current_token_from_get_token(mock_settings: MagicMock, mock_api_client: MagicMock) -> None: - """Test that cache always uses the current token from get_token(). + def test_cache_uses_current_token_from_token_provider(mock_settings: MagicMock, mock_api_client: MagicMock) -> None: + """Test that cache always uses the current token from the token provider. - The cache should call get_token() on each operation to get the current token. + The cache should call the token provider on each operation to get the current token. """ mock_me_response = {"user_id": "test-user", "org_id": "test-org"} + token_holder = ["token-1"] + mock_api_client.token_provider = lambda: token_holder[0] with ( - patch("aignostics.platform._operation_cache.get_token") as mock_get_token, patch("aignostics.platform._client.get_token", return_value="token-1"), patch("aignostics.platform._client.Client.get_api_client", return_value=mock_api_client), ): - mock_get_token.return_value = "token-1" mock_api_client.get_me_v1_me_get.return_value = mock_me_response client = Client(cache_token=False) @@ -727,10 +724,9 @@ def test_cache_uses_current_token_from_get_token(mock_settings: MagicMock, mock_ # First call with token-1 client.me() - assert mock_get_token.call_count >= 1 # Change token - mock_get_token.return_value = "token-2" + token_holder[0] = "token-2" mock_me_response_2 = {"user_id": "test-user-2", "org_id": "test-org-2"} mock_api_client.get_me_v1_me_get.return_value = mock_me_response_2 @@ -748,14 +744,13 @@ def test_cache_with_token_refresh_scenario(mock_settings: MagicMock, mock_api_cl """ mock_me_response_1 = {"user_id": "user-1", "org_id": "org-1"} mock_me_response_2 = {"user_id": "user-2", "org_id": "org-2"} + token_holder = ["token-initial"] + mock_api_client.token_provider = lambda: token_holder[0] with ( - patch("aignostics.platform._operation_cache.get_token") as mock_get_token, patch("aignostics.platform._client.get_token", return_value="token-initial"), patch("aignostics.platform._client.Client.get_api_client", return_value=mock_api_client), ): - # Initial token - mock_get_token.return_value = "token-initial" mock_api_client.get_me_v1_me_get.return_value = mock_me_response_1 client = Client(cache_token=False) @@ -772,7 +767,7 @@ def test_cache_with_token_refresh_scenario(mock_settings: MagicMock, mock_api_cl assert mock_api_client.get_me_v1_me_get.call_count == 1 # Token refresh happens - mock_get_token.return_value = "token-refreshed" + token_holder[0] = "token-refreshed" mock_api_client.get_me_v1_me_get.return_value = mock_me_response_2 # Call 3: New token means cache miss, fetches new data diff --git a/tests/aignostics/platform/client_token_provider_test.py b/tests/aignostics/platform/client_token_provider_test.py index e787f2110..eb8c923fa 100644 --- a/tests/aignostics/platform/client_token_provider_test.py +++ b/tests/aignostics/platform/client_token_provider_test.py @@ -1,10 +1,12 @@ """Tests for the token provider configuration and its integration with the client.""" +from collections.abc import Callable from unittest.mock import Mock, patch import pytest -from aignostics.platform._client import Client, _OAuth2TokenProviderConfiguration +from aignostics.platform._api import _AuthenticatedApi, _OAuth2TokenProviderConfiguration +from aignostics.platform._client import Client @pytest.fixture(autouse=True) @@ -12,6 +14,16 @@ def _clear_api_client_cache() -> None: """Clear the API client cache before each test to ensure test isolation.""" Client._api_client_cached = None Client._api_client_uncached = None + Client._api_client_external.clear() + + +def _make_provider(token: str) -> Callable[[], str]: + """Create a token provider that returns the given token string.""" + + def provider() -> str: + return token + + return provider @pytest.mark.unit @@ -38,13 +50,11 @@ def test_client_passes_token_provider() -> None: with ( patch("aignostics.platform._client.get_token", return_value="client-token"), patch("aignostics.platform._client.ApiClient") as api_client_mock, - patch("aignostics.platform._client.PublicApi") as public_api_mock, ): Client(cache_token=False) config_used = api_client_mock.call_args[0][0] assert isinstance(config_used, _OAuth2TokenProviderConfiguration) assert config_used.token_provider() == "client-token" - public_api_mock.assert_called() @pytest.mark.unit @@ -53,12 +63,202 @@ def test_client_me_calls_api() -> None: with ( patch("aignostics.platform._client.get_token", return_value="client-token"), patch("aignostics.platform._client.ApiClient"), - patch("aignostics.platform._client.PublicApi") as public_api_mock, + patch.object(_AuthenticatedApi, "__init__", lambda self, *a, **kw: None) as _, ): + client = Client() + # Manually set up the mock api on the client api_instance = Mock() api_instance.get_me_v1_me_get.return_value = "me-info" - public_api_mock.return_value = api_instance - client = Client() + api_instance.token_provider = lambda: "client-token" + client._api = api_instance result = client.me() assert result == "me-info" api_instance.get_me_v1_me_get.assert_called_once() + + +# --- External token provider tests --- + + +@pytest.mark.unit +def test_client_with_external_token_provider() -> None: + """Test that Client accepts an external token provider and initializes successfully.""" + my_provider = _make_provider("my-m2m-token") + + with ( + patch("aignostics.platform._client.ApiClient") as api_client_mock, + patch.object(_AuthenticatedApi, "__init__", lambda self, *a, **kw: None), + ): + Client(token_provider=my_provider) + + # Verify the config received the external provider + config_used = api_client_mock.call_args[0][0] + assert isinstance(config_used, _OAuth2TokenProviderConfiguration) + assert config_used.token_provider is my_provider + + +@pytest.mark.unit +def test_external_provider_bypasses_oauth() -> None: + """Test that get_token is NOT called when an external token provider is used.""" + my_provider = _make_provider("external-token") + + with ( + patch("aignostics.platform._client.get_token") as mock_get_token, + patch("aignostics.platform._client.ApiClient"), + patch.object(_AuthenticatedApi, "__init__", lambda self, *a, **kw: None), + ): + Client(token_provider=my_provider) + mock_get_token.assert_not_called() + + +@pytest.mark.unit +def test_external_provider_token_in_auth_header() -> None: + """Test that the external provider's token appears in the Authorization header.""" + my_provider = _make_provider("bearer-value-123") + + with ( + patch("aignostics.platform._client.ApiClient") as api_client_mock, + patch.object(_AuthenticatedApi, "__init__", lambda self, *a, **kw: None), + ): + Client(token_provider=my_provider) + config_used = api_client_mock.call_args[0][0] + auth = config_used.auth_settings() + assert auth["OAuth2AuthorizationCodeBearer"]["value"] == "Bearer bearer-value-123" + + +@pytest.mark.unit +def test_external_provider_singleton_isolation() -> None: + """Test that different providers get different API client instances.""" + provider_a = _make_provider("token-a") + provider_b = _make_provider("token-b") + + with ( + patch("aignostics.platform._client.ApiClient"), + patch.object(_AuthenticatedApi, "__init__", lambda self, *a, **kw: None), + ): + client_a = Client(token_provider=provider_a) + client_b = Client(token_provider=provider_b) + + assert client_a._api is not client_b._api + + +@pytest.mark.unit +def test_external_provider_same_provider_reused() -> None: + """Test that the same provider callable reuses the cached API client.""" + my_provider = _make_provider("reuse-token") + + with ( + patch("aignostics.platform._client.ApiClient"), + patch.object(_AuthenticatedApi, "__init__", lambda self, *a, **kw: None), + ): + client1 = Client(token_provider=my_provider) + client2 = Client(token_provider=my_provider) + + assert client1._api is client2._api + + +@pytest.mark.unit +def test_cache_token_false_with_external_provider_is_allowed() -> None: + """Test that cache_token=False is silently ignored when token_provider is set.""" + with ( + patch("aignostics.platform._client.get_token") as mock_get_token, + patch("aignostics.platform._client.ApiClient"), + patch.object(_AuthenticatedApi, "__init__", lambda self, *a, **kw: None), + ): + # Should not raise; cache_token is irrelevant when using an external provider + Client(token_provider=_make_provider("token"), cache_token=False) + mock_get_token.assert_not_called() + + +@pytest.mark.unit +def test_cache_token_default_with_external_provider_ok() -> None: + """Test that default cache_token=True works with an external token provider.""" + with ( + patch("aignostics.platform._client.ApiClient"), + patch.object(_AuthenticatedApi, "__init__", lambda self, *a, **kw: None), + ): + # Should not raise + Client(token_provider=_make_provider("token")) + + +@pytest.mark.unit +def test_falsy_token_provider_logs_warning() -> None: + """Test that a warning is logged when token_provider returns an empty string.""" + empty_provider = _make_provider("") + config = _OAuth2TokenProviderConfiguration(host="https://dummy", token_provider=empty_provider) + + with patch("aignostics.platform._api.logger") as mock_logger: + result = config.auth_settings() + + assert result == {} + mock_logger.warning.assert_called_once() + warning_msg = mock_logger.warning.call_args[0][0] + assert "empty or None token" in warning_msg + + +@pytest.mark.unit +def test_none_token_provider_no_warning() -> None: + """Test that no warning is logged when token_provider is not set (None).""" + config = _OAuth2TokenProviderConfiguration(host="https://dummy") + + with patch("aignostics.platform._api.logger") as mock_logger: + result = config.auth_settings() + + assert result == {} + mock_logger.warning.assert_not_called() + + +@pytest.mark.unit +def test_external_provider_cache_bounded() -> None: + """Test that _api_client_external is bounded to _MAX_EXTERNAL_CLIENTS entries.""" + from aignostics.platform._client import _MAX_EXTERNAL_CLIENTS + + with ( + patch("aignostics.platform._client.ApiClient"), + patch.object(_AuthenticatedApi, "__init__", lambda self, *a, **kw: None), + patch("aignostics.platform._client.logger") as mock_logger, + ): + # Create more clients than the limit, each with a distinct provider + for i in range(_MAX_EXTERNAL_CLIENTS + 5): + Client(token_provider=_make_provider(f"token-{i}")) + + # Cache must not exceed the limit (cleared + 1 new entry after overflow) + assert len(Client._api_client_external) <= _MAX_EXTERNAL_CLIENTS + + # A warning should have been logged when the cache was cleared + mock_logger.warning.assert_called() + warning_msg = mock_logger.warning.call_args[0][0] + assert "resource leak" in warning_msg + + +# --- Integration tests --- + + +@pytest.mark.integration +def test_external_provider_wires_through_to_resources() -> None: + """Integration: Client(token_provider=...) wires through real constructors. + + Verifies that an external token provider flows through Client → _AuthenticatedApi → + resource classes (Applications, Runs) without any AttributeError. Only the + ApiClient constructor is mocked to avoid real HTTP calls. + """ + my_provider = _make_provider("integration-test-token") + + with patch("aignostics.platform._client.ApiClient") as mock_api_client_cls: + # Create client with external provider — real _AuthenticatedApi and + # resource constructors (_AuthenticatedResource.__init__) run. + client = Client(token_provider=my_provider) + + # Verify the provider is wired through the real _AuthenticatedApi + assert isinstance(client._api, _AuthenticatedApi) + assert client._api.token_provider is my_provider + + # Verify resources received the same _AuthenticatedApi instance + assert client.applications._api is client._api + assert client.runs._api is client._api + assert client.versions._api is client._api + + # Verify the Configuration passed to ApiClient produces the correct auth header + config = mock_api_client_cls.call_args[0][0] + assert isinstance(config, _OAuth2TokenProviderConfiguration) + auth = config.auth_settings() + assert auth["OAuth2AuthorizationCodeBearer"]["value"] == "Bearer integration-test-token" diff --git a/tests/aignostics/platform/conftest.py b/tests/aignostics/platform/conftest.py index 4a4d7df20..c4e8e5275 100644 --- a/tests/aignostics/platform/conftest.py +++ b/tests/aignostics/platform/conftest.py @@ -5,6 +5,7 @@ import pytest +from aignostics.platform._api import _AuthenticatedApi from aignostics.platform._client import Client from aignostics.platform._operation_cache import _operation_cache from aignostics.platform._service import Service @@ -49,9 +50,9 @@ def mock_api_client() -> MagicMock: """Provide a mock API client. Returns: - MagicMock: A mock of the PublicApi client. + MagicMock: A mock of the _AuthenticatedApi client. """ - return MagicMock() + return MagicMock(spec=_AuthenticatedApi) @pytest.fixture(autouse=True) @@ -64,11 +65,13 @@ def clear_cache() -> t.Generator[None, None, None]: _operation_cache.clear() Client._api_client_cached = None Client._api_client_uncached = None + Client._api_client_external.clear() Service._http_pool = None yield _operation_cache.clear() Client._api_client_cached = None Client._api_client_uncached = None + Client._api_client_external.clear() Service._http_pool = None @@ -88,6 +91,7 @@ def client_with_mock_api(mock_api_client: MagicMock) -> t.Generator[Client, None "exp": 9999999999, "iss": "test-issuer", } + mock_api_client.token_provider = lambda: "test-token-123" with ( patch("aignostics.platform._client.get_token", return_value="test-token-123"), patch("aignostics.platform._authentication.verify_and_decode_token", return_value=mock_token_claims), diff --git a/tests/aignostics/platform/nocache_test.py b/tests/aignostics/platform/nocache_test.py index 364b8260e..7a50a9c3a 100644 --- a/tests/aignostics/platform/nocache_test.py +++ b/tests/aignostics/platform/nocache_test.py @@ -25,7 +25,7 @@ def test_decorator_without_nocache_uses_cache() -> None: """Test that decorated function uses cache by default (nocache=False).""" call_count = 0 - @cached_operation(ttl=60, use_token=False) + @cached_operation(ttl=60) def test_func() -> int: nonlocal call_count call_count += 1 @@ -47,7 +47,7 @@ def test_decorator_with_nocache_false_uses_cache() -> None: """Test that nocache=False explicitly uses cache.""" call_count = 0 - @cached_operation(ttl=60, use_token=False) + @cached_operation(ttl=60) def test_func() -> int: nonlocal call_count call_count += 1 @@ -69,7 +69,7 @@ def test_decorator_with_nocache_true_skips_reading_cache() -> None: """Test that nocache=True skips reading from cache.""" call_count = 0 - @cached_operation(ttl=60, use_token=False) + @cached_operation(ttl=60) def test_func() -> int: nonlocal call_count call_count += 1 @@ -91,7 +91,7 @@ def test_decorator_with_nocache_true_still_writes_to_cache() -> None: """Test that nocache=True still writes the result to cache.""" call_count = 0 - @cached_operation(ttl=60, use_token=False) + @cached_operation(ttl=60) def test_func() -> int: nonlocal call_count call_count += 1 @@ -118,7 +118,7 @@ def test_decorator_nocache_parameter_not_passed_to_function() -> None: """Test that nocache parameter is intercepted and not passed to the decorated function.""" received_kwargs = {} - @cached_operation(ttl=60, use_token=False) + @cached_operation(ttl=60) def test_func(**kwargs: bool) -> dict: nonlocal received_kwargs received_kwargs = kwargs @@ -136,7 +136,7 @@ def test_decorator_with_nocache_and_other_kwargs() -> None: """Test that nocache works alongside other keyword arguments.""" call_count = 0 - @cached_operation(ttl=60, use_token=False) + @cached_operation(ttl=60) def test_func(param1: str = "default", param2: int = 0) -> tuple: nonlocal call_count call_count += 1 @@ -163,7 +163,7 @@ def test_decorator_nocache_with_different_cache_keys() -> None: """Test that nocache respects different cache keys (different args).""" call_count = 0 - @cached_operation(ttl=60, use_token=False) + @cached_operation(ttl=60) def test_func(key: str) -> tuple: nonlocal call_count call_count += 1 @@ -512,7 +512,7 @@ def test_nocache_with_expired_cache_entry() -> None: """Test nocache behavior when cache entry has expired.""" call_count = 0 - @cached_operation(ttl=1, use_token=False) # 1 second TTL + @cached_operation(ttl=1) # 1 second TTL def test_func() -> int: nonlocal call_count call_count += 1 @@ -537,7 +537,7 @@ def test_nocache_clears_expired_entry_before_writing_new() -> None: """Test that nocache properly handles expired entries.""" call_count = 0 - @cached_operation(ttl=1, use_token=False) + @cached_operation(ttl=1) def test_func() -> int: nonlocal call_count call_count += 1 @@ -565,7 +565,7 @@ def test_multiple_consecutive_nocache_calls() -> None: """Test multiple consecutive calls with nocache=True.""" call_count = 0 - @cached_operation(ttl=60, use_token=False) + @cached_operation(ttl=60) def test_func() -> int: nonlocal call_count call_count += 1 @@ -594,7 +594,7 @@ def test_nocache_interleaved_with_normal_calls() -> None: """Test interleaving nocache=True with normal cached calls.""" call_count = 0 - @cached_operation(ttl=60, use_token=False) + @cached_operation(ttl=60) def test_func() -> int: nonlocal call_count call_count += 1 @@ -640,7 +640,7 @@ def test_nocache_after_cache_clear() -> None: """Test that nocache works correctly after cache has been cleared.""" call_count = 0 - @cached_operation(ttl=60, use_token=False) + @cached_operation(ttl=60) def test_func() -> int: nonlocal call_count call_count += 1 @@ -667,7 +667,7 @@ def test_cache_clear_removes_nocache_populated_entries() -> None: """Test that cache clear removes entries populated with nocache=True.""" call_count = 0 - @cached_operation(ttl=60, use_token=False) + @cached_operation(ttl=60) def test_func() -> int: nonlocal call_count call_count += 1 diff --git a/tests/aignostics/platform/resources/applications_test.py b/tests/aignostics/platform/resources/applications_test.py index 2cf1ab60c..4036292cf 100644 --- a/tests/aignostics/platform/resources/applications_test.py +++ b/tests/aignostics/platform/resources/applications_test.py @@ -7,9 +7,9 @@ from unittest.mock import Mock import pytest -from aignx.codegen.api.public_api import PublicApi from aignx.codegen.models.application_read_response import ApplicationReadResponse +from aignostics.platform._api import _AuthenticatedApi from aignostics.platform.resources.applications import Applications, Versions from aignostics.platform.resources.utils import PAGE_SIZE @@ -23,7 +23,10 @@ def mock_api() -> Mock: Returns: Mock: A mock instance of ExternalsApi. """ - return Mock(spec=PublicApi) + api = Mock(spec=_AuthenticatedApi) + api.token_provider = lambda: "test-token" + api.api_client = Mock() + return api @pytest.fixture diff --git a/tests/aignostics/platform/resources/runs_test.py b/tests/aignostics/platform/resources/runs_test.py index cb4749eda..43fcec767 100644 --- a/tests/aignostics/platform/resources/runs_test.py +++ b/tests/aignostics/platform/resources/runs_test.py @@ -7,7 +7,6 @@ from unittest.mock import Mock import pytest -from aignx.codegen.api.public_api import PublicApi from aignx.codegen.models import ( InputArtifactCreationRequest, ItemCreationRequest, @@ -16,6 +15,7 @@ RunReadResponse, ) +from aignostics.platform._api import _AuthenticatedApi from aignostics.platform.resources.runs import LIST_APPLICATION_RUNS_MAX_PAGE_SIZE, Run, Runs from aignostics.platform.resources.utils import PAGE_SIZE @@ -27,7 +27,10 @@ def mock_api() -> Mock: Returns: Mock: A mock instance of ExternalsApi. """ - return Mock(spec=PublicApi) + api = Mock(spec=_AuthenticatedApi) + api.token_provider = lambda: "test-token" + api.api_client = Mock() + return api @pytest.fixture