Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/a2a/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,26 @@
TransportProtocol,
)
from a2a.utils.proto_utils import to_stream_response
from a2a.utils.url_validator import (
BlockPrivateNetworks,
InvalidUrlError,
RequireScheme,
ResolvedUrl,
UrlValidationRule,
UrlValidator,
)


__all__ = [
'AGENT_CARD_WELL_KNOWN_PATH',
'DEFAULT_RPC_URL',
'BlockPrivateNetworks',
'InvalidUrlError',
'RequireScheme',
'ResolvedUrl',
'TransportProtocol',
'UrlValidationRule',
'UrlValidator',
'proto_utils',
'to_stream_response',
]
179 changes: 179 additions & 0 deletions src/a2a/utils/url_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Composable URL validation utilities."""

import asyncio
import ipaddress
import socket

from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from urllib.parse import SplitResult, urlsplit


IPAddress = ipaddress.IPv4Address | ipaddress.IPv6Address
Resolver = Callable[[str, int | None], Sequence[IPAddress | str]]


class InvalidUrlError(ValueError):
"""Raised when URL validation rejects a URL."""


@dataclass(frozen=True)
class ResolvedUrl:
"""A parsed URL and the resolved addresses used for validation."""

raw: str
parsed: SplitResult
addresses: tuple[IPAddress, ...]


class UrlValidationRule(ABC):
"""A composable URL validation rule."""

@abstractmethod
async def check(self, url: ResolvedUrl) -> None:
"""Raise InvalidUrlError to reject the URL."""


class RequireScheme(UrlValidationRule):
"""Require a URL scheme to be one of the configured schemes."""

def __init__(self, allowed_schemes: Sequence[str]) -> None:
if not allowed_schemes:
raise ValueError('allowed_schemes must not be empty.')
self._allowed_schemes = frozenset(
scheme.lower() for scheme in allowed_schemes
)

async def check(self, url: ResolvedUrl) -> None:
"""Reject URLs whose scheme is not configured as allowed."""
scheme = url.parsed.scheme.lower()
if scheme not in self._allowed_schemes:
allowed = ', '.join(sorted(self._allowed_schemes))
raise InvalidUrlError(
f'URL scheme {url.parsed.scheme!r} is not allowed. '
f'Allowed schemes: {allowed}.'
)


class BlockPrivateNetworks(UrlValidationRule):
"""Reject URLs resolving to non-public IP addresses.

