164 lines
5.5 KiB
Python
164 lines
5.5 KiB
Python
"""
|
|
Tests for Ollama Client.
|
|
|
|
Tests the local LLM inference client.
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import Mock, patch, MagicMock, AsyncMock
|
|
|
|
from src.clients.ollama_client import (
|
|
OllamaClient,
|
|
LLMResponse,
|
|
get_ollama_client
|
|
)
|
|
|
|
|
|
class TestOllamaClient:
|
|
"""Tests for OllamaClient class."""
|
|
|
|
@pytest.fixture
|
|
def client(self):
|
|
"""Create an Ollama client for testing."""
|
|
return OllamaClient()
|
|
|
|
def test_init_defaults(self, client):
|
|
"""Test default initialization."""
|
|
assert "localhost" in client._base_url
|
|
assert client._triage_model is not None
|
|
assert client._specialist_model is not None
|
|
|
|
def test_llm_response_dataclass(self):
|
|
"""Test LLMResponse dataclass."""
|
|
response = LLMResponse(
|
|
content="This is the model response",
|
|
model="llama3.2:1b",
|
|
total_tokens=150,
|
|
eval_duration_ms=500,
|
|
prompt_eval_count=50
|
|
)
|
|
|
|
assert response.content == "This is the model response"
|
|
assert response.total_tokens == 150
|
|
assert response.eval_duration_ms == 500
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check_success(self, client):
|
|
"""Test successful health check."""
|
|
mock_response = Mock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {"models": [{"name": "llama3.2:1b"}]}
|
|
|
|
with patch.object(client._client, 'get', new_callable=AsyncMock) as mock_get:
|
|
mock_get.return_value = mock_response
|
|
|
|
result = await client.health_check()
|
|
|
|
assert result is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check_failure(self, client):
|
|
"""Test health check failure."""
|
|
with patch.object(client._client, 'get', new_callable=AsyncMock) as mock_get:
|
|
mock_get.side_effect = Exception("Connection refused")
|
|
|
|
result = await client.health_check()
|
|
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_models_success(self, client):
|
|
"""Test listing available models."""
|
|
mock_response = Mock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"models": [
|
|
{"name": "llama3.2:1b"},
|
|
{"name": "llama3.1:8b"}
|
|
]
|
|
}
|
|
|
|
with patch.object(client._client, 'get', new_callable=AsyncMock) as mock_get:
|
|
mock_get.return_value = mock_response
|
|
|
|
models = await client.list_models()
|
|
|
|
assert len(models) == 2
|
|
assert "llama3.2:1b" in models
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_triage_success(self, client):
|
|
"""Test triage model generation."""
|
|
mock_response = Mock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"response": "Extracted: client=OESTEPAN, problem=server down",
|
|
"model": "llama3.2:1b",
|
|
"eval_count": 20,
|
|
"prompt_eval_count": 50,
|
|
"eval_duration": 500_000_000 # nanoseconds
|
|
}
|
|
|
|
with patch.object(client._client, 'post', new_callable=AsyncMock) as mock_post:
|
|
mock_post.return_value = mock_response
|
|
|
|
result = await client.generate_triage("Extract entities from: server is down")
|
|
|
|
assert result is not None
|
|
assert "OESTEPAN" in result.content
|
|
assert result.model == "llama3.2:1b"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_specialist_success(self, client):
|
|
"""Test specialist model generation."""
|
|
mock_response = Mock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"response": "Based on the analysis, the root cause is disk space exhaustion.",
|
|
"model": "llama3.1:8b",
|
|
"eval_count": 100,
|
|
"prompt_eval_count": 200,
|
|
"eval_duration": 2_000_000_000
|
|
}
|
|
|
|
with patch.object(client._client, 'post', new_callable=AsyncMock) as mock_post:
|
|
mock_post.return_value = mock_response
|
|
|
|
result = await client.generate_specialist(
|
|
"Analyze this problem with context...",
|
|
system_prompt="You are a technical support specialist."
|
|
)
|
|
|
|
assert result is not None
|
|
assert "root cause" in result.content
|
|
assert result.model == "llama3.1:8b"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_error_handling(self, client):
|
|
"""Test error handling in generation."""
|
|
mock_response = Mock()
|
|
mock_response.status_code = 500
|
|
mock_response.text = "Internal server error"
|
|
|
|
with patch.object(client._client, 'post', new_callable=AsyncMock) as mock_post:
|
|
mock_post.return_value = mock_response
|
|
|
|
result = await client.generate_triage("Test prompt")
|
|
|
|
assert result is None
|
|
|
|
|
|
class TestOllamaSingleton:
|
|
"""Tests for singleton instance."""
|
|
|
|
def test_get_ollama_client_singleton(self):
|
|
"""Test singleton returns same instance."""
|
|
# Reset singleton
|
|
import src.clients.ollama_client as module
|
|
module._ollama_client = None
|
|
|
|
client1 = get_ollama_client()
|
|
client2 = get_ollama_client()
|
|
|
|
assert client1 is client2
|