minions-ai-agents/src/clients/ollama_client.py

281 lines
8.4 KiB
Python

"""
Ollama Client for Arthur Agent.
Provides local LLM inference for both triage (1B) and specialist (8B) models.
Optimized for CPU-only operation per PRD.
"""
import logging
from typing import Optional, AsyncIterator
from dataclasses import dataclass
import httpx
from src.config import Config
logger = logging.getLogger("ArthurOllama")
@dataclass
class LLMResponse:
"""Response from LLM inference."""
content: str
model: str
total_tokens: int
eval_duration_ms: int
prompt_eval_count: int
class OllamaClient:
"""
Client for Ollama local LLM inference.
Supports two models per PRD:
- Triage (1B): Fast extraction and classification
- Specialist (8B): Deep reasoning and response generation
"""
def __init__(self):
"""Initialize Ollama client from config."""
config = Config.get_llm_config()
self._base_url = config.ollama_base_url
self._triage_model = config.triage_model
self._specialist_model = config.specialist_model
self._triage_context = config.triage_context
self._specialist_context = config.specialist_context
# HTTP client with longer timeout for LLM (CPU inference can be slow)
self._client = httpx.AsyncClient(
base_url=self._base_url,
timeout=httpx.Timeout(300.0, connect=10.0) # 5 minutes for 8B on CPU
)
async def health_check(self) -> bool:
"""
Check if Ollama server is running.
Returns:
True if healthy, False otherwise
"""
try:
response = await self._client.get("/api/tags")
if response.status_code == 200:
models = response.json().get("models", [])
logger.info(f"Ollama healthy. Available models: {len(models)}")
return True
return False
except Exception as e:
logger.error(f"Ollama health check failed: {e}")
return False
async def list_models(self) -> list[str]:
"""List available models in Ollama."""
try:
response = await self._client.get("/api/tags")
if response.status_code == 200:
models = response.json().get("models", [])
return [m["name"] for m in models]
return []
except Exception as e:
logger.error(f"Failed to list models: {e}")
return []
async def generate_triage(
self,
prompt: str,
system_prompt: Optional[str] = None
) -> Optional[LLMResponse]:
"""
Generate response using triage model (1B - fast).
Used for:
- Entity extraction (client, technology, problem)
- Initial classification
- Tool selection
Args:
prompt: User prompt
system_prompt: System instructions
Returns:
LLMResponse or None if failed
"""
return await self._generate(
model=self._triage_model,
prompt=prompt,
system_prompt=system_prompt,
num_ctx=self._triage_context
)
async def generate_specialist(
self,
prompt: str,
system_prompt: Optional[str] = None
) -> Optional[LLMResponse]:
"""
Generate response using specialist model (8B - reasoning).
Used for:
- Root cause analysis
- Technical diagnosis
- Response generation
Args:
prompt: User prompt with enriched context
system_prompt: System instructions
Returns:
LLMResponse or None if failed
"""
return await self._generate(
model=self._specialist_model,
prompt=prompt,
system_prompt=system_prompt,
num_ctx=self._specialist_context
)
async def _generate(
self,
model: str,
prompt: str,
system_prompt: Optional[str] = None,
num_ctx: int = 2048
) -> Optional[LLMResponse]:
"""
Core generation method.
Args:
model: Model name
prompt: User prompt
system_prompt: System instructions
num_ctx: Context window size
Returns:
LLMResponse or None if failed
"""
try:
payload = {
"model": model,
"prompt": prompt,
"stream": False,
"options": {
"num_ctx": num_ctx,
"temperature": 0.3, # Lower for more deterministic
"top_p": 0.9,
}
}
if system_prompt:
payload["system"] = system_prompt
response = await self._client.post("/api/generate", json=payload)
if response.status_code != 200:
logger.error(f"Ollama error: {response.status_code} - {response.text}")
return None
data = response.json()
return LLMResponse(
content=data.get("response", ""),
model=data.get("model", model),
total_tokens=data.get("eval_count", 0) + data.get("prompt_eval_count", 0),
eval_duration_ms=int(data.get("eval_duration", 0) / 1_000_000),
prompt_eval_count=data.get("prompt_eval_count", 0)
)
except httpx.TimeoutException:
logger.error(f"Ollama timeout for model {model}")
return None
except Exception as e:
logger.error(f"Ollama generation failed: {e}")
return None
async def generate_stream(
self,
model: str,
prompt: str,
system_prompt: Optional[str] = None
) -> AsyncIterator[str]:
"""
Stream response tokens.
Args:
model: Model name
prompt: User prompt
system_prompt: System instructions
Yields:
Response tokens as they are generated
"""
try:
payload = {
"model": model,
"prompt": prompt,
"stream": True,
}
if system_prompt:
payload["system"] = system_prompt
async with self._client.stream(
"POST", "/api/generate", json=payload
) as response:
async for line in response.aiter_lines():
if line:
import json
data = json.loads(line)
if token := data.get("response"):
yield token
if data.get("done"):
break
except Exception as e:
logger.error(f"Streaming failed: {e}")
async def get_embeddings(self, text: str) -> list[float]:
"""
Generate embeddings for text using the triage model (1B) or dedicated embedding model.
Args:
text: Input text
Returns:
List of floats representing the embedding
"""
try:
# Using triage model for embeddings as it's smaller/faster
# In production could use a specific embedding model (e.g. bge-m3)
payload = {
"model": self._triage_model,
"prompt": text,
}
response = await self._client.post("/api/embeddings", json=payload)
if response.status_code != 200:
logger.error(f"Ollama embedding error: {response.status_code} - {response.text}")
return []
data = response.json()
return data.get("embedding", [])
except Exception as e:
logger.error(f"Embedding generation failed: {e}")
return []
async def close(self) -> None:
"""Close HTTP client."""
await self._client.aclose()
# Singleton instance
_ollama_client: Optional[OllamaClient] = None
def get_ollama_client() -> OllamaClient:
"""Get global Ollama client instance."""
global _ollama_client
if _ollama_client is None:
_ollama_client = OllamaClient()
return _ollama_client