Skip to content

API: Core Components

This section covers the fundamental building blocks of the SafeAgent framework.

safeagent.config

Simple configuration loader with environment variable defaults.

safeagent.governance

DataGovernanceError

Bases: Exception

Exception raised when governance policies are violated.

Source code in src/safeagent/governance.py
30
31
32
class DataGovernanceError(Exception):
    """Exception raised when governance policies are violated."""
    pass

GovernanceManager

Manages data governance policies, including encryption, auditing, retention policies, and run ID management.

Source code in src/safeagent/governance.py
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
class GovernanceManager:
    """
    Manages data governance policies, including encryption, auditing,
    retention policies, and run ID management.
    """

    def __init__(self, audit_log_path: str = "audit", retention_days: int = 30, audit_log_extension: str = "json"):
        self.audit_log_path = f"{audit_log_path}.{audit_log_extension}"
        self.retention_days = retention_days
        log_dir = os.path.dirname(self.audit_log_path)
        if log_dir: 
            os.makedirs(log_dir, exist_ok=True)
        open(self.audit_log_path, "a").close() 
        self.current_run_id = None

    def start_new_run(self) -> str:
        """Generates a new unique ID for a single, complete run of an orchestrator."""
        self.current_run_id = str(uuid.uuid4())
        return self.current_run_id

    def get_current_run_id(self) -> str:
        """Returns the ID for the current run, creating one if it doesn't exist."""
        if not self.current_run_id:
            return self.start_new_run()
        return self.current_run_id

    def encrypt(self, plaintext: str) -> str:
        """Encrypt sensitive data before storage."""
        return fernet.encrypt(plaintext.encode()).decode()

    def decrypt(self, token: str) -> str:
        """Decrypt sensitive data when needed."""
        return fernet.decrypt(token.encode()).decode()

    def audit(self, user_id: str, action: str, resource: str, metadata: Dict[str, Any] = None) -> None:
        """Write an audit log entry for data actions, including the current run_id."""
        entry = {
            "timestamp": time.time(),
            "run_id": self.get_current_run_id(), 
            "user_id": user_id,
            "action": action,
            "resource": resource,
            "metadata": metadata or {}
        }
        with open(self.audit_log_path, "a") as f:
            f.write(json.dumps(entry) + "\n")

    def tag_lineage(self, record: Dict[str, Any], source: str) -> Dict[str, Any]:
        """Attach lineage metadata to a record."""
        if "_lineage" not in record:
            record["_lineage"] = []
        record["_lineage"].append({
            "timestamp": time.time(),
            "source": source
        })
        return record

    def purge_old_logs(self) -> None:
        """Purge audit log entries older than retention period."""
        cutoff = time.time() - self.retention_days * 86400
        retained = []
        try:
            with open(self.audit_log_path, "r") as f:
                for line in f:
                    try:
                        entry = json.loads(line)
                        if entry.get("timestamp", 0) >= cutoff:
                            retained.append(line)
                    except json.JSONDecodeError:
                        logging.warning(f"Skipping malformed line in audit log: {line.strip()}")
                        continue 
        except FileNotFoundError:
            logging.info(f"Audit log file not found at {self.audit_log_path} during purge. No purging needed.")
            return

        with open(self.audit_log_path, "w") as f:
            f.writelines(retained)

audit(user_id, action, resource, metadata=None)

Write an audit log entry for data actions, including the current run_id.

Source code in src/safeagent/governance.py
68
69
70
71
72
73
74
75
76
77
78
79
def audit(self, user_id: str, action: str, resource: str, metadata: Dict[str, Any] = None) -> None:
    """Write an audit log entry for data actions, including the current run_id."""
    entry = {
        "timestamp": time.time(),
        "run_id": self.get_current_run_id(), 
        "user_id": user_id,
        "action": action,
        "resource": resource,
        "metadata": metadata or {}
    }
    with open(self.audit_log_path, "a") as f:
        f.write(json.dumps(entry) + "\n")

decrypt(token)

Decrypt sensitive data when needed.

Source code in src/safeagent/governance.py
64
65
66
def decrypt(self, token: str) -> str:
    """Decrypt sensitive data when needed."""
    return fernet.decrypt(token.encode()).decode()

encrypt(plaintext)

Encrypt sensitive data before storage.

Source code in src/safeagent/governance.py
60
61
62
def encrypt(self, plaintext: str) -> str:
    """Encrypt sensitive data before storage."""
    return fernet.encrypt(plaintext.encode()).decode()

get_current_run_id()

Returns the ID for the current run, creating one if it doesn't exist.

Source code in src/safeagent/governance.py
54
55
56
57
58
def get_current_run_id(self) -> str:
    """Returns the ID for the current run, creating one if it doesn't exist."""
    if not self.current_run_id:
        return self.start_new_run()
    return self.current_run_id

purge_old_logs()

Purge audit log entries older than retention period.

Source code in src/safeagent/governance.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def purge_old_logs(self) -> None:
    """Purge audit log entries older than retention period."""
    cutoff = time.time() - self.retention_days * 86400
    retained = []
    try:
        with open(self.audit_log_path, "r") as f:
            for line in f:
                try:
                    entry = json.loads(line)
                    if entry.get("timestamp", 0) >= cutoff:
                        retained.append(line)
                except json.JSONDecodeError:
                    logging.warning(f"Skipping malformed line in audit log: {line.strip()}")
                    continue 
    except FileNotFoundError:
        logging.info(f"Audit log file not found at {self.audit_log_path} during purge. No purging needed.")
        return

    with open(self.audit_log_path, "w") as f:
        f.writelines(retained)

start_new_run()

Generates a new unique ID for a single, complete run of an orchestrator.

