diff --git a/src/a2a/utils/__init__.py b/src/a2a/utils/__init__.py index 04693dd0b..7e488f033 100644 --- a/src/a2a/utils/__init__.py +++ b/src/a2a/utils/__init__.py @@ -7,12 +7,26 @@ TransportProtocol, ) from a2a.utils.proto_utils import to_stream_response +from a2a.utils.url_validator import ( + BlockPrivateNetworks, + InvalidUrlError, + RequireScheme, + ResolvedUrl, + UrlValidationRule, + UrlValidator, +) __all__ = [ 'AGENT_CARD_WELL_KNOWN_PATH', 'DEFAULT_RPC_URL', + 'BlockPrivateNetworks', + 'InvalidUrlError', + 'RequireScheme', + 'ResolvedUrl', 'TransportProtocol', + 'UrlValidationRule', + 'UrlValidator', 'proto_utils', 'to_stream_response', ] diff --git a/src/a2a/utils/url_validator.py b/src/a2a/utils/url_validator.py new file mode 100644 index 000000000..622658861 --- /dev/null +++ b/src/a2a/utils/url_validator.py @@ -0,0 +1,179 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Composable URL validation utilities.""" + +import asyncio +import ipaddress +import socket + +from abc import ABC, abstractmethod +from collections.abc import Callable, Sequence +from dataclasses import dataclass +from urllib.parse import SplitResult, urlsplit + + +IPAddress = ipaddress.IPv4Address | ipaddress.IPv6Address +Resolver = Callable[[str, int | None], Sequence[IPAddress | str]] + + +class InvalidUrlError(ValueError): + """Raised when URL validation rejects a URL.""" + + +@dataclass(frozen=True) +class ResolvedUrl: + """A parsed URL and the resolved addresses used for validation.""" + + raw: str + parsed: SplitResult + addresses: tuple[IPAddress, ...] + + +class UrlValidationRule(ABC): + """A composable URL validation rule.""" + + @abstractmethod + async def check(self, url: ResolvedUrl) -> None: + """Raise InvalidUrlError to reject the URL.""" + + +class RequireScheme(UrlValidationRule): + """Require a URL scheme to be one of the configured schemes.""" + + def __init__(self, allowed_schemes: Sequence[str]) -> None: + if not allowed_schemes: + raise ValueError('allowed_schemes must not be empty.') + self._allowed_schemes = frozenset( + scheme.lower() for scheme in allowed_schemes + ) + + async def check(self, url: ResolvedUrl) -> None: + """Reject URLs whose scheme is not configured as allowed.""" + scheme = url.parsed.scheme.lower() + if scheme not in self._allowed_schemes: + allowed = ', '.join(sorted(self._allowed_schemes)) + raise InvalidUrlError( + f'URL scheme {url.parsed.scheme!r} is not allowed. ' + f'Allowed schemes: {allowed}.' + ) + + +class BlockPrivateNetworks(UrlValidationRule): + """Reject URLs resolving to non-public IP addresses. + + Hosts in ``allow_hosts`` and addresses covered by ``allow_cidrs`` are + exempt from the non-public address check. + """ + + def __init__( + self, + *, + allow_hosts: Sequence[str] = (), + allow_cidrs: Sequence[str] = (), + ) -> None: + self._allow_hosts = frozenset( + _normalize_host(host) for host in allow_hosts + ) + self._allow_networks = tuple( + ipaddress.ip_network(cidr, strict=False) for cidr in allow_cidrs + ) + + async def check(self, url: ResolvedUrl) -> None: + """Reject URLs that resolve to non-public addresses.""" + host = url.parsed.hostname + if host is not None and _normalize_host(host) in self._allow_hosts: + return + + for address in url.addresses: + if any(address in network for network in self._allow_networks): + continue + if not address.is_global: + raise InvalidUrlError( + f'URL host {host!r} resolves to non-public address ' + f'{address}.' + ) + + +class UrlValidator: + """Validate URLs by parsing, resolving, then running rules in order.""" + + def __init__( + self, + rules: Sequence[UrlValidationRule] = (), + *, + resolve: bool = True, + resolver: Resolver | None = None, + ) -> None: + self._rules = tuple(rules) + self._resolve = resolve + self._resolver = resolver + + async def validate(self, url: str) -> ResolvedUrl: + """Validate a URL and return the parsed URL plus resolved addresses.""" + resolved = await self._build(url) + for rule in self._rules: + await rule.check(resolved) + return resolved + + async def _build(self, url: str) -> ResolvedUrl: + try: + parsed = urlsplit(url) + host = parsed.hostname + port = parsed.port + except ValueError as exc: + raise InvalidUrlError(f'Invalid URL {url!r}: {exc}') from exc + + addresses: tuple[IPAddress, ...] = () + if self._resolve: + if host is None: + raise InvalidUrlError(f'URL {url!r} does not include a host.') + addresses = await self._resolve_host(host, port) + + return ResolvedUrl(raw=url, parsed=parsed, addresses=addresses) + + async def _resolve_host( + self, host: str, port: int | None + ) -> tuple[IPAddress, ...]: + try: + return (ipaddress.ip_address(host),) + except ValueError: + pass + + try: + if self._resolver is not None: + resolved = self._resolver(host, port) + else: + loop = asyncio.get_running_loop() + address_info = await loop.getaddrinfo( + host, + port, + type=socket.SOCK_STREAM, + ) + resolved = [info[4][0] for info in address_info] + except OSError as exc: + raise InvalidUrlError( + f'Could not resolve URL host {host!r}: {exc}' + ) from exc + + addresses = tuple( + dict.fromkeys(ipaddress.ip_address(address) for address in resolved) + ) + if not addresses: + raise InvalidUrlError(f'URL host {host!r} did not resolve.') + return addresses + + +def _normalize_host(host: str) -> str: + return host.rstrip('.').lower() diff --git a/tests/utils/test_url_validator.py b/tests/utils/test_url_validator.py new file mode 100644 index 000000000..c8c962211 --- /dev/null +++ b/tests/utils/test_url_validator.py @@ -0,0 +1,176 @@ +"""Tests for a2a.utils.url_validator.""" + +import ipaddress + +import pytest + +from a2a.utils.url_validator import ( + BlockPrivateNetworks, + InvalidUrlError, + RequireScheme, + ResolvedUrl, + UrlValidationRule, + UrlValidator, +) + + +class RecordingRule(UrlValidationRule): + """Records the resolved URL passed to the rule.""" + + def __init__(self) -> None: + self.seen: list[ResolvedUrl] = [] + + async def check(self, url: ResolvedUrl) -> None: + self.seen.append(url) + + +@pytest.mark.asyncio +async def test_validate_resolves_and_returns_pinned_addresses() -> None: + """UrlValidator returns parsed URL details and resolved addresses.""" + + def resolver(host: str, port: int | None) -> list[str]: + assert host == 'example.com' + assert port == 443 + return ['93.184.216.34', '93.184.216.34'] + + rule = RecordingRule() + validator = UrlValidator( + [RequireScheme(['https']), rule], + resolver=resolver, + ) + + result = await validator.validate('https://example.com:443/agent') + + assert result.raw == 'https://example.com:443/agent' + assert result.parsed.scheme == 'https' + assert result.parsed.hostname == 'example.com' + assert result.addresses == (ipaddress.ip_address('93.184.216.34'),) + assert rule.seen == [result] + + +@pytest.mark.asyncio +async def test_require_scheme_rejects_disallowed_scheme() -> None: + """RequireScheme rejects URLs whose scheme is not allowed.""" + validator = UrlValidator( + [RequireScheme(['https'])], + resolve=False, + ) + + with pytest.raises(InvalidUrlError, match='not allowed'): + await validator.validate('http://example.com') + + +def test_require_scheme_rejects_empty_allowed_schemes() -> None: + """RequireScheme needs at least one allowed scheme.""" + with pytest.raises(ValueError, match='must not be empty'): + RequireScheme([]) + + +@pytest.mark.asyncio +async def test_validate_rejects_url_without_host_when_resolving() -> None: + """URL resolution requires a host.""" + validator = UrlValidator([RequireScheme(['https'])]) + + with pytest.raises(InvalidUrlError, match='does not include a host'): + await validator.validate('https:///missing-host') + + +@pytest.mark.asyncio +async def test_validate_rejects_invalid_url() -> None: + """Malformed URLs are reported as InvalidUrlError.""" + validator = UrlValidator(resolve=False) + + with pytest.raises(InvalidUrlError, match='Invalid URL'): + await validator.validate('http://example.com:not-a-port') + + +@pytest.mark.asyncio +async def test_block_private_networks_rejects_loopback_address() -> None: + """BlockPrivateNetworks rejects non-public resolved addresses.""" + validator = UrlValidator([BlockPrivateNetworks()]) + + with pytest.raises(InvalidUrlError, match='non-public address 127.0.0.1'): + await validator.validate('http://127.0.0.1/callback') + + +@pytest.mark.asyncio +async def test_block_private_networks_allows_configured_host() -> None: + """BlockPrivateNetworks allows explicitly configured hosts.""" + + def resolver(host: str, port: int | None) -> list[str]: + return ['127.0.0.1'] + + validator = UrlValidator( + [BlockPrivateNetworks(allow_hosts=['internal.example.test'])], + resolver=resolver, + ) + + result = await validator.validate('http://internal.example.test/callback') + + assert result.addresses == (ipaddress.ip_address('127.0.0.1'),) + + +@pytest.mark.asyncio +async def test_block_private_networks_allows_configured_cidr() -> None: + """BlockPrivateNetworks allows addresses in configured CIDRs.""" + validator = UrlValidator([BlockPrivateNetworks(allow_cidrs=['10.0.0.0/8'])]) + + result = await validator.validate('http://10.1.2.3/callback') + + assert result.addresses == (ipaddress.ip_address('10.1.2.3'),) + + +@pytest.mark.asyncio +async def test_block_private_networks_rejects_mixed_disallowed_address() -> ( + None +): + """All resolved addresses must be public or explicitly allowed.""" + + def resolver(host: str, port: int | None) -> list[str]: + return ['93.184.216.34', '10.1.2.3'] + + validator = UrlValidator([BlockPrivateNetworks()], resolver=resolver) + + with pytest.raises(InvalidUrlError, match='10.1.2.3'): + await validator.validate('http://example.com/callback') + + +@pytest.mark.asyncio +async def test_validate_reports_resolution_failures() -> None: + """Resolver failures are reported as InvalidUrlError.""" + + def resolver(host: str, port: int | None) -> list[str]: + raise OSError('name lookup failed') + + validator = UrlValidator(resolver=resolver) + + with pytest.raises(InvalidUrlError, match='Could not resolve URL host'): + await validator.validate('http://example.com/callback') + + +@pytest.mark.asyncio +async def test_validate_rejects_empty_resolution_result() -> None: + """Resolvers must return at least one address.""" + + def resolver(host: str, port: int | None) -> list[str]: + return [] + + validator = UrlValidator(resolver=resolver) + + with pytest.raises(InvalidUrlError, match='did not resolve'): + await validator.validate('http://example.com/callback') + + +@pytest.mark.asyncio +async def test_validate_without_resolution_runs_rules_with_empty_addresses() -> ( + None +): + """UrlValidator can skip DNS resolution for parse-only validation.""" + rule = RecordingRule() + validator = UrlValidator([rule], resolve=False) + + result = await validator.validate('custom://agent/path') + + assert result.parsed.scheme == 'custom' + assert result.addresses == () + assert rule.seen == [result]