""" Ollama Client for Arthur Agent. Provides local LLM inference for both triage (1B) and specialist (8B) models. Optimized for CPU-only operation per PRD. """ import logging from typing import Optional, AsyncIterator from dataclasses import dataclass import httpx from src.config import Config logger = logging.getLogger("ArthurOllama") @dataclass class LLMResponse: """Response from LLM inference.""" content: str model: str total_tokens: int eval_duration_ms: int prompt_eval_count: int class OllamaClient: """ Client for Ollama local LLM inference. Supports two models per PRD: - Triage (1B): Fast extraction and classification - Specialist (8B): Deep reasoning and response generation """ def __init__(self): """Initialize Ollama client from config.""" config = Config.get_llm_config() self._base_url = config.ollama_base_url self._triage_model = config.triage_model self._specialist_model = config.specialist_model self._triage_context = config.triage_context self._specialist_context = config.specialist_context # HTTP client with longer timeout for LLM self._client = httpx.AsyncClient( base_url=self._base_url, timeout=httpx.Timeout(120.0, connect=10.0) ) async def health_check(self) -> bool: """ Check if Ollama server is running. Returns: True if healthy, False otherwise """ try: response = await self._client.get("/api/tags") if response.status_code == 200: models = response.json().get("models", []) logger.info(f"Ollama healthy. Available models: {len(models)}") return True return False except Exception as e: logger.error(f"Ollama health check failed: {e}") return False async def list_models(self) -> list[str]: """List available models in Ollama.""" try: response = await self._client.get("/api/tags") if response.status_code == 200: models = response.json().get("models", []) return [m["name"] for m in models] return [] except Exception as e: logger.error(f"Failed to list models: {e}") return [] async def generate_triage( self, prompt: str, system_prompt: Optional[str] = None ) -> Optional[LLMResponse]: """ Generate response using triage model (1B - fast). Used for: - Entity extraction (client, technology, problem) - Initial classification - Tool selection Args: prompt: User prompt system_prompt: System instructions Returns: LLMResponse or None if failed """ return await self._generate( model=self._triage_model, prompt=prompt, system_prompt=system_prompt, num_ctx=self._triage_context ) async def generate_specialist( self, prompt: str, system_prompt: Optional[str] = None ) -> Optional[LLMResponse]: """ Generate response using specialist model (8B - reasoning). Used for: - Root cause analysis - Technical diagnosis - Response generation Args: prompt: User prompt with enriched context system_prompt: System instructions Returns: LLMResponse or None if failed """ return await self._generate( model=self._specialist_model, prompt=prompt, system_prompt=system_prompt, num_ctx=self._specialist_context ) async def _generate( self, model: str, prompt: str, system_prompt: Optional[str] = None, num_ctx: int = 2048 ) -> Optional[LLMResponse]: """ Core generation method. Args: model: Model name prompt: User prompt system_prompt: System instructions num_ctx: Context window size Returns: LLMResponse or None if failed """ try: payload = { "model": model, "prompt": prompt, "stream": False, "options": { "num_ctx": num_ctx, "temperature": 0.3, # Lower for more deterministic "top_p": 0.9, } } if system_prompt: payload["system"] = system_prompt response = await self._client.post("/api/generate", json=payload) if response.status_code != 200: logger.error(f"Ollama error: {response.status_code} - {response.text}") return None data = response.json() return LLMResponse( content=data.get("response", ""), model=data.get("model", model), total_tokens=data.get("eval_count", 0) + data.get("prompt_eval_count", 0), eval_duration_ms=int(data.get("eval_duration", 0) / 1_000_000), prompt_eval_count=data.get("prompt_eval_count", 0) ) except httpx.TimeoutException: logger.error(f"Ollama timeout for model {model}") return None except Exception as e: logger.error(f"Ollama generation failed: {e}") return None async def generate_stream( self, model: str, prompt: str, system_prompt: Optional[str] = None ) -> AsyncIterator[str]: """ Stream response tokens. Args: model: Model name prompt: User prompt system_prompt: System instructions Yields: Response tokens as they are generated """ try: payload = { "model": model, "prompt": prompt, "stream": True, } if system_prompt: payload["system"] = system_prompt async with self._client.stream( "POST", "/api/generate", json=payload ) as response: async for line in response.aiter_lines(): if line: import json data = json.loads(line) if token := data.get("response"): yield token if data.get("done"): break except Exception as e: logger.error(f"Streaming failed: {e}") async def get_embeddings(self, text: str) -> list[float]: """ Generate embeddings for text using the triage model (1B) or dedicated embedding model. Args: text: Input text Returns: List of floats representing the embedding """ try: # Using triage model for embeddings as it's smaller/faster # In production could use a specific embedding model (e.g. bge-m3) payload = { "model": self._triage_model, "prompt": text, } response = await self._client.post("/api/embeddings", json=payload) if response.status_code != 200: logger.error(f"Ollama embedding error: {response.status_code} - {response.text}") return [] data = response.json() return data.get("embedding", []) except Exception as e: logger.error(f"Embedding generation failed: {e}") return [] async def close(self) -> None: """Close HTTP client.""" await self._client.aclose() # Singleton instance _ollama_client: Optional[OllamaClient] = None def get_ollama_client() -> OllamaClient: """Get global Ollama client instance.""" global _ollama_client if _ollama_client is None: _ollama_client = OllamaClient() return _ollama_client