-
Notifications
You must be signed in to change notification settings - Fork 448
feat: add composable URL validation utilities #1114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,179 @@ | ||||||||||||||||||||||
| # Copyright 2026 Google LLC | ||||||||||||||||||||||
| # | ||||||||||||||||||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||||||
| # you may not use this file except in compliance with the License. | ||||||||||||||||||||||
| # You may obtain a copy of the License at | ||||||||||||||||||||||
| # | ||||||||||||||||||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||||||
| # | ||||||||||||||||||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||||||||||||
| # limitations under the License. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| """Composable URL validation utilities.""" | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| import asyncio | ||||||||||||||||||||||
| import ipaddress | ||||||||||||||||||||||
| import socket | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| from abc import ABC, abstractmethod | ||||||||||||||||||||||
| from collections.abc import Callable, Sequence | ||||||||||||||||||||||
| from dataclasses import dataclass | ||||||||||||||||||||||
| from urllib.parse import SplitResult, urlsplit | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| IPAddress = ipaddress.IPv4Address | ipaddress.IPv6Address | ||||||||||||||||||||||
| Resolver = Callable[[str, int | None], Sequence[IPAddress | str]] | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| class InvalidUrlError(ValueError): | ||||||||||||||||||||||
| """Raised when URL validation rejects a URL.""" | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| @dataclass(frozen=True) | ||||||||||||||||||||||
| class ResolvedUrl: | ||||||||||||||||||||||
| """A parsed URL and the resolved addresses used for validation.""" | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| raw: str | ||||||||||||||||||||||
| parsed: SplitResult | ||||||||||||||||||||||
| addresses: tuple[IPAddress, ...] | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| class UrlValidationRule(ABC): | ||||||||||||||||||||||
| """A composable URL validation rule.""" | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| @abstractmethod | ||||||||||||||||||||||
| async def check(self, url: ResolvedUrl) -> None: | ||||||||||||||||||||||
| """Raise InvalidUrlError to reject the URL.""" | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| class RequireScheme(UrlValidationRule): | ||||||||||||||||||||||
| """Require a URL scheme to be one of the configured schemes.""" | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def __init__(self, allowed_schemes: Sequence[str]) -> None: | ||||||||||||||||||||||
| if not allowed_schemes: | ||||||||||||||||||||||
| raise ValueError('allowed_schemes must not be empty.') | ||||||||||||||||||||||
| self._allowed_schemes = frozenset( | ||||||||||||||||||||||
| scheme.lower() for scheme in allowed_schemes | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| async def check(self, url: ResolvedUrl) -> None: | ||||||||||||||||||||||
| """Reject URLs whose scheme is not configured as allowed.""" | ||||||||||||||||||||||
| scheme = url.parsed.scheme.lower() | ||||||||||||||||||||||
| if scheme not in self._allowed_schemes: | ||||||||||||||||||||||
| allowed = ', '.join(sorted(self._allowed_schemes)) | ||||||||||||||||||||||
| raise InvalidUrlError( | ||||||||||||||||||||||
| f'URL scheme {url.parsed.scheme!r} is not allowed. ' | ||||||||||||||||||||||
| f'Allowed schemes: {allowed}.' | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| class BlockPrivateNetworks(UrlValidationRule): | ||||||||||||||||||||||
| """Reject URLs resolving to non-public IP addresses. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Hosts in ``allow_hosts`` and addresses covered by ``allow_cidrs`` are | ||||||||||||||||||||||
| exempt from the non-public address check. | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||
| self, | ||||||||||||||||||||||
| *, | ||||||||||||||||||||||
| allow_hosts: Sequence[str] = (), | ||||||||||||||||||||||
| allow_cidrs: Sequence[str] = (), | ||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||
| self._allow_hosts = frozenset( | ||||||||||||||||||||||
| _normalize_host(host) for host in allow_hosts | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| self._allow_networks = tuple( | ||||||||||||||||||||||
| ipaddress.ip_network(cidr, strict=False) for cidr in allow_cidrs | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| async def check(self, url: ResolvedUrl) -> None: | ||||||||||||||||||||||
| """Reject URLs that resolve to non-public addresses.""" | ||||||||||||||||||||||
| host = url.parsed.hostname | ||||||||||||||||||||||
| if host is not None and _normalize_host(host) in self._allow_hosts: | ||||||||||||||||||||||
| return | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| for address in url.addresses: | ||||||||||||||||||||||
| if any(address in network for network in self._allow_networks): | ||||||||||||||||||||||
| continue | ||||||||||||||||||||||
| if not address.is_global: | ||||||||||||||||||||||
| raise InvalidUrlError( | ||||||||||||||||||||||
| f'URL host {host!r} resolves to non-public address ' | ||||||||||||||||||||||
| f'{address}.' | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| class UrlValidator: | ||||||||||||||||||||||
| """Validate URLs by parsing, resolving, then running rules in order.""" | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||
| self, | ||||||||||||||||||||||
| rules: Sequence[UrlValidationRule] = (), | ||||||||||||||||||||||
| *, | ||||||||||||||||||||||
| resolve: bool = True, | ||||||||||||||||||||||
| resolver: Resolver | None = None, | ||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||
| self._rules = tuple(rules) | ||||||||||||||||||||||
| self._resolve = resolve | ||||||||||||||||||||||
| self._resolver = resolver | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| async def validate(self, url: str) -> ResolvedUrl: | ||||||||||||||||||||||
| """Validate a URL and return the parsed URL plus resolved addresses.""" | ||||||||||||||||||||||
| resolved = await self._build(url) | ||||||||||||||||||||||
| for rule in self._rules: | ||||||||||||||||||||||
| await rule.check(resolved) | ||||||||||||||||||||||
| return resolved | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| async def _build(self, url: str) -> ResolvedUrl: | ||||||||||||||||||||||
| try: | ||||||||||||||||||||||
| parsed = urlsplit(url) | ||||||||||||||||||||||
| host = parsed.hostname | ||||||||||||||||||||||
| port = parsed.port | ||||||||||||||||||||||
| except ValueError as exc: | ||||||||||||||||||||||
| raise InvalidUrlError(f'Invalid URL {url!r}: {exc}') from exc | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| addresses: tuple[IPAddress, ...] = () | ||||||||||||||||||||||
| if self._resolve: | ||||||||||||||||||||||
| if host is None: | ||||||||||||||||||||||
| raise InvalidUrlError(f'URL {url!r} does not include a host.') | ||||||||||||||||||||||
| addresses = await self._resolve_host(host, port) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| return ResolvedUrl(raw=url, parsed=parsed, addresses=addresses) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| async def _resolve_host( | ||||||||||||||||||||||
| self, host: str, port: int | None | ||||||||||||||||||||||
| ) -> tuple[IPAddress, ...]: | ||||||||||||||||||||||
| try: | ||||||||||||||||||||||
| return (ipaddress.ip_address(host),) | ||||||||||||||||||||||
| except ValueError: | ||||||||||||||||||||||
| pass | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| try: | ||||||||||||||||||||||
| if self._resolver is not None: | ||||||||||||||||||||||
| resolved = self._resolver(host, port) | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| loop = asyncio.get_running_loop() | ||||||||||||||||||||||
| address_info = await loop.getaddrinfo( | ||||||||||||||||||||||
| host, | ||||||||||||||||||||||
| port, | ||||||||||||||||||||||
| type=socket.SOCK_STREAM, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| resolved = [info[4][0] for info in address_info] | ||||||||||||||||||||||
| except OSError as exc: | ||||||||||||||||||||||
| raise InvalidUrlError( | ||||||||||||||||||||||
| f'Could not resolve URL host {host!r}: {exc}' | ||||||||||||||||||||||
| ) from exc | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| addresses = tuple( | ||||||||||||||||||||||
| dict.fromkeys(ipaddress.ip_address(address) for address in resolved) | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
Comment on lines
+170
to
+172
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||||||||
| if not addresses: | ||||||||||||||||||||||
| raise InvalidUrlError(f'URL host {host!r} did not resolve.') | ||||||||||||||||||||||
| return addresses | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def _normalize_host(host: str) -> str: | ||||||||||||||||||||||
| return host.rstrip('.').lower() | ||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,176 @@ | ||
| """Tests for a2a.utils.url_validator.""" | ||
|
|
||
| import ipaddress | ||
|
|
||
| import pytest | ||
|
|
||
| from a2a.utils.url_validator import ( | ||
| BlockPrivateNetworks, | ||
| InvalidUrlError, | ||
| RequireScheme, | ||
| ResolvedUrl, | ||
| UrlValidationRule, | ||
| UrlValidator, | ||
| ) | ||
|
|
||
|
|
||
| class RecordingRule(UrlValidationRule): | ||
| """Records the resolved URL passed to the rule.""" | ||
|
|
||
| def __init__(self) -> None: | ||
| self.seen: list[ResolvedUrl] = [] | ||
|
|
||
| async def check(self, url: ResolvedUrl) -> None: | ||
| self.seen.append(url) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_validate_resolves_and_returns_pinned_addresses() -> None: | ||
| """UrlValidator returns parsed URL details and resolved addresses.""" | ||
|
|
||
| def resolver(host: str, port: int | None) -> list[str]: | ||
| assert host == 'example.com' | ||
| assert port == 443 | ||
| return ['93.184.216.34', '93.184.216.34'] | ||
|
|
||
| rule = RecordingRule() | ||
| validator = UrlValidator( | ||
| [RequireScheme(['https']), rule], | ||
| resolver=resolver, | ||
| ) | ||
|
|
||
| result = await validator.validate('https://example.com:443/agent') | ||
|
|
||
| assert result.raw == 'https://example.com:443/agent' | ||
| assert result.parsed.scheme == 'https' | ||
| assert result.parsed.hostname == 'example.com' | ||
| assert result.addresses == (ipaddress.ip_address('93.184.216.34'),) | ||
| assert rule.seen == [result] | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_require_scheme_rejects_disallowed_scheme() -> None: | ||
| """RequireScheme rejects URLs whose scheme is not allowed.""" | ||
| validator = UrlValidator( | ||
| [RequireScheme(['https'])], | ||
| resolve=False, | ||
| ) | ||
|
|
||
| with pytest.raises(InvalidUrlError, match='not allowed'): | ||
| await validator.validate('http://example.com') | ||
|
|
||
|
|
||
| def test_require_scheme_rejects_empty_allowed_schemes() -> None: | ||
| """RequireScheme needs at least one allowed scheme.""" | ||
| with pytest.raises(ValueError, match='must not be empty'): | ||
| RequireScheme([]) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_validate_rejects_url_without_host_when_resolving() -> None: | ||
| """URL resolution requires a host.""" | ||
| validator = UrlValidator([RequireScheme(['https'])]) | ||
|
|
||
| with pytest.raises(InvalidUrlError, match='does not include a host'): | ||
| await validator.validate('https:///missing-host') | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_validate_rejects_invalid_url() -> None: | ||
| """Malformed URLs are reported as InvalidUrlError.""" | ||
| validator = UrlValidator(resolve=False) | ||
|
|
||
| with pytest.raises(InvalidUrlError, match='Invalid URL'): | ||
| await validator.validate('http://example.com:not-a-port') | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_block_private_networks_rejects_loopback_address() -> None: | ||
| """BlockPrivateNetworks rejects non-public resolved addresses.""" | ||
| validator = UrlValidator([BlockPrivateNetworks()]) | ||
|
|
||
| with pytest.raises(InvalidUrlError, match='non-public address 127.0.0.1'): | ||
| await validator.validate('http://127.0.0.1/callback') | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_block_private_networks_allows_configured_host() -> None: | ||
| """BlockPrivateNetworks allows explicitly configured hosts.""" | ||
|
|
||
| def resolver(host: str, port: int | None) -> list[str]: | ||
| return ['127.0.0.1'] | ||
|
|
||
| validator = UrlValidator( | ||
| [BlockPrivateNetworks(allow_hosts=['internal.example.test'])], | ||
| resolver=resolver, | ||
| ) | ||
|
|
||
| result = await validator.validate('http://internal.example.test/callback') | ||
|
|
||
| assert result.addresses == (ipaddress.ip_address('127.0.0.1'),) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_block_private_networks_allows_configured_cidr() -> None: | ||
| """BlockPrivateNetworks allows addresses in configured CIDRs.""" | ||
| validator = UrlValidator([BlockPrivateNetworks(allow_cidrs=['10.0.0.0/8'])]) | ||
|
|
||
| result = await validator.validate('http://10.1.2.3/callback') | ||
|
|
||
| assert result.addresses == (ipaddress.ip_address('10.1.2.3'),) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_block_private_networks_rejects_mixed_disallowed_address() -> ( | ||
| None | ||
| ): | ||
| """All resolved addresses must be public or explicitly allowed.""" | ||
|
|
||
| def resolver(host: str, port: int | None) -> list[str]: | ||
| return ['93.184.216.34', '10.1.2.3'] | ||
|
|
||
| validator = UrlValidator([BlockPrivateNetworks()], resolver=resolver) | ||
|
|
||
| with pytest.raises(InvalidUrlError, match='10.1.2.3'): | ||
| await validator.validate('http://example.com/callback') | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_validate_reports_resolution_failures() -> None: | ||
| """Resolver failures are reported as InvalidUrlError.""" | ||
|
|
||
| def resolver(host: str, port: int | None) -> list[str]: | ||
| raise OSError('name lookup failed') | ||
|
|
||
| validator = UrlValidator(resolver=resolver) | ||
|
|
||
| with pytest.raises(InvalidUrlError, match='Could not resolve URL host'): | ||
| await validator.validate('http://example.com/callback') | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_validate_rejects_empty_resolution_result() -> None: | ||
| """Resolvers must return at least one address.""" | ||
|
|
||
| def resolver(host: str, port: int | None) -> list[str]: | ||
| return [] | ||
|
|
||
| validator = UrlValidator(resolver=resolver) | ||
|
|
||
| with pytest.raises(InvalidUrlError, match='did not resolve'): | ||
| await validator.validate('http://example.com/callback') | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_validate_without_resolution_runs_rules_with_empty_addresses() -> ( | ||
| None | ||
| ): | ||
| """UrlValidator can skip DNS resolution for parse-only validation.""" | ||
| rule = RecordingRule() | ||
| validator = UrlValidator([rule], resolve=False) | ||
|
|
||
| result = await validator.validate('custom://agent/path') | ||
|
|
||
| assert result.parsed.scheme == 'custom' | ||
| assert result.addresses == () | ||
| assert rule.seen == [result] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If
resolve=Falseis configured on theUrlValidator,url.addresseswill be empty. In this case,BlockPrivateNetworkswill silently allow any URL, even if the host is a private IP address literal (e.g.,http://127.0.0.1/). To ensure robust defense-in-depth and prevent SSRF bypasses,BlockPrivateNetworksshould attempt to parse the host as an IP address literal ifurl.addressesis empty.