Source code in src/safeagent/governance.py
49
50
51
52
def start_new_run(self) -> str:
    """Generates a new unique ID for a single, complete run of an orchestrator."""
    self.current_run_id = str(uuid.uuid4())
    return self.current_run_id

tag_lineage(record, source)

Attach lineage metadata to a record.

Source code in src/safeagent/governance.py
81
82
83
84
85
86
87
88
89
def tag_lineage(self, record: Dict[str, Any], source: str) -> Dict[str, Any]:
    """Attach lineage metadata to a record."""
    if "_lineage" not in record:
        record["_lineage"] = []
    record["_lineage"].append({
        "timestamp": time.time(),
        "source": source
    })
    return record

safeagent.llm_client

FrameworkError

Bases: Exception

Custom exception for framework-related errors.

Source code in src/safeagent/llm_client.py
13
14
15
class FrameworkError(Exception):
    """Custom exception for framework-related errors."""
    pass

LLMClient

Thin wrapper around any LLM provider with retries, error handling, and structured JSON logging.

Source code in src/safeagent/llm_client.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
class LLMClient:
    """Thin wrapper around any LLM provider with retries, error handling, and structured JSON logging."""

    def __init__(self, provider: str, api_key: str, model: str, base_url: str = None):
        """
        Initialize the LLM client.

        Args:
            provider (str): Name of the provider (e.g., 'openai', 'anthropic').
            api_key (str): API key or token for authentication.
            model (str): Model identifier (e.g., 'gpt-4', 'claude-3-opus').
            base_url (str, optional): Custom endpoint URL; defaults to provider-specific default.
        """
        self.provider = provider
        self.api_key = api_key
        self.model = model
        self.base_url = base_url or self._default_url()
        if requests is not None:
            self.session = requests.Session()
        else:
            class _DummySession:
                def __init__(self):
                    self.headers = {}

                def post(self, *_, **__):
                    raise FrameworkError("requests package is required for HTTP calls")

            self.session = _DummySession()
        self.session.headers.update({
            "Content-Type": "application/json"
        })
        if self.provider != "gemini":
            self.session.headers["Authorization"] = f"Bearer {self.api_key}"
        self.gov = GovernanceManager()

    def _default_url(self) -> str:
        """Return default endpoint URL based on provider."""
        if self.provider == "openai":
            return "https://api.openai.com/v1/chat/completions"
        if self.provider == "anthropic":
            return "https://api.anthropic.com/v1/complete"
        if self.provider == "gemini":
            return f"https://generativelanguage.googleapis.com/v1beta/models/{self.model}:generateContent?key={self.api_key}"
        raise FrameworkError(f"No default URL configured for provider '{self.provider}'")

    def generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7) -> Dict:
        """
        Call the underlying LLM API, with up to 3 retries.

        Args:
            prompt (str): The textual prompt to send to the model.
            max_tokens (int): Maximum number of tokens in the response.
            temperature (float): Sampling temperature.

        Returns:
            Dict: A dictionary containing keys 'text', 'usage', and 'metadata'.

        Raises:
            FrameworkError: If the API fails after retries.
        """
        # Encrypt the prompt before logging
        encrypted_prompt = self.gov.encrypt(prompt)
        self.gov.audit(user_id="system", action="encrypt_prompt", resource="llm_client", metadata={"prompt_enc": encrypted_prompt[:50]})
        payload = self._build_payload(prompt, max_tokens, temperature)

        # Log start of LLM call and audit
        req_id = get_request_id()
        log_entry_start = {
            "event": "llm_call_start",
            "provider": self.provider,
            "model": self.model,
            "prompt_snippet": prompt[:100],
            "request_id": req_id,
            "timestamp": time.time(),
        }
        logging.info(json.dumps(log_entry_start))
        self.gov.audit(
            user_id="system",
            action="llm_call_start",
            resource=self.provider,
            metadata={"model": self.model, "request_id": req_id},
        )

        # Attempt with exponential backoff
        for attempt in range(3):
            try:
                resp = self.session.post(self.base_url, json=payload, timeout=30)
                if resp.status_code != 200:
                    raise FrameworkError(f"LLM returned status {resp.status_code}: {resp.text}")
                data = resp.json()
                text, usage = self._parse_response(data)

                # Log end of LLM call and audit
                log_entry_end = {
                    "event": "llm_call_end",
                    "provider": self.provider,
                    "model": self.model,
                    "usage": usage,
                    "request_id": req_id,
                    "timestamp": time.time(),
                }
                logging.info(json.dumps(log_entry_end))
                self.gov.audit(
                    user_id="system",
                    action="llm_call_end",
                    resource=self.provider,
                    metadata={"model": self.model, "usage": usage, "request_id": req_id},
                )

                return {"text": text, "usage": usage, "metadata": {"provider": self.provider, "model": self.model}}

            except Exception as e:
                wait = 2 ** attempt
                logging.warning(f"LLM call failed (attempt {attempt + 1}): {e}. Retrying in {wait}s")
                time.sleep(wait)

        raise FrameworkError("LLM generate() failed after 3 attempts")

    def _build_payload(self, prompt: str, max_tokens: int, temperature: float) -> Dict:
        """Construct provider-specific payload for the API call."""
        if self.provider == "openai":
            return {
                "model": self.model,
                "messages": [{"role": "user", "content": prompt}],
                "max_tokens": max_tokens,
                "temperature": temperature
            }
        if self.provider == "anthropic":
            return {
                "model": self.model,
                "prompt": prompt,
                "max_tokens_to_sample": max_tokens,
                "temperature": temperature
            }
        if self.provider == "gemini":
            return {
                "contents": [{"parts": [{"text": prompt}]}],
                "generationConfig": {"maxOutputTokens": max_tokens, "temperature": temperature}
            }
        raise FrameworkError(f"Payload builder not implemented for '{self.provider}'")

    def _parse_response(self, data: Dict) -> (str, Dict):
        """Extract generated text and usage info from API response."""
        if self.provider == "openai":
            choice = data.get("choices", [])[0]
            return choice.get("message", {}).get("content", ""), data.get("usage", {})
        if self.provider == "anthropic":
            return data.get("completion", ""), {
                "prompt_tokens": data.get("prompt_tokens"),
                "completion_tokens": data.get("completion_tokens")
            }
        if self.provider == "gemini":
            text = (
                data.get("candidates", [{}])[0]
                .get("content", {})
                .get("parts", [{}])[0]
                .get("text", "")
            )
            usage = data.get("usageMetadata", {})
            return text, {
                "prompt_tokens": usage.get("promptTokenCount"),
                "completion_tokens": usage.get("candidatesTokenCount"),
            }
        raise FrameworkError(f"Response parser not implemented for '{self.provider}'")

