171 lines
5.5 KiB
Python
171 lines
5.5 KiB
Python
"""
|
|
Tests for Flywheel Module - RAG Pipeline.
|
|
|
|
Tests document ingestion and processing.
|
|
"""
|
|
|
|
import pytest
|
|
import tempfile
|
|
import os
|
|
from pathlib import Path
|
|
from unittest.mock import Mock, patch
|
|
|
|
from src.flywheel.rag_pipeline import (
|
|
RAGIngestionPipeline,
|
|
DocumentMetadata,
|
|
DocumentChunk,
|
|
IngestionResult,
|
|
get_rag_pipeline
|
|
)
|
|
|
|
|
|
class TestDocumentMetadata:
|
|
"""Tests for DocumentMetadata dataclass."""
|
|
|
|
def test_default_values(self):
|
|
"""Test default metadata values."""
|
|
meta = DocumentMetadata(
|
|
filename="test.md",
|
|
filepath="/path/test.md",
|
|
technology="linux",
|
|
tenant_id="tenant-001",
|
|
doc_type="manual"
|
|
)
|
|
|
|
assert meta.language == "pt"
|
|
assert meta.version == "1.0"
|
|
assert meta.tags == []
|
|
|
|
|
|
class TestRAGPipeline:
|
|
"""Tests for RAGIngestionPipeline."""
|
|
|
|
@pytest.fixture
|
|
def pipeline(self):
|
|
"""Create pipeline with mocked Qdrant."""
|
|
with patch('src.flywheel.rag_pipeline.get_qdrant_client') as mock:
|
|
mock.return_value = Mock()
|
|
mock.return_value.upsert_document = Mock(return_value=True)
|
|
return RAGIngestionPipeline()
|
|
|
|
def test_sanitize_removes_scripts(self, pipeline):
|
|
"""Test that script tags are removed."""
|
|
content = "Normal text <script>alert('xss')</script> more text"
|
|
result = pipeline._sanitize_content(content)
|
|
|
|
assert "<script>" not in result
|
|
assert "alert" not in result
|
|
assert "Normal text" in result
|
|
|
|
def test_sanitize_removes_base64(self, pipeline):
|
|
"""Test that large base64 blocks are removed."""
|
|
content = "Data: " + "A" * 150 + " end"
|
|
result = pipeline._sanitize_content(content)
|
|
|
|
assert "[BASE64_REMOVED]" in result
|
|
|
|
def test_detect_technology_linux(self, pipeline):
|
|
"""Test Linux technology detection."""
|
|
content = "Configure the Linux server using systemctl to manage services."
|
|
tech = pipeline._detect_technology(content)
|
|
|
|
assert tech == "linux"
|
|
|
|
def test_detect_technology_docker(self, pipeline):
|
|
"""Test Docker technology detection."""
|
|
content = "Build the container using Dockerfile and run with docker compose."
|
|
tech = pipeline._detect_technology(content)
|
|
|
|
assert tech == "docker"
|
|
|
|
def test_detect_technology_network(self, pipeline):
|
|
"""Test network technology detection."""
|
|
content = "Configure the firewall rules and routing tables."
|
|
tech = pipeline._detect_technology(content)
|
|
|
|
assert tech == "network"
|
|
|
|
def test_extract_tags(self, pipeline):
|
|
"""Test tag extraction from content."""
|
|
content = "This document covers Linux server administration with nginx as reverse proxy and postgresql database."
|
|
tags = pipeline._extract_tags(content)
|
|
|
|
assert "linux" in tags
|
|
assert "nginx" in tags
|
|
assert "postgresql" in tags
|
|
|
|
def test_chunk_content(self, pipeline):
|
|
"""Test document chunking."""
|
|
meta = DocumentMetadata(
|
|
filename="test.md",
|
|
filepath="/test/test.md",
|
|
technology="general",
|
|
tenant_id="tenant-001",
|
|
doc_type="manual"
|
|
)
|
|
|
|
content = "First paragraph.\n\nSecond paragraph.\n\nThird paragraph."
|
|
chunks = pipeline._chunk_content(content, meta)
|
|
|
|
assert len(chunks) >= 1
|
|
assert all(c.metadata == meta for c in chunks)
|
|
assert chunks[0].chunk_index == 0
|
|
|
|
def test_generate_chunk_id(self, pipeline):
|
|
"""Test chunk ID generation."""
|
|
id1 = pipeline._generate_chunk_id("/path/file.md", 0)
|
|
id2 = pipeline._generate_chunk_id("/path/file.md", 1)
|
|
id3 = pipeline._generate_chunk_id("/path/file.md", 0)
|
|
|
|
assert id1 != id2
|
|
assert id1 == id3 # Same inputs should give same ID
|
|
|
|
def test_generate_embedding(self, pipeline):
|
|
"""Test embedding generation."""
|
|
emb = pipeline._generate_embedding("test content")
|
|
|
|
assert len(emb) == 384
|
|
assert all(-1 <= v <= 1 for v in emb)
|
|
|
|
def test_ingest_file_not_found(self, pipeline):
|
|
"""Test ingestion of non-existent file."""
|
|
result = pipeline.ingest_file(
|
|
filepath="/nonexistent/file.md",
|
|
tenant_id="tenant-001"
|
|
)
|
|
|
|
assert result.success is False
|
|
assert "not found" in result.error.lower()
|
|
|
|
def test_ingest_file_success(self, pipeline):
|
|
"""Test successful file ingestion."""
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False) as f:
|
|
f.write("# Test Document\n\nThis is a test about Linux servers.")
|
|
filepath = f.name
|
|
|
|
try:
|
|
result = pipeline.ingest_file(
|
|
filepath=filepath,
|
|
tenant_id="tenant-001"
|
|
)
|
|
|
|
assert result.success is True
|
|
assert result.chunks_created >= 1
|
|
finally:
|
|
os.unlink(filepath)
|
|
|
|
|
|
class TestRAGPipelineSingleton:
|
|
"""Tests for singleton."""
|
|
|
|
def test_singleton(self):
|
|
"""Test singleton returns same instance."""
|
|
import src.flywheel.rag_pipeline as module
|
|
module._pipeline = None
|
|
|
|
with patch('src.flywheel.rag_pipeline.get_qdrant_client'):
|
|
p1 = get_rag_pipeline()
|
|
p2 = get_rag_pipeline()
|
|
|
|
assert p1 is p2
|