Compare commits

..

No commits in common. "3ad316151971e953b87a4045f401488bfd73f7c3" and "733ff85c088c5edebb7d67ad4c1ccc2e76a2fe8a" have entirely different histories.

13 changed files with 139 additions and 191 deletions

View File

@ -95,23 +95,12 @@ Este documento serve como o roteiro técnico detalhado para a implementação do
- [x] **Mapeamento de Alterações por Fase:** - [x] **Mapeamento de Alterações por Fase:**
- Listar commits e arquivos modificados em cada uma das fases (1-5) - Listar commits e arquivos modificados em cada uma das fases (1-5)
- Gerar manifesto de auditoria: [AUDIT_MANIFEST.md](file:///C:/Users/joao.goncalves/Desktop/Projetos/minions-da-itguys/.gemini/AUDIT_MANIFEST.md) - Gerar manifesto de auditoria: [AUDIT_MANIFEST.md](file:///C:/Users/joao.goncalves/Desktop/Projetos/minions-da-itguys/.gemini/AUDIT_MANIFEST.md)
- [x] **Execução de Agente de Qualidade:** - [ ] **Execução de Agente de Qualidade:**
- [x] Análise "ponto a ponto" do código mapeado - Análise "ponto a ponto" do código mapeado
- [x] Focos: Otimizações, Falhas de Segurança, Bugs Lógicos, Code Quality - Focos: Otimizações, Falhas de Segurança, Bugs Lógicos, Code Quality
- [x] Resultado: [AUDIT_REPORT.md](file:///C:/Users/joao.goncalves/.gemini/antigravity/brain/b1ff0191-b9df-4504-9bc7-ebcbbf6c59e4/AUDIT_REPORT.md)
- [x] **Soluções de Auditoria (Fixes):**
- [x] **Crítico:** Criar `src/clients/financial_client.py` (Mock) para desbloquear homologação
- [x] **Crítico:** Implementar Embeddings Reais (BGE/Ollama) no RAG e Memória
- [x] **Alta:** Persistir `AuditLog` no PostgreSQL (Pipeline)
- [x] **Alta:** Refatorar `Config` para uso consistente de `SecretsManager`
- [x] **Média:** Otimizar N+1 queries no `zabbix_connector.py`
- [x] **Baixa:** Validação dinâmica de domínios em `validators.py`
- [ ] **Refinamento e Correção:** - [ ] **Refinamento e Correção:**
- [x] Verificar todas as alterações - Aplicar melhorias sugeridas pelo agente
- [x] **Segunda Passagem de Auditoria (Deep Dive)**: - Validar ausência de regressões
- [x] Análise de regressão e pontos cegos pós-correção
- Resultado: [AUDIT_DEEP_DIVE.md](file:///C:/Users/joao.goncalves/.gemini/antigravity/brain/0ae8ff87-2359-49bb-951c-6f6c593ee5db/AUDIT_DEEP_DIVE.md)
- [ ] Validar ausência de regressões
## Fase 7: Homologação e Go-Live 🔄 ## Fase 7: Homologação e Go-Live 🔄
- [ ] **Obter Credenciais:** - [ ] **Obter Credenciais:**

View File

@ -11,9 +11,8 @@ from typing import Optional
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timezone from datetime import datetime, timezone
from src.agents.triage_agent import TriageResult, get_triage_agent from src.agents.triage_agent import TriageAgent, TriageResult, get_triage_agent
from src.agents.specialist_agent import SpecialistResponse, get_specialist_agent from src.agents.specialist_agent import SpecialistAgent, SpecialistResponse, get_specialist_agent
from src.database.connection import get_db_manager
from src.models import AuditLog, ResolutionStatus, TicketContext from src.models import AuditLog, ResolutionStatus, TicketContext
from src.security import sanitize_text from src.security import sanitize_text
@ -58,7 +57,6 @@ class TicketPipeline:
"""Initialize pipeline components.""" """Initialize pipeline components."""
self._triage = get_triage_agent() self._triage = get_triage_agent()
self._specialist = get_specialist_agent() self._specialist = get_specialist_agent()
self._db = get_db_manager()
async def process_email( async def process_email(
self, self,
@ -129,14 +127,10 @@ class TicketPipeline:
ticket_id=ticket_id, ticket_id=ticket_id,
sender_email=sender_email, sender_email=sender_email,
subject=subject, subject=subject,
body=body,
triage=triage_result, triage=triage_result,
specialist=specialist_result specialist=specialist_result
) )
# Save to database
await self._save_audit_log(result.audit_log)
result.success = specialist_result.success result.success = specialist_result.success
except Exception as e: except Exception as e:
@ -203,46 +197,11 @@ class TicketPipeline:
return "\n".join(lines) return "\n".join(lines)
async def _save_audit_log(self, log: AuditLog) -> None:
"""Persist audit log to database."""
try:
query = """
INSERT INTO audit_logs (
ticket_id, tenant_id, sender_email, subject, original_message,
context_collected, triage_model_output, specialist_model_reasoning,
response_sent, tools_called, resolution_status, processing_time_ms,
error_message, created_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
"""
await self._db.execute(
query,
log.ticket_id,
log.tenant_id,
log.sender_email,
log.subject,
log.original_message,
log.context_collected.model_dump_json(),
log.triage_model_output,
log.specialist_model_reasoning,
log.response_sent,
log.tools_called,
log.resolution_status.value,
log.processing_time_ms,
log.error_message,
log.created_at
)
logger.info(f"[{log.ticket_id}] Audit log saved successfully")
except Exception as e:
logger.error(f"[{log.ticket_id}] Failed to save audit log: {e}")
def _create_audit_log( def _create_audit_log(
self, self,
ticket_id: str, ticket_id: str,
sender_email: str, sender_email: str,
subject: str, subject: str,
body: str,
triage: TriageResult, triage: TriageResult,
specialist: SpecialistResponse specialist: SpecialistResponse
) -> AuditLog: ) -> AuditLog:
@ -280,7 +239,6 @@ class TicketPipeline:
tenant_id=triage.tenant.id if triage.tenant else "UNKNOWN", tenant_id=triage.tenant.id if triage.tenant else "UNKNOWN",
sender_email=sanitize_text(sender_email), sender_email=sanitize_text(sender_email),
subject=sanitize_text(subject), subject=sanitize_text(subject),
original_message=sanitize_text(body),
priority=triage.entities.priority.value, priority=triage.entities.priority.value,
category=triage.entities.category.value, category=triage.entities.category.value,
context_collected=context, context_collected=context,

View File

@ -74,8 +74,14 @@ class SelfCorrectionLayer:
def __init__(self): def __init__(self):
"""Initialize validator with allowed patterns.""" """Initialize validator with allowed patterns."""
# Allowed email domains are now managed dynamically via Financial System # Allowed email domains for responses
# self._allowed_domains = set() self._allowed_domains = {
"itguys.com.br",
"oestepan.com.br",
"oestepan.ind.br",
"enseg-rs.com.br",
"enseg.com.br",
}
# Blocked action patterns (prevent dangerous suggestions) # Blocked action patterns (prevent dangerous suggestions)
self._blocked_patterns = [ self._blocked_patterns = [
@ -230,7 +236,40 @@ class SelfCorrectionLayer:
return validation return validation
def validate_email_domain(self, email: str) -> ValidationResult:
"""
Validate email domain is allowed.
Prevents sending responses to unknown domains.
"""
validation = ValidationResult(is_valid=True)
if "@" not in email:
validation.add_issue(ValidationIssue(
code="EMAIL_INVALID_FORMAT",
message="Formato de email inválido",
severity=ValidationSeverity.ERROR,
field="email"
))
return validation
domain = email.split("@")[-1].lower()
if domain not in self._allowed_domains:
validation.add_issue(ValidationIssue(
code="EMAIL_DOMAIN_NOT_ALLOWED",
message=f"Domínio '{domain}' não está na lista permitida",
severity=ValidationSeverity.ERROR,
field="email",
suggestion="Adicionar domínio à lista de clientes ou verificar cadastro"
))
return validation
def add_allowed_domain(self, domain: str) -> None:
"""Add a domain to the allowed list."""
self._allowed_domains.add(domain.lower())
logger.info(f"Added allowed domain: {domain}")
def sanitize_response(self, response: str) -> str: def sanitize_response(self, response: str) -> str:
""" """

View File

@ -1,5 +1,5 @@
# Clients Module for Arthur Agent (External Integrations) # Clients Module for Arthur Agent (External Integrations)
from .financial_client import MockFinancialClient, FinancialClient, get_financial_client from .mock_financial import MockFinancialClient, FinancialClient, get_financial_client
from .mail_client import MailConfig from .mail_client import MailConfig
from .zabbix_connector import ZabbixConnector, get_zabbix_connector, HostStatus, Problem from .zabbix_connector import ZabbixConnector, get_zabbix_connector, HostStatus, Problem
from .qdrant_client import QdrantMultitenant, get_qdrant_client, SearchResult from .qdrant_client import QdrantMultitenant, get_qdrant_client, SearchResult

View File

@ -232,37 +232,6 @@ class OllamaClient:
except Exception as e: except Exception as e:
logger.error(f"Streaming failed: {e}") logger.error(f"Streaming failed: {e}")
async def get_embeddings(self, text: str) -> list[float]:
"""
Generate embeddings for text using the triage model (1B) or dedicated embedding model.
Args:
text: Input text
Returns:
List of floats representing the embedding
"""
try:
# Using triage model for embeddings as it's smaller/faster
# In production could use a specific embedding model (e.g. bge-m3)
payload = {
"model": self._triage_model,
"prompt": text,
}
response = await self._client.post("/api/embeddings", json=payload)
if response.status_code != 200:
logger.error(f"Ollama embedding error: {response.status_code} - {response.text}")
return []
data = response.json()
return data.get("embedding", [])
except Exception as e:
logger.error(f"Embedding generation failed: {e}")
return []
async def close(self) -> None: async def close(self) -> None:
"""Close HTTP client.""" """Close HTTP client."""
await self._client.aclose() await self._client.aclose()

View File

@ -6,8 +6,7 @@ infrastructure diagnostics and root cause analysis.
""" """
import logging import logging
import time from typing import Optional, Any
from typing import Optional
from dataclasses import dataclass from dataclasses import dataclass
from zabbix_utils import ZabbixAPI from zabbix_utils import ZabbixAPI
@ -195,14 +194,13 @@ class ZabbixConnector:
try: try:
params = { params = {
"recent": True,
"sortfield": ["severity", "eventid"],
"sortorder": ["DESC", "DESC"],
"limit": limit,
"selectTags": "extend",
"output": ["eventid", "objectid", "severity", "name", "output": ["eventid", "objectid", "severity", "name",
"acknowledged", "clock", "r_clock"], "acknowledged", "clock", "r_eventid"]
"selectHosts": ["hostid", "host", "name"], # Fetch host info in same query
"selectTags": "extend", # Keep tags as they are in the Problem dataclass
"recent": True, # Keep recent as it was in the original params
"sortfield": ["severity", "eventid"], # Keep original sortfield
"sortorder": ["DESC", "DESC"], # Keep original sortorder
"limit": limit
} }
if host_id: if host_id:
@ -215,9 +213,8 @@ class ZabbixConnector:
result = [] result = []
for p in problems: for p in problems:
# Extract host info from payload # Get host info for this problem
hosts = p.get("hosts", []) host_info = self._get_host_for_trigger(p.get("objectid"))
host_info = hosts[0] if hosts else {}
result.append(Problem( result.append(Problem(
event_id=p["eventid"], event_id=p["eventid"],
@ -285,7 +282,8 @@ class ZabbixConnector:
if not neighbor_ids: if not neighbor_ids:
return [] return []
# Get problems for neighbor hosts (using selectHosts to avoid N+1) # Get problems for neighbor hosts
import time
time_from = int(time.time()) - (time_window_minutes * 60) time_from = int(time.time()) - (time_window_minutes * 60)
problems = self._api.problem.get( problems = self._api.problem.get(
@ -294,17 +292,13 @@ class ZabbixConnector:
recent=True, recent=True,
sortfield="eventid", sortfield="eventid",
sortorder="DESC", sortorder="DESC",
selectHosts=["hostid", "host"], # Fetch host info in same query
output=["eventid", "objectid", "severity", "name", output=["eventid", "objectid", "severity", "name",
"acknowledged", "clock"] "acknowledged", "clock"]
) )
result = [] result = []
for p in problems: for p in problems:
# Extract host info from payload (no extra API call) host_info = self._get_host_for_trigger(p.get("objectid"))
hosts = p.get("hosts", [])
host_info = hosts[0] if hosts else {}
result.append(Problem( result.append(Problem(
event_id=p["eventid"], event_id=p["eventid"],
host_id=host_info.get("hostid", ""), host_id=host_info.get("hostid", ""),

View File

@ -10,14 +10,9 @@ from dotenv import load_dotenv
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
from src.security.secrets_manager import SecretsManager
# Load environment variables # Load environment variables
load_dotenv() load_dotenv()
# Setup Secrets
secrets = SecretsManager()
# Setup Logging # Setup Logging
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
@ -96,24 +91,24 @@ class Config:
def get_qdrant_config() -> QdrantConfig: def get_qdrant_config() -> QdrantConfig:
"""Returns Qdrant connection configuration.""" """Returns Qdrant connection configuration."""
return QdrantConfig( return QdrantConfig(
host=secrets.get("QDRANT_HOST", "qdrant"), host=os.getenv("QDRANT_HOST", "qdrant"),
port=int(secrets.get("QDRANT_PORT", "6333")), port=int(os.getenv("QDRANT_PORT", "6333")),
collection_name=secrets.get("QDRANT_COLLECTION", "arthur_knowledge"), collection_name=os.getenv("QDRANT_COLLECTION", "arthur_knowledge"),
use_grpc=secrets.get("QDRANT_USE_GRPC", "false").lower() == "true", use_grpc=os.getenv("QDRANT_USE_GRPC", "false").lower() == "true",
on_disk=secrets.get("QDRANT_ON_DISK", "true").lower() == "true", on_disk=os.getenv("QDRANT_ON_DISK", "true").lower() == "true",
) )
@staticmethod @staticmethod
def get_postgres_config() -> PostgresConfig: def get_postgres_config() -> PostgresConfig:
"""Returns PostgreSQL configuration.""" """Returns PostgreSQL configuration."""
return PostgresConfig( return PostgresConfig(
host=secrets.get("POSTGRES_HOST", "postgres"), host=os.getenv("POSTGRES_HOST", "postgres"),
port=int(secrets.get("POSTGRES_PORT", "5432")), port=int(os.getenv("POSTGRES_PORT", "5432")),
database=secrets.get("POSTGRES_DB", "arthur_db"), database=os.getenv("POSTGRES_DB", "arthur_db"),
user=secrets.get("POSTGRES_USER", "arthur"), user=os.getenv("POSTGRES_USER", "arthur"),
password=secrets.get("POSTGRES_PASSWORD"), password=os.getenv("POSTGRES_PASSWORD"),
min_pool_size=int(secrets.get("POSTGRES_MIN_POOL", "2")), min_pool_size=int(os.getenv("POSTGRES_MIN_POOL", "2")),
max_pool_size=int(secrets.get("POSTGRES_MAX_POOL", "10")), max_pool_size=int(os.getenv("POSTGRES_MAX_POOL", "10")),
) )
@staticmethod @staticmethod
@ -143,12 +138,12 @@ class Config:
def get_mail_config() -> MailConfig: def get_mail_config() -> MailConfig:
"""Returns email configuration.""" """Returns email configuration."""
return MailConfig( return MailConfig(
imap_host=secrets.get("MAIL_IMAP_HOST", "mail.itguys.com.br"), imap_host=os.getenv("MAIL_IMAP_HOST", "mail.itguys.com.br"),
imap_port=int(secrets.get("MAIL_IMAP_PORT", "993")), imap_port=int(os.getenv("MAIL_IMAP_PORT", "993")),
smtp_host=secrets.get("MAIL_SMTP_HOST", "mail.itguys.com.br"), smtp_host=os.getenv("MAIL_SMTP_HOST", "mail.itguys.com.br"),
smtp_port=int(secrets.get("MAIL_SMTP_PORT", "587")), smtp_port=int(os.getenv("MAIL_SMTP_PORT", "587")),
email_address=secrets.get("MAIL_ADDRESS", "arthur.servicedesk@itguys.com.br"), email_address=os.getenv("MAIL_ADDRESS", "arthur.servicedesk@itguys.com.br"),
password=secrets.get("MAIL_PASSWORD"), password=os.getenv("MAIL_PASSWORD"),
) )
@staticmethod @staticmethod

View File

@ -36,8 +36,8 @@ class DatabaseManager:
return ( return (
f"postgresql://{self._secrets.get('POSTGRES_USER')}:" f"postgresql://{self._secrets.get('POSTGRES_USER')}:"
f"{self._secrets.get('POSTGRES_PASSWORD')}@" f"{self._secrets.get('POSTGRES_PASSWORD')}@"
f"{self._secrets.get('POSTGRES_HOST')}:" f"{os.getenv('POSTGRES_HOST', 'postgres')}:"
f"{self._secrets.get('POSTGRES_PORT')}/" f"{os.getenv('POSTGRES_PORT', '5432')}/"
f"{self._secrets.get('POSTGRES_DB')}" f"{self._secrets.get('POSTGRES_DB')}"
) )

View File

@ -239,7 +239,7 @@ class HomologationValidator:
client = get_financial_client() client = get_financial_client()
# Test tenant lookup # Test tenant lookup
tenant = await client.resolve_tenant_from_email("teste@oestepan.com.br") tenant = await client.get_tenant_by_email("teste@oestepan.com.br")
if tenant: if tenant:
check.status = ValidationStatus.PASSED check.status = ValidationStatus.PASSED

View File

@ -12,7 +12,7 @@ from dataclasses import dataclass, field, asdict
from datetime import datetime, timezone from datetime import datetime, timezone
from enum import Enum from enum import Enum
from src.clients import get_qdrant_client, get_ollama_client from src.clients import get_qdrant_client
logger = logging.getLogger("ArthurMemory") logger = logging.getLogger("ArthurMemory")
@ -83,9 +83,8 @@ class EpisodicMemory:
"""Initialize episodic memory.""" """Initialize episodic memory."""
self._embedding_dim = embedding_dim self._embedding_dim = embedding_dim
self._qdrant = get_qdrant_client() self._qdrant = get_qdrant_client()
self._ollama = get_ollama_client()
async def store_lesson( def store_lesson(
self, self,
ticket_id: str, ticket_id: str,
tenant_id: str, tenant_id: str,
@ -144,7 +143,7 @@ class EpisodicMemory:
search_content = self._create_search_content(entry) search_content = self._create_search_content(entry)
# Generate embedding # Generate embedding
embedding = await self._generate_embedding(search_content) embedding = self._generate_embedding(search_content)
# Store in Qdrant # Store in Qdrant
success = self._qdrant.upsert_document( success = self._qdrant.upsert_document(
@ -176,7 +175,7 @@ class EpisodicMemory:
logger.error(f"Failed to store lesson: {e}") logger.error(f"Failed to store lesson: {e}")
return None return None
async def store_antipattern( def store_antipattern(
self, self,
ticket_id: str, ticket_id: str,
tenant_id: str, tenant_id: str,
@ -221,7 +220,7 @@ class EpisodicMemory:
) )
search_content = f"ANTIPADRÃO: {problem_summary}. NÃO FAZER: {failed_approach}" search_content = f"ANTIPADRÃO: {problem_summary}. NÃO FAZER: {failed_approach}"
embedding = await self._generate_embedding(search_content) embedding = self._generate_embedding(search_content)
success = self._qdrant.upsert_document( success = self._qdrant.upsert_document(
doc_id=memory_id, doc_id=memory_id,
@ -250,7 +249,7 @@ class EpisodicMemory:
logger.error(f"Failed to store antipattern: {e}") logger.error(f"Failed to store antipattern: {e}")
return None return None
async def search_similar( def search_similar(
self, self,
problem_description: str, problem_description: str,
tenant_id: str, tenant_id: str,
@ -273,7 +272,7 @@ class EpisodicMemory:
""" """
try: try:
# Generate embedding for search # Generate embedding for search
embedding = await self._generate_embedding(problem_description) embedding = self._generate_embedding(problem_description)
# Search in Qdrant # Search in Qdrant
results = self._qdrant.search( results = self._qdrant.search(
@ -335,9 +334,16 @@ class EpisodicMemory:
return "\n".join(parts) return "\n".join(parts)
async def _generate_embedding(self, text: str) -> List[float]: def _generate_embedding(self, text: str) -> List[float]:
"""Generate embedding for text using Ollama.""" """Generate embedding for text (placeholder)."""
return await self._ollama.get_embeddings(text) # Same placeholder as RAG pipeline
hash_bytes = hashlib.sha256(text.encode()).digest()
embedding = []
for i in range(self._embedding_dim):
byte_idx = i % len(hash_bytes)
value = (hash_bytes[byte_idx] / 255.0) * 2 - 1
embedding.append(value)
return embedding
def _reconstruct_entry(self, result) -> Optional[MemoryEntry]: def _reconstruct_entry(self, result) -> Optional[MemoryEntry]:
"""Reconstruct MemoryEntry from search result.""" """Reconstruct MemoryEntry from search result."""

View File

@ -5,7 +5,7 @@ Processes Markdown and PDF documents, extracts text,
generates embeddings and indexes in Qdrant. generates embeddings and indexes in Qdrant.
""" """
import os
import re import re
import hashlib import hashlib
import logging import logging
@ -14,7 +14,7 @@ from typing import Optional, List
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime, timezone from datetime import datetime, timezone
from src.clients import get_qdrant_client, get_ollama_client from src.clients import get_qdrant_client
logger = logging.getLogger("ArthurRAG") logger = logging.getLogger("ArthurRAG")
@ -92,9 +92,8 @@ class RAGIngestionPipeline:
self._chunk_overlap = chunk_overlap self._chunk_overlap = chunk_overlap
self._embedding_dim = embedding_dim self._embedding_dim = embedding_dim
self._qdrant = get_qdrant_client() self._qdrant = get_qdrant_client()
self._ollama = get_ollama_client()
async def ingest_directory( def ingest_directory(
self, self,
directory: str, directory: str,
tenant_id: str, tenant_id: str,
@ -137,7 +136,7 @@ class RAGIngestionPipeline:
logger.info(f"Found {len(files)} documents in {directory}") logger.info(f"Found {len(files)} documents in {directory}")
for filepath in files: for filepath in files:
result = await self.ingest_file( result = self.ingest_file(
filepath=str(filepath), filepath=str(filepath),
tenant_id=tenant_id, tenant_id=tenant_id,
doc_type=doc_type doc_type=doc_type
@ -146,7 +145,7 @@ class RAGIngestionPipeline:
return results return results
async def ingest_file( def ingest_file(
self, self,
filepath: str, filepath: str,
tenant_id: str, tenant_id: str,
@ -207,11 +206,8 @@ class RAGIngestionPipeline:
# Generate embeddings and index # Generate embeddings and index
indexed = 0 indexed = 0
for chunk in chunks: for chunk in chunks:
# Generate real embedding using Ollama # Placeholder embedding (in production, use sentence-transformers)
chunk.embedding = await self._generate_embedding(chunk.content) chunk.embedding = self._generate_embedding(chunk.content)
if not chunk.embedding:
logger.warning(f"Failed to generate embedding for chunk {chunk.id}")
continue
# Index in Qdrant # Index in Qdrant
success = self._qdrant.upsert_document( success = self._qdrant.upsert_document(
@ -367,14 +363,26 @@ class RAGIngestionPipeline:
content = f"{filepath}:{chunk_index}" content = f"{filepath}:{chunk_index}"
return hashlib.md5(content.encode()).hexdigest() return hashlib.md5(content.encode()).hexdigest()
async def _generate_embedding(self, text: str) -> List[float]: def _generate_embedding(self, text: str) -> List[float]:
""" """
Generate embedding for text using Ollama. Generate embedding for text.
"""
embedding = await self._ollama.get_embeddings(text) NOTE: This is a placeholder. In production, use:
- sentence-transformers with BGE-small
- Or Ollama embedding endpoint
"""
# Placeholder: return deterministic pseudo-embedding
# In production, replace with actual embedding model
import hashlib
hash_bytes = hashlib.sha256(text.encode()).digest()
# Create normalized vector from hash
embedding = []
for i in range(self._embedding_dim):
byte_idx = i % len(hash_bytes)
value = (hash_bytes[byte_idx] / 255.0) * 2 - 1 # Normalize to [-1, 1]
embedding.append(value)
# If model dimension differs, we might need padding/truncating (or just trust the model)
# For now we assume the model returns correct DIM or we handle it downstream
return embedding return embedding

View File

@ -8,7 +8,7 @@ import pytest
import tempfile import tempfile
import os import os
from pathlib import Path from pathlib import Path
from unittest.mock import Mock, patch, AsyncMock from unittest.mock import Mock, patch
from src.flywheel.rag_pipeline import ( from src.flywheel.rag_pipeline import (
RAGIngestionPipeline, RAGIngestionPipeline,
@ -42,16 +42,10 @@ class TestRAGPipeline:
@pytest.fixture @pytest.fixture
def pipeline(self): def pipeline(self):
"""Create pipeline with mocked Qdrant and Ollama.""" """Create pipeline with mocked Qdrant."""
with patch('src.flywheel.rag_pipeline.get_qdrant_client') as mock_qdrant, \ with patch('src.flywheel.rag_pipeline.get_qdrant_client') as mock:
patch('src.flywheel.rag_pipeline.get_ollama_client') as mock_ollama: mock.return_value = Mock()
mock_qdrant.return_value = Mock() mock.return_value.upsert_document = Mock(return_value=True)
mock_qdrant.return_value.upsert_document = Mock(return_value=True)
# Mock Ollama client for embeddings
mock_ollama.return_value = Mock()
mock_ollama.return_value.get_embeddings = AsyncMock(return_value=[0.1] * 384)
return RAGIngestionPipeline() return RAGIngestionPipeline()
def test_sanitize_removes_scripts(self, pipeline): def test_sanitize_removes_scripts(self, pipeline):
@ -126,19 +120,16 @@ class TestRAGPipeline:
assert id1 != id2 assert id1 != id2
assert id1 == id3 # Same inputs should give same ID assert id1 == id3 # Same inputs should give same ID
@pytest.mark.asyncio def test_generate_embedding(self, pipeline):
async def test_generate_embedding(self, pipeline): """Test embedding generation."""
"""Test embedding generation via Ollama.""" emb = pipeline._generate_embedding("test content")
emb = await pipeline._generate_embedding("test content")
# Embedding should be returned from mock assert len(emb) == 384
assert isinstance(emb, list) assert all(-1 <= v <= 1 for v in emb)
assert len(emb) > 0
@pytest.mark.asyncio def test_ingest_file_not_found(self, pipeline):
async def test_ingest_file_not_found(self, pipeline):
"""Test ingestion of non-existent file.""" """Test ingestion of non-existent file."""
result = await pipeline.ingest_file( result = pipeline.ingest_file(
filepath="/nonexistent/file.md", filepath="/nonexistent/file.md",
tenant_id="tenant-001" tenant_id="tenant-001"
) )
@ -146,15 +137,14 @@ class TestRAGPipeline:
assert result.success is False assert result.success is False
assert "not found" in result.error.lower() assert "not found" in result.error.lower()
@pytest.mark.asyncio def test_ingest_file_success(self, pipeline):
async def test_ingest_file_success(self, pipeline):
"""Test successful file ingestion.""" """Test successful file ingestion."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False) as f: with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False) as f:
f.write("# Test Document\n\nThis is a test about Linux servers.") f.write("# Test Document\n\nThis is a test about Linux servers.")
filepath = f.name filepath = f.name
try: try:
result = await pipeline.ingest_file( result = pipeline.ingest_file(
filepath=filepath, filepath=filepath,
tenant_id="tenant-001" tenant_id="tenant-001"
) )