__init__(provider, api_key, model, base_url=None)

Initialize the LLM client.

Parameters:

Name Type Description Default
provider str

Name of the provider (e.g., 'openai', 'anthropic').

required
api_key str

API key or token for authentication.

required
model str

Model identifier (e.g., 'gpt-4', 'claude-3-opus').

required
base_url str

Custom endpoint URL; defaults to provider-specific default.

None
Source code in src/safeagent/llm_client.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def __init__(self, provider: str, api_key: str, model: str, base_url: str = None):
    """
    Initialize the LLM client.

    Args:
        provider (str): Name of the provider (e.g., 'openai', 'anthropic').
        api_key (str): API key or token for authentication.
        model (str): Model identifier (e.g., 'gpt-4', 'claude-3-opus').
        base_url (str, optional): Custom endpoint URL; defaults to provider-specific default.
    """
    self.provider = provider
    self.api_key = api_key
    self.model = model
    self.base_url = base_url or self._default_url()
    if requests is not None:
        self.session = requests.Session()
    else:
        class _DummySession:
            def __init__(self):
                self.headers = {}

            def post(self, *_, **__):
                raise FrameworkError("requests package is required for HTTP calls")

        self.session = _DummySession()
    self.session.headers.update({
        "Content-Type": "application/json"
    })
    if self.provider != "gemini":
        self.session.headers["Authorization"] = f"Bearer {self.api_key}"
    self.gov = GovernanceManager()

generate(prompt, max_tokens=512, temperature=0.7)

Call the underlying LLM API, with up to 3 retries.

Parameters:

Name Type Description Default
prompt str

The textual prompt to send to the model.

required
max_tokens int

Maximum number of tokens in the response.

512
temperature float

Sampling temperature.

0.7

Returns:

Name Type Description
Dict Dict

A dictionary containing keys 'text', 'usage', and 'metadata'.

Raises:

Type Description
FrameworkError

If the API fails after retries.

Source code in src/safeagent/llm_client.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
def generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7) -> Dict:
    """
    Call the underlying LLM API, with up to 3 retries.

    Args:
        prompt (str): The textual prompt to send to the model.
        max_tokens (int): Maximum number of tokens in the response.
        temperature (float): Sampling temperature.

    Returns:
        Dict: A dictionary containing keys 'text', 'usage', and 'metadata'.

    Raises:
        FrameworkError: If the API fails after retries.
    """
    # Encrypt the prompt before logging
    encrypted_prompt = self.gov.encrypt(prompt)
    self.gov.audit(user_id="system", action="encrypt_prompt", resource="llm_client", metadata={"prompt_enc": encrypted_prompt[:50]})
    payload = self._build_payload(prompt, max_tokens, temperature)

    # Log start of LLM call and audit
    req_id = get_request_id()
    log_entry_start = {
        "event": "llm_call_start",
        "provider": self.provider,
        "model": self.model,
        "prompt_snippet": prompt[:100],
        "request_id": req_id,
        "timestamp": time.time(),
    }
    logging.info(json.dumps(log_entry_start))
    self.gov.audit(
        user_id="system",
        action="llm_call_start",
        resource=self.provider,
        metadata={"model": self.model, "request_id": req_id},
    )

    # Attempt with exponential backoff
    for attempt in range(3):
        try:
            resp = self.session.post(self.base_url, json=payload, timeout=30)
            if resp.status_code != 200:
                raise FrameworkError(f"LLM returned status {resp.status_code}: {resp.text}")
            data = resp.json()
            text, usage = self._parse_response(data)

            # Log end of LLM call and audit
            log_entry_end = {
                "event": "llm_call_end",
                "provider": self.provider,
                "model": self.model,
                "usage": usage,
                "request_id": req_id,
                "timestamp": time.time(),
            }
            logging.info(json.dumps(log_entry_end))
            self.gov.audit(
                user_id="system",
                action="llm_call_end",
                resource=self.provider,
                metadata={"model": self.model, "usage": usage, "request_id": req_id},
            )

            return {"text": text, "usage": usage, "metadata": {"provider": self.provider, "model": self.model}}

        except Exception as e:
            wait = 2 ** attempt
            logging.warning(f"LLM call failed (attempt {attempt + 1}): {e}. Retrying in {wait}s")
            time.sleep(wait)

    raise FrameworkError("LLM generate() failed after 3 attempts")

safeagent.memory_manager

MemoryManager

Minimal key-value memory store. Supports 'inmemory' or 'redis' backends and logs each read/write. Optionally, can summarize entire memory via an LLM.

