281 lines
8.3 KiB
Python
281 lines
8.3 KiB
Python
"""
|
|
Ollama Client for Arthur Agent.
|
|
|
|
Provides local LLM inference for both triage (1B) and specialist (8B) models.
|
|
Optimized for CPU-only operation per PRD.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Optional, AsyncIterator
|
|
from dataclasses import dataclass
|
|
import httpx
|
|
|
|
from src.config import Config
|
|
|
|
logger = logging.getLogger("ArthurOllama")
|
|
|
|
|
|
@dataclass
|
|
class LLMResponse:
|
|
"""Response from LLM inference."""
|
|
content: str
|
|
model: str
|
|
total_tokens: int
|
|
eval_duration_ms: int
|
|
prompt_eval_count: int
|
|
|
|
|
|
class OllamaClient:
|
|
"""
|
|
Client for Ollama local LLM inference.
|
|
|
|
Supports two models per PRD:
|
|
- Triage (1B): Fast extraction and classification
|
|
- Specialist (8B): Deep reasoning and response generation
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize Ollama client from config."""
|
|
config = Config.get_llm_config()
|
|
self._base_url = config.ollama_base_url
|
|
self._triage_model = config.triage_model
|
|
self._specialist_model = config.specialist_model
|
|
self._triage_context = config.triage_context
|
|
self._specialist_context = config.specialist_context
|
|
|
|
# HTTP client with longer timeout for LLM
|
|
self._client = httpx.AsyncClient(
|
|
base_url=self._base_url,
|
|
timeout=httpx.Timeout(120.0, connect=10.0)
|
|
)
|
|
|
|
async def health_check(self) -> bool:
|
|
"""
|
|
Check if Ollama server is running.
|
|
|
|
Returns:
|
|
True if healthy, False otherwise
|
|
"""
|
|
try:
|
|
response = await self._client.get("/api/tags")
|
|
if response.status_code == 200:
|
|
models = response.json().get("models", [])
|
|
logger.info(f"Ollama healthy. Available models: {len(models)}")
|
|
return True
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Ollama health check failed: {e}")
|
|
return False
|
|
|
|
async def list_models(self) -> list[str]:
|
|
"""List available models in Ollama."""
|
|
try:
|
|
response = await self._client.get("/api/tags")
|
|
if response.status_code == 200:
|
|
models = response.json().get("models", [])
|
|
return [m["name"] for m in models]
|
|
return []
|
|
except Exception as e:
|
|
logger.error(f"Failed to list models: {e}")
|
|
return []
|
|
|
|
async def generate_triage(
|
|
self,
|
|
prompt: str,
|
|
system_prompt: Optional[str] = None
|
|
) -> Optional[LLMResponse]:
|
|
"""
|
|
Generate response using triage model (1B - fast).
|
|
|
|
Used for:
|
|
- Entity extraction (client, technology, problem)
|
|
- Initial classification
|
|
- Tool selection
|
|
|
|
Args:
|
|
prompt: User prompt
|
|
system_prompt: System instructions
|
|
|
|
Returns:
|
|
LLMResponse or None if failed
|
|
"""
|
|
return await self._generate(
|
|
model=self._triage_model,
|
|
prompt=prompt,
|
|
system_prompt=system_prompt,
|
|
num_ctx=self._triage_context
|
|
)
|
|
|
|
async def generate_specialist(
|
|
self,
|
|
prompt: str,
|
|
system_prompt: Optional[str] = None
|
|
) -> Optional[LLMResponse]:
|
|
"""
|
|
Generate response using specialist model (8B - reasoning).
|
|
|
|
Used for:
|
|
- Root cause analysis
|
|
- Technical diagnosis
|
|
- Response generation
|
|
|
|
Args:
|
|
prompt: User prompt with enriched context
|
|
system_prompt: System instructions
|
|
|
|
Returns:
|
|
LLMResponse or None if failed
|
|
"""
|
|
return await self._generate(
|
|
model=self._specialist_model,
|
|
prompt=prompt,
|
|
system_prompt=system_prompt,
|
|
num_ctx=self._specialist_context
|
|
)
|
|
|
|
async def _generate(
|
|
self,
|
|
model: str,
|
|
prompt: str,
|
|
system_prompt: Optional[str] = None,
|
|
num_ctx: int = 2048
|
|
) -> Optional[LLMResponse]:
|
|
"""
|
|
Core generation method.
|
|
|
|
Args:
|
|
model: Model name
|
|
prompt: User prompt
|
|
system_prompt: System instructions
|
|
num_ctx: Context window size
|
|
|
|
Returns:
|
|
LLMResponse or None if failed
|
|
"""
|
|
try:
|
|
payload = {
|
|
"model": model,
|
|
"prompt": prompt,
|
|
"stream": False,
|
|
"options": {
|
|
"num_ctx": num_ctx,
|
|
"temperature": 0.3, # Lower for more deterministic
|
|
"top_p": 0.9,
|
|
}
|
|
}
|
|
|
|
if system_prompt:
|
|
payload["system"] = system_prompt
|
|
|
|
response = await self._client.post("/api/generate", json=payload)
|
|
|
|
if response.status_code != 200:
|
|
logger.error(f"Ollama error: {response.status_code} - {response.text}")
|
|
return None
|
|
|
|
data = response.json()
|
|
|
|
return LLMResponse(
|
|
content=data.get("response", ""),
|
|
model=data.get("model", model),
|
|
total_tokens=data.get("eval_count", 0) + data.get("prompt_eval_count", 0),
|
|
eval_duration_ms=int(data.get("eval_duration", 0) / 1_000_000),
|
|
prompt_eval_count=data.get("prompt_eval_count", 0)
|
|
)
|
|
|
|
except httpx.TimeoutException:
|
|
logger.error(f"Ollama timeout for model {model}")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Ollama generation failed: {e}")
|
|
return None
|
|
|
|
async def generate_stream(
|
|
self,
|
|
model: str,
|
|
prompt: str,
|
|
system_prompt: Optional[str] = None
|
|
) -> AsyncIterator[str]:
|
|
"""
|
|
Stream response tokens.
|
|
|
|
Args:
|
|
model: Model name
|
|
prompt: User prompt
|
|
system_prompt: System instructions
|
|
|
|
Yields:
|
|
Response tokens as they are generated
|
|
"""
|
|
try:
|
|
payload = {
|
|
"model": model,
|
|
"prompt": prompt,
|
|
"stream": True,
|
|
}
|
|
|
|
if system_prompt:
|
|
payload["system"] = system_prompt
|
|
|
|
async with self._client.stream(
|
|
"POST", "/api/generate", json=payload
|
|
) as response:
|
|
async for line in response.aiter_lines():
|
|
if line:
|
|
import json
|
|
data = json.loads(line)
|
|
if token := data.get("response"):
|
|
yield token
|
|
if data.get("done"):
|
|
break
|
|
|
|
except Exception as e:
|
|
logger.error(f"Streaming failed: {e}")
|
|
|
|
async def get_embeddings(self, text: str) -> list[float]:
|
|
"""
|
|
Generate embeddings for text using the triage model (1B) or dedicated embedding model.
|
|
|
|
Args:
|
|
text: Input text
|
|
|
|
Returns:
|
|
List of floats representing the embedding
|
|
"""
|
|
try:
|
|
# Using triage model for embeddings as it's smaller/faster
|
|
# In production could use a specific embedding model (e.g. bge-m3)
|
|
payload = {
|
|
"model": self._triage_model,
|
|
"prompt": text,
|
|
}
|
|
|
|
response = await self._client.post("/api/embeddings", json=payload)
|
|
|
|
if response.status_code != 200:
|
|
logger.error(f"Ollama embedding error: {response.status_code} - {response.text}")
|
|
return []
|
|
|
|
data = response.json()
|
|
return data.get("embedding", [])
|
|
|
|
except Exception as e:
|
|
logger.error(f"Embedding generation failed: {e}")
|
|
return []
|
|
|
|
async def close(self) -> None:
|
|
"""Close HTTP client."""
|
|
await self._client.aclose()
|
|
|
|
|
|
# Singleton instance
|
|
_ollama_client: Optional[OllamaClient] = None
|
|
|
|
|
|
def get_ollama_client() -> OllamaClient:
|
|
"""Get global Ollama client instance."""
|
|
global _ollama_client
|
|
if _ollama_client is None:
|
|
_ollama_client = OllamaClient()
|
|
return _ollama_client
|