Adding a New Guardrail Integration
You're going to create a class that checks text before it goes to the LLM or after it comes back. If it violates your rules, you block it.
How It Works​
Request with guardrail:
curl --location 'http://localhost:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "gpt-4",
"messages": [{"role": "user", "content": "How do I hack a system?"}],
"guardrails": ["my-guardrail"]
}'
Your guardrail checks input, then output. If something's wrong, raise an exception.
Build Your Guardrail​
Create Your Directory​
mkdir -p litellm/proxy/guardrails/guardrail_hooks/my_guardrail
cd litellm/proxy/guardrails/guardrail_hooks/my_guardrail
Two files: my_guardrail.py (main class) and __init__.py (initialization).
Write the Main Class​
my_guardrail.py:
import os
from typing import Optional, List
from fastapi import HTTPException
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.types.guardrails import PiiEntityType
from litellm._logging import verbose_proxy_logger
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
class MyGuardrail(CustomGuardrail):
def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs):
self.api_key = api_key or os.getenv("MY_GUARDRAIL_API_KEY")
self.api_base = api_base or os.getenv("MY_GUARDRAIL_API_BASE", "https://api.myguardrail.com")
super().__init__(default_on=True)
async def apply_guardrail(
self,
text: str,
language: Optional[str] = None,
entities: Optional[List[PiiEntityType]] = None,
request_data: Optional[dict] = None,
) -> str:
result = await self._check_with_api(text, request_data)
if result.get("action") == "BLOCK":
raise Exception(f"Content blocked: {result.get('reason', 'Policy violation')}")
return text
async def _check_with_api(self, text: str, request_data: Optional[dict]) -> dict:
async_client = get_async_httpx_client(llm_provider=httpxSpecialProvider.LoggingCallback)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
response = await async_client.post(
f"{self.api_base}/check",
headers=headers,
json={"text": text},
timeout=5,
)
response.raise_for_status()
return response.json()
Create the Init File​
__init__.py:
from typing import TYPE_CHECKING
from litellm.types.guardrails import SupportedGuardrailIntegrations
from .my_guardrail import MyGuardrail
if TYPE_CHECKING:
from litellm.types.guardrails import Guardrail, LitellmParams
def initialize_guardrail(litellm_params: "LitellmParams", guardrail: "Guardrail"):
import litellm
_my_guardrail_callback = MyGuardrail(
api_base=litellm_params.api_base,
api_key=litellm_params.api_key,
guardrail_name=guardrail.get("guardrail_name", ""),
event_hook=litellm_params.mode,
default_on=litellm_params.default_on,
)
litellm.logging_callback_manager.add_litellm_callback(_my_guardrail_callback)
return _my_guardrail_callback
guardrail_initializer_registry = {
SupportedGuardrailIntegrations.MY_GUARDRAIL.value: initialize_guardrail,
}
guardrail_class_registry = {
SupportedGuardrailIntegrations.MY_GUARDRAIL.value: MyGuardrail,
}
Register Your Guardrail Type​
Add to litellm/types/guardrails.py:
class SupportedGuardrailIntegrations(str, Enum):
LAKERA = "lakera_prompt_injection"
APORIA = "aporia"
BEDROCK = "bedrock_guardrails"
PRESIDIO = "presidio"
ZSCALER_AI_GUARD = "zscaler_ai_guard"
MY_GUARDRAIL = "my_guardrail"
Usage​
Config File​
model_list:
- model_name: gpt-4
litellm_params:
model: gpt-4
api_key: os.environ/OPENAI_API_KEY
litellm_settings:
guardrails:
- guardrail_name: my_guardrail
litellm_params:
guardrail: my_guardrail
mode: during_call
api_key: os.environ/MY_GUARDRAIL_API_KEY
api_base: https://api.myguardrail.com
Per-Request​
curl --location 'http://localhost:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "gpt-4",
"messages": [{"role": "user", "content": "Test message"}],
"guardrails": ["my_guardrail"]
}'
Testing​
Add unit tests inside test_litellm/ folder.