Source code in src/safeagent/memory_manager.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
class MemoryManager:
    """
    Minimal key-value memory store.
    Supports 'inmemory' or 'redis' backends and logs each read/write.
    Optionally, can summarize entire memory via an LLM.
    """

    def __init__(self, backend: str = "inmemory", redis_url: str = None):
        """
        backend: "inmemory" (default) or "redis".
        redis_url: e.g., "redis://localhost:6379" if backend="redis".
        """
        global _redis
        self.backend = backend

        if self.backend == "redis":
            if _redis is None:
                try:
                    import redis
                    _redis = redis
                except ModuleNotFoundError:
                    logging.error("Redis backend selected, but 'redis' package not found. Falling back to in-memory.")
                    self.backend = "inmemory" 
                    self.store = {}
                    return

            if _redis: 
                self.client = _redis.from_url(redis_url)
                try:
                    self.client.ping()
                    logging.info("Successfully connected to Redis.")
                except Exception as e:
                    logging.error(f"Failed to connect to Redis at {redis_url}: {e}. Falling back to in-memory.")
                    self.backend = "inmemory"
                    self.store = {}
            else:
                logging.error("Redis package not available. Falling back to in-memory.")
                self.backend = "inmemory"
                self.store = {}

        if self.backend == "inmemory":
            self.store = {} 

    def save(self, user_id: str, key: str, value: str) -> None:
        """Saves value under (user_id, key)."""
        if self.backend == "redis":
            self.client.hset(user_id, key, value)
        else:
            self.store.setdefault(user_id, {})[key] = value

        logging.info(json.dumps({
            "event": "memory_save",
            "user_id": user_id,
            "key": key,
            "request_id": get_request_id(),
            "timestamp": time.time(),
        }))

    def load(self, user_id: str, key: str) -> str:
        """Loads value for (user_id, key). Returns empty string if missing."""
        if self.backend == "redis":
            raw = self.client.hget(user_id, key)
            if isinstance(raw, bytes):
                value = raw.decode("utf-8")
            elif raw is None:
                value = ""
            else:
                value = str(raw)
        else:
            value = self.store.get(user_id, {}).get(key, "")

        logging.info(json.dumps({
            "event": "memory_load",
            "user_id": user_id,
            "key": key,
            "request_id": get_request_id(),
            "timestamp": time.time(),
        }))
        return value

    def summarize(self, user_id: str, embed_fn, llm_client, max_tokens: int = 256) -> str:
        """
        Reads all entries for user_id, concatenates them, and calls LLM to generate a summary.
        Stores the summary under key="summary" and returns it.
        """
        if self.backend == "redis":
            # Ensure proper handling if client failed to initialize or connection dropped
            try:
                all_vals = [v.decode("utf-8") for v in self.client.hvals(user_id)]
            except Exception as e:
                logging.warning(f"Could not retrieve from Redis during summarize: {e}. Using empty history.")
                all_vals = []
        else:
            all_vals = list(self.store.get(user_id, {}).values())

        full_text = "\n".join(all_vals)
        if not full_text:
            return ""

        summary_prompt = f"Summarize the following conversation history:\n\n{full_text}"
        resp = llm_client.generate(summary_prompt, max_tokens=max_tokens)
        summary = resp["text"]

        # Save summary back to memory
        self.save(user_id, "summary", summary)
        return summary

__init__(backend='inmemory', redis_url=None)

backend: "inmemory" (default) or "redis". redis_url: e.g., "redis://localhost:6379" if backend="redis".

Source code in src/safeagent/memory_manager.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def __init__(self, backend: str = "inmemory", redis_url: str = None):
    """
    backend: "inmemory" (default) or "redis".
    redis_url: e.g., "redis://localhost:6379" if backend="redis".
    """
    global _redis
    self.backend = backend

    if self.backend == "redis":
        if _redis is None:
            try:
                import redis
                _redis = redis
            except ModuleNotFoundError:
                logging.error("Redis backend selected, but 'redis' package not found. Falling back to in-memory.")
                self.backend = "inmemory" 
                self.store = {}
                return

        if _redis: 
            self.client = _redis.from_url(redis_url)
            try:
                self.client.ping()
                logging.info("Successfully connected to Redis.")
            except Exception as e:
                logging.error(f"Failed to connect to Redis at {redis_url}: {e}. Falling back to in-memory.")
                self.backend = "inmemory"
                self.store = {}
        else:
            logging.error("Redis package not available. Falling back to in-memory.")
            self.backend = "inmemory"
            self.store = {}

    if self.backend == "inmemory":
        self.store = {} 

load(user_id, key)

Loads value for (user_id, key). Returns empty string if missing.

Source code in src/safeagent/memory_manager.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def load(self, user_id: str, key: str) -> str:
    """Loads value for (user_id, key). Returns empty string if missing."""
    if self.backend == "redis":
        raw = self.client.hget(user_id, key)
        if isinstance(raw, bytes):
            value = raw.decode("utf-8")
        elif raw is None:
            value = ""
        else:
            value = str(raw)
    else:
        value = self.store.get(user_id, {}).get(key, "")

    logging.info(json.dumps({
        "event": "memory_load",
        "user_id": user_id,
        "key": key,
        "request_id": get_request_id(),
        "timestamp": time.time(),
    }))
    return value

save(user_id, key, value)

Saves value under (user_id, key).

Source code in src/safeagent/memory_manager.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def save(self, user_id: str, key: str, value: str) -> None:
    """Saves value under (user_id, key)."""
    if self.backend == "redis":
        self.client.hset(user_id, key, value)
    else:
        self.store.setdefault(user_id, {})[key] = value

    logging.info(json.dumps({
        "event": "memory_save",
        "user_id": user_id,
        "key": key,
        "request_id": get_request_id(),
        "timestamp": time.time(),
    }))

summarize(user_id, embed_fn, llm_client, max_tokens=256)

Reads all entries for user_id, concatenates them, and calls LLM to generate a summary. Stores the summary under key="summary" and returns it.

