Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/a2a/client/auth/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ async def get_credentials(
session_id = context.state['sessionId']
return self._store.get(session_id, {}).get(security_scheme_name)

async def set_credentials(
def set_credentials(
self, session_id: str, security_scheme_name: str, credential: str
) -> None:
"""Method to populate the store."""
Expand Down
134 changes: 78 additions & 56 deletions src/a2a/client/auth/interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
BeforeArgs,
ClientCallInterceptor,
)
from a2a.types.a2a_pb2 import SecurityScheme

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -35,62 +36,83 @@ async def before(self, args: BeforeArgs) -> None:

for requirement in agent_card.security_requirements:
for scheme_name in requirement.schemes:
credential = await self._credential_service.get_credentials(
scheme_name, args.context
)
if credential and scheme_name in agent_card.security_schemes:
scheme = agent_card.security_schemes[scheme_name]

if args.context is None:
args.context = ClientCallContext()

if args.context.service_parameters is None:
args.context.service_parameters = {}

# HTTP Bearer authentication
if (
scheme.HasField('http_auth_security_scheme')
and scheme.http_auth_security_scheme.scheme.lower()
== 'bearer'
):
args.context.service_parameters['Authorization'] = (
f'Bearer {credential}'
)
logger.debug(
"Added Bearer token for scheme '%s'.",
scheme_name,
)
return

# OAuth2 and OIDC schemes are implicitly Bearer
if scheme.HasField(
'oauth2_security_scheme'
) or scheme.HasField('open_id_connect_security_scheme'):
args.context.service_parameters['Authorization'] = (
f'Bearer {credential}'
)
logger.debug(
"Added Bearer token for scheme '%s'.",
scheme_name,
)
return

# API Key in Header
if (
scheme.HasField('api_key_security_scheme')
and scheme.api_key_security_scheme.location.lower()
== 'header'
):
args.context.service_parameters[
scheme.api_key_security_scheme.name
] = credential
logger.debug(
"Added API Key Header for scheme '%s'.",
scheme_name,
)
return

# Note: Other cases like API keys in query/cookie are not handled and will be skipped.
if await self._apply_credential(args, scheme_name):
return

async def _apply_credential(
self, args: BeforeArgs, scheme_name: str
) -> bool:
"""Fetches and applies a credential for a single scheme. Returns True if request should stop."""
agent_card = args.agent_card
credential = await self._credential_service.get_credentials(
scheme_name, args.context
)
if not credential or scheme_name not in agent_card.security_schemes:
return False

scheme = agent_card.security_schemes[scheme_name]
self._ensure_context(args)
context = args.context
if context is None:
return False
params = context.service_parameters
if params is None:
return False
if self._apply_bearer(params, scheme, scheme_name, credential):
return True
return self._apply_api_key(params, scheme, scheme_name, credential)

def _ensure_context(self, args: BeforeArgs) -> None:
"""Ensures the client call context and service parameters exist."""
if args.context is None:
args.context = ClientCallContext()
if args.context.service_parameters is None:
args.context.service_parameters = {}

def _apply_bearer(
self,
service_parameters: dict[str, str],
scheme: SecurityScheme,
scheme_name: str,
credential: str,
) -> bool:
"""Applies Bearer token for HTTP Bearer, OAuth2, or OIDC schemes. Returns True if applied."""
is_http_bearer = (
scheme.HasField('http_auth_security_scheme')
and scheme.http_auth_security_scheme.scheme.lower() == 'bearer'
)
is_oauth2_or_oidc = scheme.HasField(
'oauth2_security_scheme'
) or scheme.HasField('open_id_connect_security_scheme')

if is_http_bearer or is_oauth2_or_oidc:
service_parameters['Authorization'] = f'Bearer {credential}'
logger.debug(
"Added Bearer token for scheme '%s'.",
scheme_name,
)
return True
return False

def _apply_api_key(
self,
service_parameters: dict[str, str],
scheme: SecurityScheme,
scheme_name: str,
credential: str,
) -> bool:
"""Applies API Key header. Returns True if applied."""
if (
scheme.HasField('api_key_security_scheme')
and scheme.api_key_security_scheme.location.lower() == 'header'
):
service_parameters[scheme.api_key_security_scheme.name] = credential
logger.debug(
"Added API Key Header for scheme '%s'.",
scheme_name,
)
return True
return False

async def after(self, args: AfterArgs) -> None:
"""Invoked after the method is executed."""
4 changes: 2 additions & 2 deletions src/a2a/client/transports/jsonrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,33 @@
from google.protobuf import json_format
from jsonrpc.jsonrpc2 import JSONRPC20Request, JSONRPC20Response

from a2a.client.client import ClientCallContext
from a2a.client.errors import A2AClientError
from a2a.client.transports.base import ClientTransport
from a2a.client.transports.http_helpers import (
get_http_args,
send_http_request,
send_http_stream_request,
)
from a2a.types.a2a_pb2 import (
AgentCard,
CancelTaskRequest,
DeleteTaskPushNotificationConfigRequest,
GetExtendedAgentCardRequest,
GetTaskPushNotificationConfigRequest,
GetTaskRequest,
ListTaskPushNotificationConfigsRequest,
ListTaskPushNotificationConfigsResponse,
ListTasksRequest,
ListTasksResponse,
SendMessageRequest,
SendMessageResponse,
StreamResponse,
SubscribeToTaskRequest,
Task,
TaskPushNotificationConfig,
)
from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP
from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP, A2AError

Check notice on line 38 in src/a2a/client/transports/jsonrpc.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/rest.py (11-37)
from a2a.utils.telemetry import SpanKind, trace_class


Expand Down Expand Up @@ -315,7 +315,7 @@
"""Closes the httpx client."""
await self.httpx_client.aclose()

def _create_jsonrpc_error(self, error_dict: dict[str, Any]) -> Exception:
def _create_jsonrpc_error(self, error_dict: dict[str, Any]) -> A2AError:
"""Creates the appropriate A2AError from a JSON-RPC error dictionary."""
code = error_dict.get('code')
message = error_dict.get('message', str(error_dict))
Expand Down
46 changes: 31 additions & 15 deletions src/a2a/client/transports/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,33 @@

from google.protobuf.json_format import MessageToDict, Parse, ParseDict

from a2a.client.client import ClientCallContext
from a2a.client.errors import A2AClientError
from a2a.client.transports.base import ClientTransport
from a2a.client.transports.http_helpers import (
get_http_args,
send_http_request,
send_http_stream_request,
)
from a2a.types.a2a_pb2 import (
AgentCard,
CancelTaskRequest,
DeleteTaskPushNotificationConfigRequest,
GetExtendedAgentCardRequest,
GetTaskPushNotificationConfigRequest,
GetTaskRequest,
ListTaskPushNotificationConfigsRequest,
ListTaskPushNotificationConfigsResponse,
ListTasksRequest,
ListTasksResponse,
SendMessageRequest,
SendMessageResponse,
StreamResponse,
SubscribeToTaskRequest,
Task,
TaskPushNotificationConfig,
)
from a2a.utils.errors import A2A_REASON_TO_ERROR, MethodNotFoundError
from a2a.utils.errors import A2A_REASON_TO_ERROR, A2AError, MethodNotFoundError

Check notice on line 37 in src/a2a/client/transports/rest.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/jsonrpc.py (12-38)
from a2a.utils.telemetry import SpanKind, trace_class


Expand Down Expand Up @@ -64,24 +64,40 @@
# We extract the first `ErrorInfo` object because it contains the
# specific `reason` code needed to map this back to a Python A2AError.
for d in details:
if (
isinstance(d, dict)
and d.get('@type') == 'type.googleapis.com/google.rpc.ErrorInfo'
):
reason = d.get('reason')
metadata = d.get('metadata') or {}
if isinstance(reason, str):
exception_cls = A2A_REASON_TO_ERROR.get(reason)
if exception_cls:
exc = exception_cls(message)
if metadata:
exc.data = metadata
return exc
exc = _extract_error_info(d, message)
if exc is not None:
return exc
if _is_error_info(d):
break

return None


def _is_error_info(d: Any) -> bool:
"""Checks if a detail entry is an ErrorInfo object."""
return (
isinstance(d, dict)
and d.get('@type') == 'type.googleapis.com/google.rpc.ErrorInfo'
)


def _extract_error_info(d: Any, message: str) -> A2AError | None:
"""Extracts an A2AError from an ErrorInfo detail entry."""
if not _is_error_info(d):
return None
reason = d.get('reason')
if not isinstance(reason, str):
return None
exception_cls = A2A_REASON_TO_ERROR.get(reason)
if not exception_cls:
return None
exc = exception_cls(message)
metadata = d.get('metadata') or {}
if metadata:
exc.data = metadata
return exc


@trace_class(kind=SpanKind.CLIENT)
class RestTransport(ClientTransport):
"""A REST transport for the A2A client."""
Expand Down Expand Up @@ -338,7 +354,7 @@
mapped = _parse_rest_error(error_payload, str(e))
if mapped:
raise mapped from e
except (json.JSONDecodeError, ValueError):
except ValueError:
pass

status_code = e.response.status_code
Expand Down
8 changes: 4 additions & 4 deletions tests/client/test_auth_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ async def test_in_memory_context_credential_store(
session_id = 'session-id'
scheme_name = 'test-scheme'
credential = 'test-token'
await store.set_credentials(session_id, scheme_name, credential)
store.set_credentials(session_id, scheme_name, credential)

# Assert: Successful retrieval
context = ClientCallContext(state={'sessionId': session_id})
Expand All @@ -144,7 +144,7 @@ async def test_in_memory_context_credential_store(
assert retrieved_credential_empty is None
# Assert: Overwrite the credential when session_id already exists
new_credential = 'new-token'
await store.set_credentials(session_id, scheme_name, new_credential)
store.set_credentials(session_id, scheme_name, new_credential)
assert await store.get_credentials(scheme_name, context) == new_credential


Expand Down Expand Up @@ -249,7 +249,7 @@ async def test_auth_interceptor_variants(
test_case: AuthTestCase, store: InMemoryContextCredentialStore
) -> None:
"""Parametrized test verifying that AuthInterceptor correctly attaches credentials based on the defined security scheme in the AgentCard."""
await store.set_credentials(
store.set_credentials(
test_case.session_id, test_case.scheme_name, test_case.credential
)
auth_interceptor = AuthInterceptor(credential_service=store)
Expand Down Expand Up @@ -300,7 +300,7 @@ async def test_auth_interceptor_skips_when_scheme_not_in_security_schemes(
scheme_name = 'missing'
session_id = 'session-id'
credential = 'test-token'
await store.set_credentials(session_id, scheme_name, credential)
store.set_credentials(session_id, scheme_name, credential)
auth_interceptor = AuthInterceptor(credential_service=store)
agent_card = AgentCard(
supported_interfaces=[
Expand Down
Loading