Hosts in ``allow_hosts`` and addresses covered by ``allow_cidrs`` are
exempt from the non-public address check.
"""

def __init__(
self,
*,
allow_hosts: Sequence[str] = (),
allow_cidrs: Sequence[str] = (),
) -> None:
self._allow_hosts = frozenset(
_normalize_host(host) for host in allow_hosts
)
self._allow_networks = tuple(
ipaddress.ip_network(cidr, strict=False) for cidr in allow_cidrs
)

async def check(self, url: ResolvedUrl) -> None:
"""Reject URLs that resolve to non-public addresses."""
host = url.parsed.hostname
if host is not None and _normalize_host(host) in self._allow_hosts:
return

for address in url.addresses:
if any(address in network for network in self._allow_networks):
continue
if not address.is_global:
raise InvalidUrlError(
f'URL host {host!r} resolves to non-public address '
f'{address}.'
)
Comment on lines +99 to +106

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

If resolve=False is configured on the UrlValidator, url.addresses will be empty. In this case, BlockPrivateNetworks will silently allow any URL, even if the host is a private IP address literal (e.g., http://127.0.0.1/). To ensure robust defense-in-depth and prevent SSRF bypasses, BlockPrivateNetworks should attempt to parse the host as an IP address literal if url.addresses is empty.

        addresses = url.addresses
        if not addresses and host is not None:
            try:
                addresses = (ipaddress.ip_address(host),)
            except ValueError:
                pass

        for address in addresses:
            if any(address in network for network in self._allow_networks):
                continue
            if not address.is_global:
                raise InvalidUrlError(
                    f'URL host {host!r} resolves to non-public address '
                    f'{address}.'
                )



class UrlValidator:
"""Validate URLs by parsing, resolving, then running rules in order."""

def __init__(
self,
rules: Sequence[UrlValidationRule] = (),
*,
resolve: bool = True,
resolver: Resolver | None = None,
) -> None:
self._rules = tuple(rules)
self._resolve = resolve
self._resolver = resolver

async def validate(self, url: str) -> ResolvedUrl:
"""Validate a URL and return the parsed URL plus resolved addresses."""
resolved = await self._build(url)
for rule in self._rules:
await rule.check(resolved)
return resolved

async def _build(self, url: str) -> ResolvedUrl:
try:
parsed = urlsplit(url)
host = parsed.hostname
port = parsed.port
except ValueError as exc:
raise InvalidUrlError(f'Invalid URL {url!r}: {exc}') from exc

addresses: tuple[IPAddress, ...] = ()
if self._resolve:
if host is None:
raise InvalidUrlError(f'URL {url!r} does not include a host.')
addresses = await self._resolve_host(host, port)

return ResolvedUrl(raw=url, parsed=parsed, addresses=addresses)

async def _resolve_host(
self, host: str, port: int | None
) -> tuple[IPAddress, ...]:
try:
return (ipaddress.ip_address(host),)
except ValueError:
pass

try:
if self._resolver is not None:
resolved = self._resolver(host, port)
else:
loop = asyncio.get_running_loop()
address_info = await loop.getaddrinfo(
host,
port,
type=socket.SOCK_STREAM,
)
resolved = [info[4][0] for info in address_info]
except OSError as exc:
raise InvalidUrlError(
f'Could not resolve URL host {host!r}: {exc}'
) from exc

addresses = tuple(
dict.fromkeys(ipaddress.ip_address(address) for address in resolved)
)
Comment on lines +170 to +172

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The Resolver type alias is defined as returning Sequence[IPAddress | str]. However, ipaddress.ip_address() raises a ValueError if passed an already-instantiated IPv4Address or IPv6Address object. If a custom resolver returns IPAddress objects, _resolve_host will crash with a ValueError at runtime. Checking the type of each address before passing it to ipaddress.ip_address() resolves this issue.

Suggested change
addresses = tuple(
dict.fromkeys(ipaddress.ip_address(address) for address in resolved)
)
addresses = tuple(
dict.fromkeys(
addr if isinstance(addr, (ipaddress.IPv4Address, ipaddress.IPv6Address))
else ipaddress.ip_address(addr)
for addr in resolved
)
)

if not addresses:
raise InvalidUrlError(f'URL host {host!r} did not resolve.')
return addresses


def _normalize_host(host: str) -> str:
return host.rstrip('.').lower()
176 changes: 176 additions & 0 deletions tests/utils/test_url_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
"""Tests for a2a.utils.url_validator."""

import ipaddress

import pytest

from a2a.utils.url_validator import (
BlockPrivateNetworks,
InvalidUrlError,
RequireScheme,
ResolvedUrl,
UrlValidationRule,
UrlValidator,
)


class RecordingRule(UrlValidationRule):
"""Records the resolved URL passed to the rule."""

def __init__(self) -> None:
self.seen: list[ResolvedUrl] = []

async def check(self, url: ResolvedUrl) -> None:
self.seen.append(url)


@pytest.mark.asyncio
async def test_validate_resolves_and_returns_pinned_addresses() -> None:
"""UrlValidator returns parsed URL details and resolved addresses."""

def resolver(host: str, port: int | None) -> list[str]:
assert host == 'example.com'
assert port == 443
return ['93.184.216.34', '93.184.216.34']

rule = RecordingRule()
validator = UrlValidator(
[RequireScheme(['https']), rule],
resolver=resolver,
)

result = await validator.validate('https://example.com:443/agent')

assert result.raw == 'https://example.com:443/agent'
assert result.parsed.scheme == 'https'
assert result.parsed.hostname == 'example.com'
assert result.addresses == (ipaddress.ip_address('93.184.216.34'),)
assert rule.seen == [result]


@pytest.mark.asyncio
async def test_require_scheme_rejects_disallowed_scheme() -> None:
"""RequireScheme rejects URLs whose scheme is not allowed."""
validator = UrlValidator(
[RequireScheme(['https'])],
resolve=False,
)

with pytest.raises(InvalidUrlError, match='not allowed'):
await validator.validate('http://example.com')


def test_require_scheme_rejects_empty_allowed_schemes() -> None:
"""RequireScheme needs at least one allowed scheme."""
with pytest.raises(ValueError, match='must not be empty'):
RequireScheme([])


@pytest.mark.asyncio
async def test_validate_rejects_url_without_host_when_resolving() -> None:
"""URL resolution requires a host."""
validator = UrlValidator([RequireScheme(['https'])])

with pytest.raises(InvalidUrlError, match='does not include a host'):
await validator.validate('https:///missing-host')


@pytest.mark.asyncio
async def test_validate_rejects_invalid_url() -> None:
"""Malformed URLs are reported as InvalidUrlError."""
validator = UrlValidator(resolve=False)

with pytest.raises(InvalidUrlError, match='Invalid URL'):
await validator.validate('http://example.com:not-a-port')


@pytest.mark.asyncio
async def test_block_private_networks_rejects_loopback_address() -> None:
"""BlockPrivateNetworks rejects non-public resolved addresses."""
validator = UrlValidator([BlockPrivateNetworks()])

with pytest.raises(InvalidUrlError, match='non-public address 127.0.0.1'):
await validator.validate('http://127.0.0.1/callback')


@pytest.mark.asyncio
async def test_block_private_networks_allows_configured_host() -> None:
"""BlockPrivateNetworks allows explicitly configured hosts."""

def resolver(host: str, port: int | None) -> list[str]:
return ['127.0.0.1']

validator = UrlValidator(
[BlockPrivateNetworks(allow_hosts=['internal.example.test'])],
resolver=resolver,
)

result = await validator.validate('http://internal.example.test/callback')

assert result.addresses == (ipaddress.ip_address('127.0.0.1'),)


@pytest.mark.asyncio
async def test_block_private_networks_allows_configured_cidr() -> None:
"""BlockPrivateNetworks allows addresses in configured CIDRs."""
validator = UrlValidator([BlockPrivateNetworks(allow_cidrs=['10.0.0.0/8'])])

result = await validator.validate('http://10.1.2.3/callback')

assert result.addresses == (ipaddress.ip_address('10.1.2.3'),)


@pytest.mark.asyncio
async def test_block_private_networks_rejects_mixed_disallowed_address() -> (
None
):
"""All resolved addresses must be public or explicitly allowed."""

def resolver(host: str, port: int | None) -> list[str]:
return ['93.184.216.34', '10.1.2.3']

validator = UrlValidator([BlockPrivateNetworks()], resolver=resolver)

with pytest.raises(InvalidUrlError, match='10.1.2.3'):
await validator.validate('http://example.com/callback')


@pytest.mark.asyncio
async def test_validate_reports_resolution_failures() -> None:
"""Resolver failures are reported as InvalidUrlError."""

def resolver(host: str, port: int | None) -> list[str]:
raise OSError('name lookup failed')

validator = UrlValidator(resolver=resolver)

with pytest.raises(InvalidUrlError, match='Could not resolve URL host'):
await validator.validate('http://example.com/callback')


@pytest.mark.asyncio
async def test_validate_rejects_empty_resolution_result() -> None:
"""Resolvers must return at least one address."""

def resolver(host: str, port: int | None) -> list[str]:
return []

validator = UrlValidator(resolver=resolver)

with pytest.raises(InvalidUrlError, match='did not resolve'):
await validator.validate('http://example.com/callback')


@pytest.mark.asyncio
async def test_validate_without_resolution_runs_rules_with_empty_addresses() -> (
None
):
"""UrlValidator can skip DNS resolution for parse-only validation."""
rule = RecordingRule()
validator = UrlValidator([rule], resolve=False)

result = await validator.validate('custom://agent/path')

assert result.parsed.scheme == 'custom'
assert result.addresses == ()
assert rule.seen == [result]
Loading