Source code in src/safeagent/memory_manager.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def summarize(self, user_id: str, embed_fn, llm_client, max_tokens: int = 256) -> str:
    """
    Reads all entries for user_id, concatenates them, and calls LLM to generate a summary.
    Stores the summary under key="summary" and returns it.
    """
    if self.backend == "redis":
        # Ensure proper handling if client failed to initialize or connection dropped
        try:
            all_vals = [v.decode("utf-8") for v in self.client.hvals(user_id)]
        except Exception as e:
            logging.warning(f"Could not retrieve from Redis during summarize: {e}. Using empty history.")
            all_vals = []
    else:
        all_vals = list(self.store.get(user_id, {}).values())

    full_text = "\n".join(all_vals)
    if not full_text:
        return ""

    summary_prompt = f"Summarize the following conversation history:\n\n{full_text}"
    resp = llm_client.generate(summary_prompt, max_tokens=max_tokens)
    summary = resp["text"]

    # Save summary back to memory
    self.save(user_id, "summary", summary)
    return summary

safeagent.prompt_renderer

PromptRenderer

Jinja2-based templating engine with structured logging and lineage tagging.

Source code in src/safeagent/prompt_renderer.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class PromptRenderer:
    """Jinja2-based templating engine with structured logging and lineage tagging."""

    def __init__(self, template_dir: Path):
        """
        Args:
            template_dir (Path): Path to the directory containing Jinja2 templates.
        """
        self.env = jinja2.Environment(
            loader=jinja2.FileSystemLoader(str(template_dir)),
            autoescape=False
        )
        self.gov = GovernanceManager()

    def render(self, template_name: str, **context) -> str:
        """
        Render a Jinja2 template with provided context, logging the event and tagging lineage.

        Args:
            template_name (str): Filename of the template (e.g., 'qa_prompt.j2').
            **context: Key-value pairs to pass into the template rendering.

        Returns:
            str: The rendered template as a string.
        """
        # Audit prompt render
        lineage_metadata = {"template": template_name, "context_keys": list(context.keys())}
        self.gov.audit(user_id="system", action="prompt_render", resource=template_name, metadata=lineage_metadata)

        template = self.env.get_template(template_name)
        rendered = template.render(**context)
        log_entry = {
            "event": "prompt_render",
            "template": template_name,
            "context_keys": list(context.keys()),
            "output_length": len(rendered),
            "timestamp": time.time()
        }
        logging.info(json.dumps(log_entry))
        return rendered

__init__(template_dir)

Parameters:

Name Type Description Default
template_dir Path

Path to the directory containing Jinja2 templates.

required
Source code in src/safeagent/prompt_renderer.py
45
46
47
48
49
50
51
52
53
54
def __init__(self, template_dir: Path):
    """
    Args:
        template_dir (Path): Path to the directory containing Jinja2 templates.
    """
    self.env = jinja2.Environment(
        loader=jinja2.FileSystemLoader(str(template_dir)),
        autoescape=False
    )
    self.gov = GovernanceManager()

render(template_name, **context)

Render a Jinja2 template with provided context, logging the event and tagging lineage.

Parameters:

Name Type Description Default
template_name str

Filename of the template (e.g., 'qa_prompt.j2').

required
**context

Key-value pairs to pass into the template rendering.

{}

Returns:

Name Type Description
str str

The rendered template as a string.

Source code in src/safeagent/prompt_renderer.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def render(self, template_name: str, **context) -> str:
    """
    Render a Jinja2 template with provided context, logging the event and tagging lineage.

    Args:
        template_name (str): Filename of the template (e.g., 'qa_prompt.j2').
        **context: Key-value pairs to pass into the template rendering.

    Returns:
        str: The rendered template as a string.
    """
    # Audit prompt render
    lineage_metadata = {"template": template_name, "context_keys": list(context.keys())}
    self.gov.audit(user_id="system", action="prompt_render", resource=template_name, metadata=lineage_metadata)

    template = self.env.get_template(template_name)
    rendered = template.render(**context)
    log_entry = {
        "event": "prompt_render",
        "template": template_name,
        "context_keys": list(context.keys()),
        "output_length": len(rendered),
        "timestamp": time.time()
    }
    logging.info(json.dumps(log_entry))
    return rendered

safeagent.embeddings

EmbeddingError

Bases: Exception

Custom exception for embedding-related failures.

Source code in src/safeagent/embeddings.py
20
21
22
class EmbeddingError(Exception):
    """Custom exception for embedding-related failures."""
    pass

gemini_embed(text, api_key, model='embedding-001')

Generates embeddings using the Google Gemini API.

This function now correctly formats the request for the embedding model, passing the API key as a URL parameter and avoiding conflicting headers.

Parameters:

Name Type Description Default
text str

The text to embed.

required
api_key str

The Google API key.

required
model str

The embedding model to use.

'embedding-001'

Returns:

Type Description
Optional[List[float]]

A list of floats representing the embedding, or None on failure.

Raises:

Type Description
EmbeddingError

If the API call fails after retries.

Source code in src/safeagent/embeddings.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def gemini_embed(text: str, api_key: str, model: str = "embedding-001") -> Optional[List[float]]:
    """
    Generates embeddings using the Google Gemini API.

    This function now correctly formats the request for the embedding model,
    passing the API key as a URL parameter and avoiding conflicting headers.

    Args:
        text (str): The text to embed.
        api_key (str): The Google API key.
        model (str): The embedding model to use.

    Returns:
        A list of floats representing the embedding, or None on failure.

    Raises:
        EmbeddingError: If the API call fails after retries.
    """
    if not api_key:
        raise EmbeddingError("Gemini API key is required for embeddings.")

    url = f"https://generativelanguage.googleapis.com/v1beta/models/{model}:embedContent?key={api_key}"

    payload = {"model": f"models/{model}", "content": {"parts": [{"text": text}]}}

    headers = {"Content-Type": "application/json"}

    try:
        resp = _session.post(url, json=payload, headers=headers, timeout=30)

        if resp.status_code != 200:
            logging.error(f"Gemini embed API request failed with status {resp.status_code}: {resp.text}")
            raise EmbeddingError(f"Gemini embed failed: {resp.text}")

        data = resp.json()
        embedding = data.get("embedding", {}).get("values")

        if not embedding:
            raise EmbeddingError("Embedding not found in Gemini API response.")

        return embedding

    except requests.exceptions.RequestException as e:
        logging.error(f"A network error occurred while calling Gemini embed API: {e}")
        raise EmbeddingError(f"Network error during embedding: {e}") from e

