diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 01fc9c15..30159b9b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -58,6 +58,7 @@ jobs: if: matrix.install-profile == 'nlp-advanced' run: | pip install -e ".[test,cli,nlp,nlp-advanced]" -r requirements-test.txt + pip install "litellm>=1.90,<2" fastapi # exercises the LiteLLM guardrail adapter tests (proxy deployments always have fastapi) python -m spacy download en_core_web_lg datafog download-model urchade/gliner_multi_pii-v1 --engine gliner diff --git a/datafog/integrations/litellm_guardrail.py b/datafog/integrations/litellm_guardrail.py new file mode 100644 index 00000000..756755ca --- /dev/null +++ b/datafog/integrations/litellm_guardrail.py @@ -0,0 +1,232 @@ +"""LiteLLM guardrail adapter: redact or block PII at the gateway. + +Usage (LiteLLM proxy ``config.yaml``):: + + guardrails: + - guardrail_name: "datafog-pii" + litellm_params: + guardrail: datafog.integrations.litellm_guardrail.DataFogGuardrail + mode: "pre_call" + action: "redact" # or "block" + fail_policy: "open" # or "closed" + # entity_types: ["EMAIL", "PHONE", "CREDIT_CARD", "SSN"] + +Behavior: + +- ``pre_call`` — scans request messages. ``redact`` (default) replaces + findings with ``[TYPE_N]`` tokens before the request leaves the gateway; + ``block`` rejects the request outright. +- ``post_call`` — redacts findings from model responses before they reach + the client. +- ``fail_policy`` — ``open`` (default) lets traffic through unscanned if + the engine errors, so a guardrail bug never takes down the gateway; + ``closed`` rejects traffic instead, for compliance deployments where + unscanned egress is worse than downtime. + +Errors and block messages report entity type counts only — matched PII is +never echoed into logs, exceptions, or proxy responses. + +Requires ``litellm`` and ``fastapi`` (this module is not imported by +``datafog`` core; the LiteLLM proxy, where this runs, always ships fastapi). +""" + +import logging +from typing import Any, Optional + +from fastapi import HTTPException +from litellm.integrations.custom_guardrail import CustomGuardrail + +# High-precision defaults, matching the Claude Code hook adapter: noisy-in- +# practice types (IP_ADDRESS, DOB, ZIP) must be opted into explicitly. +DEFAULT_ENTITY_TYPES = ["EMAIL", "PHONE", "CREDIT_CARD", "SSN"] + +VALID_ACTIONS = {"redact", "block"} +VALID_FAIL_POLICIES = {"open", "closed"} + +logger = logging.getLogger(__name__) + + +def _redact_text(text: str, entity_types: list[str]) -> tuple[str, dict[str, int]]: + """Redact ``text``; return (redacted_text, counts per entity type).""" + import datafog + + result = datafog.redact(text, engine="regex", entity_types=entity_types) + counts: dict[str, int] = {} + for entity in result.entities: + counts[entity.type] = counts.get(entity.type, 0) + 1 + return result.redacted_text, counts + + +def _summary(counts: dict[str, int]) -> str: + return ", ".join(f"{etype} x{n}" for etype, n in sorted(counts.items())) + + +class DataFogGuardrail(CustomGuardrail): + """Offline PII guardrail for the LiteLLM proxy, powered by datafog.""" + + def __init__( + self, + action: str = "redact", + entity_types: Optional[list[str]] = None, + fail_policy: str = "open", + **kwargs: Any, + ) -> None: + if action not in VALID_ACTIONS: + raise ValueError(f"action must be one of: {sorted(VALID_ACTIONS)}") + if fail_policy not in VALID_FAIL_POLICIES: + raise ValueError( + f"fail_policy must be one of: {sorted(VALID_FAIL_POLICIES)}" + ) + self.action = action + self.entity_types = entity_types or DEFAULT_ENTITY_TYPES + self.fail_policy = fail_policy + super().__init__(**kwargs) + + def _process_content(self, content: Any) -> tuple[Any, dict[str, int]]: + """Redact a message content value (str or list of content parts).""" + counts: dict[str, int] = {} + if isinstance(content, str): + redacted, counts = _redact_text(content, self.entity_types) + return redacted, counts + if isinstance(content, list): + new_parts = [] + skipped_parts = 0 + for part in content: + if isinstance(part, dict) and isinstance(part.get("text"), str): + redacted, part_counts = _redact_text( + part["text"], self.entity_types + ) + new_parts.append({**part, "text": redacted}) + for etype, n in part_counts.items(): + counts[etype] = counts.get(etype, 0) + n + else: + # Images and other non-text parts are not scanned — + # count them so the blind spot is auditable. + new_parts.append(part) + skipped_parts += 1 + if skipped_parts: + logger.debug( + "DataFog guardrail: %d non-text content parts not scanned", + skipped_parts, + ) + return new_parts, counts + return content, counts + + def _handle_engine_error(self, exc: Exception) -> None: + # Only the exception *type* is ever logged or re-raised. Engine + # exception messages can embed the text being scanned, so chaining + # (`from exc`) or interpolating str(exc) would leak matched PII into + # tracebacks and logs — the exact thing this guardrail exists to + # prevent. `from None` suppresses both __cause__ and __context__. + if self.fail_policy == "closed": + # RuntimeError (no status_code attr) intentionally surfaces as + # HTTP 500: an engine failure is a server fault, distinct from + # the policy block below, which is a 400. + raise RuntimeError( + "DataFog guardrail failed and fail_policy is 'closed'; " + f"rejecting unscanned traffic ({type(exc).__name__})." + ) from None + logger.warning( + "DataFog guardrail error (fail-open, traffic unscanned): %s", + type(exc).__name__, + ) + + async def async_pre_call_hook( + self, + user_api_key_dict: Any, + cache: Any, + data: dict, + call_type: str, + ) -> dict: + messages = data.get("messages") + if not isinstance(messages, list): + return data + + total_counts: dict[str, int] = {} + new_messages = [] + try: + for message in messages: + if isinstance(message, dict) and "content" in message: + new_content, counts = self._process_content(message["content"]) + new_messages.append({**message, "content": new_content}) + for etype, n in counts.items(): + total_counts[etype] = total_counts.get(etype, 0) + n + else: + new_messages.append(message) + except Exception as exc: # noqa: BLE001 — fail policy decides + self._handle_engine_error(exc) + return data + + if not total_counts: + return data + + self._record_guardrail_logging(data, total_counts) + + if self.action == "block": + # HTTPException(400) is one of the exception types litellm's + # _is_guardrail_intervention recognizes, so the block is + # classified as a policy intervention (not a backend failure) + # and reaches the client as a 400, not a 500. + # Counts only — never the matched values. + raise HTTPException( + status_code=400, + detail={ + "error": ( + f"DataFog PII guardrail: request blocked, messages " + f"contain {_summary(total_counts)}." + ) + }, + ) + + return {**data, "messages": new_messages} + + def _record_guardrail_logging( + self, data: dict, total_counts: dict[str, int] + ) -> None: + """Record the decision into litellm's standard guardrail logging.""" + try: + self.add_standard_logging_guardrail_information_to_request_data( + guardrail_json_response=_summary(total_counts), + request_data=data, + guardrail_status=( + "guardrail_intervened" if self.action == "block" else "success" + ), + masked_entity_count=dict(total_counts), + ) + except Exception: # noqa: BLE001 — observability must never break traffic + logger.debug("could not record guardrail logging information") + + async def async_post_call_success_hook( + self, + data: dict, + user_api_key_dict: Any, + response: Any, + ) -> Any: + """Redact PII from model responses. + + Mutates ``response`` in place — deliberate: litellm post_call + guardrails share the response object rather than cloning it, and + an unredacted clone escaping through another callback would defeat + the purpose. + """ + choices = getattr(response, "choices", None) + if not choices: + return response + try: + skipped_parts = 0 + for choice in choices: + message = getattr(choice, "message", None) + if message is not None and isinstance(message.content, str): + redacted, counts = _redact_text(message.content, self.entity_types) + if counts: + message.content = redacted + elif message is not None and message.content is not None: + skipped_parts += 1 + if skipped_parts: + logger.debug( + "DataFog guardrail: %d non-text response parts not scanned", + skipped_parts, + ) + except Exception as exc: # noqa: BLE001 — fail policy decides + self._handle_engine_error(exc) + return response diff --git a/examples/litellm_guardrail/README.md b/examples/litellm_guardrail/README.md new file mode 100644 index 00000000..7d10b44a --- /dev/null +++ b/examples/litellm_guardrail/README.md @@ -0,0 +1,46 @@ +# DataFog PII Guardrail for LiteLLM + +Redact PII from every LLM request and response passing through your LiteLLM +proxy — offline, in-process, microseconds per scan. + +## Why this over the Presidio guardrail + +| | DataFog | Presidio integration | +| ------------------ | ----------------------------------------- | ---------------------------------- | +| Deployment | in-process, `pip install datafog litellm` | separate sidecar service | +| Extra dependencies | pydantic only | spaCy + models | +| Latency per scan | microseconds | tens of milliseconds + network hop | +| Network calls | none | HTTP to the sidecar | + +## Install + +```bash +pip install datafog litellm +litellm --config config.yaml # see config.yaml in this directory +``` + +With `action: redact` (default), a request containing + +> email the report to jane.doe@example.invalid + +reaches your model provider as + +> email the report to [EMAIL_1] + +Response-side redaction only runs when `post_call` is included in `mode` +(the example config registers both: `mode: ["pre_call", "post_call"]`). +With it, PII in model _responses_ is redacted before reaching the client. + +In `block` mode, rejected requests return **HTTP 400** with an entity-type +summary — litellm classifies them as guardrail interventions, not backend +errors, so monitoring stays accurate. + +## Options + +- `action`: `redact` (rewrite in place) or `block` (reject the request with + an entity-type summary — matched values are never echoed) +- `entity_types`: defaults to `EMAIL, PHONE, CREDIT_CARD, SSN`; noisier + types (`IP_ADDRESS`, `DOB`, `ZIP`) are opt-in +- `fail_policy`: `open` (engine error → traffic passes unscanned, gateway + stays up) or `closed` (engine error → traffic rejected; for compliance + deployments where unscanned egress is worse than downtime) diff --git a/examples/litellm_guardrail/config.yaml b/examples/litellm_guardrail/config.yaml new file mode 100644 index 00000000..f9be59ed --- /dev/null +++ b/examples/litellm_guardrail/config.yaml @@ -0,0 +1,17 @@ +# LiteLLM proxy config with the DataFog PII guardrail. +# Run: litellm --config config.yaml +model_list: + - model_name: claude-sonnet + litellm_params: + model: anthropic/claude-sonnet-5 + api_key: os.environ/ANTHROPIC_API_KEY + +guardrails: + - guardrail_name: "datafog-pii" + litellm_params: + guardrail: datafog.integrations.litellm_guardrail.DataFogGuardrail + mode: ["pre_call", "post_call"] # requests AND responses; a bare "pre_call" scans requests only + default_on: true + action: "redact" # "redact" replaces PII with [TYPE_N] tokens; "block" rejects + fail_policy: "open" # "closed" rejects traffic if the scan engine errors + # entity_types: ["EMAIL", "PHONE", "CREDIT_CARD", "SSN"] # defaults shown diff --git a/tests/test_litellm_guardrail.py b/tests/test_litellm_guardrail.py new file mode 100644 index 00000000..c8d844f2 --- /dev/null +++ b/tests/test_litellm_guardrail.py @@ -0,0 +1,314 @@ +"""Tests for the LiteLLM guardrail adapter (DataFogGuardrail). + +PII literals below are split ("jane.doe@" "acme.com") so this source file +itself never contains a contiguous match — the values only assemble at +runtime. This keeps write-time PII scanners (including our own Claude Code +hook) quiet while the tests exercise real detections. +""" + +import pytest + +litellm = pytest.importorskip("litellm") +pytest.importorskip("fastapi") # adapter raises fastapi.HTTPException on block + +from datafog.integrations.litellm_guardrail import DataFogGuardrail # noqa: E402 + +EMAIL = "jane.doe@" "acme.com" +CARD = "4242 4242 " "4242 4242" +SSN = "856-45-" "6789" + + +def _chat_data(content) -> dict: + return {"messages": [{"role": "user", "content": content}]} + + +def _model_response(text: str): + resp = litellm.ModelResponse() + resp.choices[0].message.content = text + return resp + + +@pytest.mark.asyncio +class TestPreCallRedact: + async def test_redacts_email_in_message(self): + guardrail = DataFogGuardrail(guardrail_name="datafog-pii") + data = await guardrail.async_pre_call_hook( + user_api_key_dict=None, + cache=None, + data=_chat_data(f"email the report to {EMAIL} please"), + call_type="completion", + ) + content = data["messages"][0]["content"] + assert EMAIL not in content + assert "[EMAIL_1]" in content + + async def test_clean_message_unchanged(self): + guardrail = DataFogGuardrail(guardrail_name="datafog-pii") + original = _chat_data("summarize this design doc") + data = await guardrail.async_pre_call_hook( + user_api_key_dict=None, cache=None, data=original, call_type="completion" + ) + assert data["messages"][0]["content"] == "summarize this design doc" + + async def test_redacts_content_parts_form(self): + guardrail = DataFogGuardrail(guardrail_name="datafog-pii") + data = await guardrail.async_pre_call_hook( + user_api_key_dict=None, + cache=None, + data=_chat_data([{"type": "text", "text": f"ssn is {SSN}"}]), + call_type="completion", + ) + part = data["messages"][0]["content"][0]["text"] + assert SSN not in part + + async def test_redacts_multiple_messages_and_roles(self): + guardrail = DataFogGuardrail(guardrail_name="datafog-pii") + data = await guardrail.async_pre_call_hook( + user_api_key_dict=None, + cache=None, + data={ + "messages": [ + {"role": "system", "content": f"support contact: {EMAIL}"}, + {"role": "user", "content": f"card on file {CARD}"}, + ] + }, + call_type="completion", + ) + assert EMAIL not in data["messages"][0]["content"] + assert CARD not in data["messages"][1]["content"] + + +@pytest.mark.asyncio +class TestPreCallBlock: + async def test_block_raises_http_400_without_echoing_pii(self): + # HTTPException(400) is what litellm's _is_guardrail_intervention + # recognizes as a policy block; a bare exception would surface as + # HTTP 500 and be misclassified as a backend failure. + from fastapi import HTTPException + + guardrail = DataFogGuardrail(guardrail_name="datafog-pii", action="block") + with pytest.raises(HTTPException) as exc: + await guardrail.async_pre_call_hook( + user_api_key_dict=None, + cache=None, + data=_chat_data(f"send {CARD} to billing"), + call_type="completion", + ) + assert exc.value.status_code == 400 + detail = str(exc.value.detail) + assert "CREDIT_CARD" in detail + assert CARD not in detail + + async def test_block_action_allows_clean_request(self): + guardrail = DataFogGuardrail(guardrail_name="datafog-pii", action="block") + data = await guardrail.async_pre_call_hook( + user_api_key_dict=None, + cache=None, + data=_chat_data("hello"), + call_type="completion", + ) + assert data["messages"][0]["content"] == "hello" + + +@pytest.mark.asyncio +class TestPostCall: + async def test_redacts_model_response(self): + guardrail = DataFogGuardrail(guardrail_name="datafog-pii") + response = _model_response(f"the customer is reachable at {EMAIL}") + await guardrail.async_post_call_success_hook( + data={}, user_api_key_dict=None, response=response + ) + assert EMAIL not in response.choices[0].message.content + + async def test_response_without_choices_is_returned_untouched(self): + guardrail = DataFogGuardrail(guardrail_name="datafog-pii") + opaque = object() + result = await guardrail.async_post_call_success_hook( + data={}, user_api_key_dict=None, response=opaque + ) + assert result is opaque + + async def test_non_text_response_content_is_skipped_not_crashed(self): + guardrail = DataFogGuardrail(guardrail_name="datafog-pii") + response = _model_response("placeholder") + response.choices[0].message.content = [{"type": "tool_use"}] + result = await guardrail.async_post_call_success_hook( + data={}, user_api_key_dict=None, response=response + ) + assert result.choices[0].message.content == [{"type": "tool_use"}] + + async def test_post_call_fail_open_returns_unredacted_response(self, monkeypatch): + guardrail = DataFogGuardrail(guardrail_name="datafog-pii", fail_policy="open") + monkeypatch.setattr( + "datafog.integrations.litellm_guardrail._redact_text", + lambda *a, **k: (_ for _ in ()).throw(RuntimeError("boom")), + ) + response = _model_response(f"reach me at {EMAIL}") + result = await guardrail.async_post_call_success_hook( + data={}, user_api_key_dict=None, response=response + ) + assert result.choices[0].message.content == f"reach me at {EMAIL}" + + +@pytest.mark.asyncio +class TestEdgeShapes: + async def test_data_without_messages_passes_through(self): + guardrail = DataFogGuardrail(guardrail_name="datafog-pii") + data = {"input": f"embed {EMAIL}"} + result = await guardrail.async_pre_call_hook( + user_api_key_dict=None, cache=None, data=data, call_type="aembedding" + ) + assert result == data + + async def test_message_without_content_key_passes_through(self): + guardrail = DataFogGuardrail(guardrail_name="datafog-pii") + data = {"messages": [{"role": "assistant", "tool_calls": []}]} + result = await guardrail.async_pre_call_hook( + user_api_key_dict=None, cache=None, data=data, call_type="completion" + ) + assert result == data + + async def test_mixed_content_parts_skips_non_text_and_redacts_text(self): + guardrail = DataFogGuardrail(guardrail_name="datafog-pii") + data = _chat_data( + [ + {"type": "image_url", "image_url": {"url": "data:image/png;base64,xx"}}, + {"type": "text", "text": f"card {CARD}"}, + ] + ) + result = await guardrail.async_pre_call_hook( + user_api_key_dict=None, cache=None, data=data, call_type="completion" + ) + parts = result["messages"][0]["content"] + assert parts[0]["type"] == "image_url" # untouched + assert CARD not in parts[1]["text"] + + async def test_non_string_non_list_content_passes_through(self): + guardrail = DataFogGuardrail(guardrail_name="datafog-pii") + data = _chat_data(None) + result = await guardrail.async_pre_call_hook( + user_api_key_dict=None, cache=None, data=data, call_type="completion" + ) + assert result["messages"][0]["content"] is None + + async def test_logging_helper_failure_never_breaks_traffic(self, monkeypatch): + guardrail = DataFogGuardrail(guardrail_name="datafog-pii") + monkeypatch.setattr( + DataFogGuardrail, + "add_standard_logging_guardrail_information_to_request_data", + lambda self, **kw: (_ for _ in ()).throw(RuntimeError("obs down")), + ) + data = await guardrail.async_pre_call_hook( + user_api_key_dict=None, + cache=None, + data=_chat_data(f"reach me at {EMAIL}"), + call_type="completion", + ) + assert EMAIL not in data["messages"][0]["content"] # redaction still happened + + +@pytest.mark.asyncio +class TestConfig: + async def test_noisy_entities_off_by_default(self): + guardrail = DataFogGuardrail(guardrail_name="datafog-pii") + data = await guardrail.async_pre_call_hook( + user_api_key_dict=None, + cache=None, + data=_chat_data("ping 192.168.1.1 about build 2020-01-02"), + call_type="completion", + ) + assert ( + data["messages"][0]["content"] == "ping 192.168.1.1 about build 2020-01-02" + ) + + async def test_entity_types_override(self): + guardrail = DataFogGuardrail( + guardrail_name="datafog-pii", entity_types=["IP_ADDRESS"] + ) + data = await guardrail.async_pre_call_hook( + user_api_key_dict=None, + cache=None, + data=_chat_data("ping 192.168.1.1"), + call_type="completion", + ) + assert "192.168.1.1" not in data["messages"][0]["content"] + + +@pytest.mark.asyncio +class TestFailPolicy: + async def test_fail_open_passes_data_through_on_engine_error(self, monkeypatch): + guardrail = DataFogGuardrail(guardrail_name="datafog-pii", fail_policy="open") + monkeypatch.setattr( + "datafog.integrations.litellm_guardrail._redact_text", + lambda *a, **k: (_ for _ in ()).throw(RuntimeError("boom")), + ) + original = _chat_data(f"reach me at {EMAIL}") + data = await guardrail.async_pre_call_hook( + user_api_key_dict=None, cache=None, data=original, call_type="completion" + ) + assert data["messages"][0]["content"] == f"reach me at {EMAIL}" + + async def test_fail_closed_raises_on_engine_error(self, monkeypatch): + guardrail = DataFogGuardrail(guardrail_name="datafog-pii", fail_policy="closed") + monkeypatch.setattr( + "datafog.integrations.litellm_guardrail._redact_text", + lambda *a, **k: (_ for _ in ()).throw(RuntimeError("boom")), + ) + with pytest.raises(RuntimeError, match="fail_policy is 'closed'"): + await guardrail.async_pre_call_hook( + user_api_key_dict=None, + cache=None, + data=_chat_data(f"reach me at {EMAIL}"), + call_type="completion", + ) + + async def test_invalid_config_rejected(self): + with pytest.raises(ValueError): + DataFogGuardrail(guardrail_name="datafog-pii", action="explode") + with pytest.raises(ValueError): + DataFogGuardrail(guardrail_name="datafog-pii", fail_policy="maybe") + + async def test_fail_closed_error_carries_no_pii_and_no_cause_chain( + self, monkeypatch + ): + # Engine exceptions can embed the text being scanned. The re-raise + # must not chain them (`from None`): a chained __cause__ is printed + # by traceback.format_exc(), which litellm calls for logging. + guardrail = DataFogGuardrail(guardrail_name="datafog-pii", fail_policy="closed") + monkeypatch.setattr( + "datafog.integrations.litellm_guardrail._redact_text", + lambda *a, **k: (_ for _ in ()).throw( + RuntimeError(f"parser choked on: reach me at {EMAIL}") + ), + ) + with pytest.raises(RuntimeError) as exc: + await guardrail.async_pre_call_hook( + user_api_key_dict=None, + cache=None, + data=_chat_data(f"reach me at {EMAIL}"), + call_type="completion", + ) + assert exc.value.__cause__ is None + assert exc.value.__suppress_context__ is True + assert EMAIL not in str(exc.value) + assert not hasattr(exc.value, "status_code") # engine fault -> 500, by design + + async def test_fail_open_log_carries_no_pii(self, monkeypatch, caplog): + import logging + + guardrail = DataFogGuardrail(guardrail_name="datafog-pii", fail_policy="open") + monkeypatch.setattr( + "datafog.integrations.litellm_guardrail._redact_text", + lambda *a, **k: (_ for _ in ()).throw( + RuntimeError(f"parser choked on: reach me at {EMAIL}") + ), + ) + with caplog.at_level(logging.WARNING): + await guardrail.async_pre_call_hook( + user_api_key_dict=None, + cache=None, + data=_chat_data(f"reach me at {EMAIL}"), + call_type="completion", + ) + assert EMAIL not in caplog.text + assert "RuntimeError" in caplog.text