safeagent.retriever

BaseRetriever

Base interface for retrieval. Requires implementing index and query.

Source code in src/safeagent/retriever.py
59
60
61
62
63
64
65
class BaseRetriever:
    """Base interface for retrieval. Requires implementing index and query."""
    def index(self, embeddings: List[Any], metadata: List[Dict[str, Any]]) -> None:
        raise NotImplementedError

    def query(self, query_text: str, top_k: int = 5) -> List[Dict[str, Any]]:
        raise NotImplementedError

GraphRetriever

Bases: BaseRetriever

Neo4j-backed GraphRAG retriever using GDS k-NN, with governance integration.

Source code in src/safeagent/retriever.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
class GraphRetriever(BaseRetriever):
    """Neo4j-backed GraphRAG retriever using GDS k-NN, with governance integration."""

    def __init__(self, neo4j_uri: str, user: str, password: str, gds_graph_name: str, embed_model_fn):
        """Create the retriever. If neo4j_uri is falsy, the retriever is disabled."""
        self.driver = None
        self.gov = GovernanceManager()
        self.embed = embed_model_fn
        self.gds_graph = gds_graph_name

        if not neo4j_uri:
            logging.info("GraphRetriever is disabled because no neo4j_uri was provided.")
            return

        try:
            from neo4j import GraphDatabase, exceptions
            self.driver = GraphDatabase.driver(neo4j_uri, auth=(user, password))
            # Test the connection to fail fast
            with self.driver.session() as session:
                session.run("RETURN 1")
            logging.info("Successfully connected to Neo4j.")
        except ImportError:
            logging.warning("The 'neo4j' library is not installed. GraphRetriever will be disabled.")
            self.driver = None
        except exceptions.ServiceUnavailable:
            logging.warning(f"Could not connect to Neo4j at '{neo4j_uri}'. GraphRetriever is disabled.")
            self.driver = None
        except Exception as e:
            logging.warning(f"An unexpected error occurred while connecting to Neo4j. GraphRetriever is disabled. Error: {e}")
            self.driver = None


    def index(self, embeddings: List[List[float]], metadata: List[Dict[str, Any]]):
        """
        Ingest each document as a node with a 'vector' property and 'metadata' (with lineage tagging).
        """
        if not self.driver:
            return 

        self.gov.audit(user_id="system", action="graph_index", resource="neo4j", metadata={"count": len(embeddings)})
        with self.driver.session() as session:
            for vec, meta in zip(embeddings, metadata):
                tagged_meta = self.gov.tag_lineage(meta.copy(), source="graph_index")
                session.run(
                    "MERGE (d:Document {id: $id}) "
                    "SET d.vector = $vector, d.metadata = $meta",
                    id=meta["id"], vector=vec, meta=tagged_meta
                )
        log_entry = {
            "event": "graph_index",
            "count": len(embeddings),
            "timestamp": time.time()
        }
        logging.info(json.dumps(log_entry))

    def query(self, query_text: str, top_k: int = 5) -> List[Dict[str, Any]]:
        """
        Compute embedding for query_text, run GDS K-NN, and return nearest documents (with lineage tagging).
        """
        if not self.driver:
            return []

        # Encrypt and audit query
        encrypted_query = self.gov.encrypt(query_text)
        self.gov.audit(user_id="system", action="graph_query", resource="neo4j", metadata={"query_enc": encrypted_query[:50], "top_k": top_k})

        vec = self.embed(query_text)
        cypher = f"""
            CALL gds.knn.stream(
                '{self.gds_graph}',
                {{
                    topK: $k,
                    nodeWeightProperty: 'vector',
                    queryVector: $vector
                }}
            ) YIELD nodeId, similarity
            RETURN gds.util.asNode(nodeId).id AS id, similarity
        """
        results = []
        try:
            with self.driver.session() as session:
                for record in session.run(cypher, vector=vec, k=top_k):
                    node_id = record["id"]
                    score = record["similarity"]
                    meta_record = session.run(
                        "MATCH (d:Document {id: $id}) RETURN d.metadata AS meta", id=node_id
                    ).single()
                    if meta_record:
                        meta = meta_record["meta"]
                        tagged_meta = self.gov.tag_lineage(meta.copy(), source="graph_query")
                        results.append({"id": node_id, "score": score, "metadata": tagged_meta})
        except Exception as e:
            logging.error(f"Error querying Neo4j GDS: {e}")
            return []

        log_entry = {
            "event": "graph_query",
            "top_k": top_k,
            "timestamp": time.time()
        }
        logging.info(json.dumps(log_entry))
        return results

__init__(neo4j_uri, user, password, gds_graph_name, embed_model_fn)

Create the retriever. If neo4j_uri is falsy, the retriever is disabled.

Source code in src/safeagent/retriever.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
def __init__(self, neo4j_uri: str, user: str, password: str, gds_graph_name: str, embed_model_fn):
    """Create the retriever. If neo4j_uri is falsy, the retriever is disabled."""
    self.driver = None
    self.gov = GovernanceManager()
    self.embed = embed_model_fn
    self.gds_graph = gds_graph_name

    if not neo4j_uri:
        logging.info("GraphRetriever is disabled because no neo4j_uri was provided.")
        return

    try:
        from neo4j import GraphDatabase, exceptions
        self.driver = GraphDatabase.driver(neo4j_uri, auth=(user, password))
        # Test the connection to fail fast
        with self.driver.session() as session:
            session.run("RETURN 1")
        logging.info("Successfully connected to Neo4j.")
    except ImportError:
        logging.warning("The 'neo4j' library is not installed. GraphRetriever will be disabled.")
        self.driver = None
    except exceptions.ServiceUnavailable:
        logging.warning(f"Could not connect to Neo4j at '{neo4j_uri}'. GraphRetriever is disabled.")
        self.driver = None
    except Exception as e:
        logging.warning(f"An unexpected error occurred while connecting to Neo4j. GraphRetriever is disabled. Error: {e}")
        self.driver = None

index(embeddings, metadata)

Ingest each document as a node with a 'vector' property and 'metadata' (with lineage tagging).

Source code in src/safeagent/retriever.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def index(self, embeddings: List[List[float]], metadata: List[Dict[str, Any]]):
    """
    Ingest each document as a node with a 'vector' property and 'metadata' (with lineage tagging).
    """
    if not self.driver:
        return 

    self.gov.audit(user_id="system", action="graph_index", resource="neo4j", metadata={"count": len(embeddings)})
    with self.driver.session() as session:
        for vec, meta in zip(embeddings, metadata):
            tagged_meta = self.gov.tag_lineage(meta.copy(), source="graph_index")
            session.run(
                "MERGE (d:Document {id: $id}) "
                "SET d.vector = $vector, d.metadata = $meta",
                id=meta["id"], vector=vec, meta=tagged_meta
            )
    log_entry = {
        "event": "graph_index",
        "count": len(embeddings),
        "timestamp": time.time()
    }
    logging.info(json.dumps(log_entry))

query(query_text, top_k=5)

Compute embedding for query_text, run GDS K-NN, and return nearest documents (with lineage tagging).

Source code in src/safeagent/retriever.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def query(self, query_text: str, top_k: int = 5) -> List[Dict[str, Any]]:
    """
    Compute embedding for query_text, run GDS K-NN, and return nearest documents (with lineage tagging).
    """
    if not self.driver:
        return []

    # Encrypt and audit query
    encrypted_query = self.gov.encrypt(query_text)
    self.gov.audit(user_id="system", action="graph_query", resource="neo4j", metadata={"query_enc": encrypted_query[:50], "top_k": top_k})

    vec = self.embed(query_text)
    cypher = f"""
        CALL gds.knn.stream(
            '{self.gds_graph}',
            {{
                topK: $k,
                nodeWeightProperty: 'vector',
                queryVector: $vector
            }}
        ) YIELD nodeId, similarity
        RETURN gds.util.asNode(nodeId).id AS id, similarity
    """
    results = []
    try:
        with self.driver.session() as session:
            for record in session.run(cypher, vector=vec, k=top_k):
                node_id = record["id"]
                score = record["similarity"]
                meta_record = session.run(
                    "MATCH (d:Document {id: $id}) RETURN d.metadata AS meta", id=node_id
                ).single()
                if meta_record:
                    meta = meta_record["meta"]
                    tagged_meta = self.gov.tag_lineage(meta.copy(), source="graph_query")
                    results.append({"id": node_id, "score": score, "metadata": tagged_meta})
    except Exception as e:
        logging.error(f"Error querying Neo4j GDS: {e}")
        return []

    log_entry = {
        "event": "graph_query",
        "top_k": top_k,
        "timestamp": time.time()
    }
    logging.info(json.dumps(log_entry))
    return results

VectorRetriever

Bases: BaseRetriever

FAISS-backed vector retriever. Uses an embedding function to map text to vectors, with governance integration.

Source code in src/safeagent/retriever.py
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
class VectorRetriever(BaseRetriever):
    """FAISS-backed vector retriever. Uses an embedding function to map text to vectors, with governance integration."""
    def __init__(self, index_path: str, embed_model_fn):
        """
        Args:
            index_path (str): Filesystem path to store/load FAISS index.
            embed_model_fn (callable): Function that maps text (str) to a numpy ndarray vector.
        """
        self.embed = embed_model_fn
        self.gov = GovernanceManager()
        self.metadata_store: Dict[int, Dict[str, Any]] = {}
        self.next_id = 0
        self.index_path = index_path
        if _FAISS:
            if Path(index_path).exists():
                self._index = faiss.read_index(index_path)
            else:
                self._index = faiss.IndexFlatL2(768)
        else:
            self._index = []  # type: ignore

    def index(self, embeddings: List[np.ndarray], metadata: List[Dict[str, Any]]):
        """
        Add embeddings to the FAISS index and store metadata (with lineage tagging).

        Args:
            embeddings (List[np.ndarray]): List of vectors.
            metadata (List[Dict[str, Any]]): Corresponding metadata dicts (must include 'id').
        """
        if _FAISS:
            vectors = np.vstack(embeddings)
            self._index.add(vectors)
        else:
            for vec in embeddings:
                self._index.append(np.array(vec))
        for vec, meta in zip(embeddings, metadata):
            tagged_meta = self.gov.tag_lineage(meta.copy(), source="vector_index")
            self.metadata_store[self.next_id] = tagged_meta
            self.next_id += 1

        log_entry = {
            "event": "vector_index",
            "count": len(embeddings),
            "timestamp": time.time()
        }
        logging.info(json.dumps(log_entry))
        if _FAISS:
            faiss.write_index(self._index, self.index_path)

    def query(self, query_text: str, top_k: int = 5) -> List[Dict[str, Any]]:
        """
        Perform KNN search on the FAISS index using the embedded query, with encryption and audit.

        Args:
            query_text (str): The query string.
            top_k (int): Number of nearest neighbors to return.

        Returns:
            List[Dict[str, Any]]: Each dict contains 'id', 'score', and 'metadata'.
        """
        # Encrypt and audit query
        encrypted_query = self.gov.encrypt(query_text)
        self.gov.audit(user_id="system", action="vector_query", resource="faiss", metadata={"query_enc": encrypted_query[:50], "top_k": top_k})

        vec = self.embed(query_text)
        if _FAISS:
            distances, indices = self._index.search(np.array([vec]), top_k)
            idx_list = indices[0]
            dist_list = distances[0]
        else:
            if not self._index:
                idx_list, dist_list = [], []
            else:
                def dist(a, b):
                    return sum((ai - bi) ** 2 for ai, bi in zip(a, b)) ** 0.5

                dists = [dist(v, vec) for v in self._index]
                sorted_idx = sorted(range(len(dists)), key=lambda i: dists[i])[:top_k]
                idx_list = sorted_idx
                dist_list = [dists[i] for i in sorted_idx]
        results = []
        for idx, dist in zip(idx_list, dist_list):
            meta = self.metadata_store.get(int(idx), {})
            results.append({"id": int(idx), "score": float(dist), "metadata": meta})

        log_entry = {
            "event": "vector_query",
            "top_k": top_k,
            "timestamp": time.time()
        }
        logging.info(json.dumps(log_entry))
        return results

__init__(index_path, embed_model_fn)

Parameters:

Name Type Description Default
index_path str

Filesystem path to store/load FAISS index.

required
embed_model_fn callable

Function that maps text (str) to a numpy ndarray vector.

required
Source code in src/safeagent/retriever.py
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def __init__(self, index_path: str, embed_model_fn):
    """
    Args:
        index_path (str): Filesystem path to store/load FAISS index.
        embed_model_fn (callable): Function that maps text (str) to a numpy ndarray vector.
    """
    self.embed = embed_model_fn
    self.gov = GovernanceManager()
    self.metadata_store: Dict[int, Dict[str, Any]] = {}
    self.next_id = 0
    self.index_path = index_path
    if _FAISS:
        if Path(index_path).exists():
            self._index = faiss.read_index(index_path)
        else:
            self._index = faiss.IndexFlatL2(768)
    else:
        self._index = []  # type: ignore

index(embeddings, metadata)

Add embeddings to the FAISS index and store metadata (with lineage tagging).

Parameters:

Name Type Description Default
embeddings List[ndarray]

List of vectors.

required
metadata List[Dict[str, Any]]

Corresponding metadata dicts (must include 'id').

required
Source code in src/safeagent/retriever.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def index(self, embeddings: List[np.ndarray], metadata: List[Dict[str, Any]]):
    """
    Add embeddings to the FAISS index and store metadata (with lineage tagging).

    Args:
        embeddings (List[np.ndarray]): List of vectors.
        metadata (List[Dict[str, Any]]): Corresponding metadata dicts (must include 'id').
    """
    if _FAISS:
        vectors = np.vstack(embeddings)
        self._index.add(vectors)
    else:
        for vec in embeddings:
            self._index.append(np.array(vec))
    for vec, meta in zip(embeddings, metadata):
        tagged_meta = self.gov.tag_lineage(meta.copy(), source="vector_index")
        self.metadata_store[self.next_id] = tagged_meta
        self.next_id += 1

    log_entry = {
        "event": "vector_index",
        "count": len(embeddings),
        "timestamp": time.time()
    }
    logging.info(json.dumps(log_entry))
    if _FAISS:
        faiss.write_index(self._index, self.index_path)

query(query_text, top_k=5)

Perform KNN search on the FAISS index using the embedded query, with encryption and audit.

Parameters:

Name Type Description Default
query_text str

The query string.

required
top_k int

Number of nearest neighbors to return.

5

Returns:

Type Description
List[Dict[str, Any]]

List[Dict[str, Any]]: Each dict contains 'id', 'score', and 'metadata'.

Source code in src/safeagent/retriever.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def query(self, query_text: str, top_k: int = 5) -> List[Dict[str, Any]]:
    """
    Perform KNN search on the FAISS index using the embedded query, with encryption and audit.

    Args:
        query_text (str): The query string.
        top_k (int): Number of nearest neighbors to return.

    Returns:
        List[Dict[str, Any]]: Each dict contains 'id', 'score', and 'metadata'.
    """
    # Encrypt and audit query
    encrypted_query = self.gov.encrypt(query_text)
    self.gov.audit(user_id="system", action="vector_query", resource="faiss", metadata={"query_enc": encrypted_query[:50], "top_k": top_k})

    vec = self.embed(query_text)
    if _FAISS:
        distances, indices = self._index.search(np.array([vec]), top_k)
        idx_list = indices[0]
        dist_list = distances[0]
    else:
        if not self._index:
            idx_list, dist_list = [], []
        else:
            def dist(a, b):
                return sum((ai - bi) ** 2 for ai, bi in zip(a, b)) ** 0.5

            dists = [dist(v, vec) for v in self._index]
            sorted_idx = sorted(range(len(dists)), key=lambda i: dists[i])[:top_k]
            idx_list = sorted_idx
            dist_list = [dists[i] for i in sorted_idx]
    results = []
    for idx, dist in zip(idx_list, dist_list):
        meta = self.metadata_store.get(int(idx), {})
        results.append({"id": int(idx), "score": float(dist), "metadata": meta})

    log_entry = {
        "event": "vector_query",
        "top_k": top_k,
        "timestamp": time.time()
    }
    logging.info(json.dumps(log_entry))
    return results

register_retriever(name, cls)

Register a retriever class for dynamic loading.

Source code in src/safeagent/retriever.py
51
52
53
def register_retriever(name: str, cls):
    """Register a retriever class for dynamic loading."""
    RETRIEVER_REGISTRY[name] = cls