160 lines
5.5 KiB
Python
160 lines
5.5 KiB
Python
|
|
import os
|
|
import re
|
|
from typing import List, Dict
|
|
import logging
|
|
from dotenv import load_dotenv
|
|
from qdrant_client import QdrantClient
|
|
from qdrant_client.http import models
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
# Setup logging
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
|
|
# Configuration
|
|
AGENT_CATALOG_PATH = os.path.join(os.path.dirname(__file__), '../docs/AGENT_CATALOG.md')
|
|
QDRANT_HOST = os.getenv("QDRANT_HOST", "localhost")
|
|
QDRANT_PORT = int(os.getenv("QDRANT_PORT", 6333))
|
|
COLLECTION_NAME = "routing_index"
|
|
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2" # Defines the vector size as 384
|
|
VECTOR_SIZE = 384
|
|
|
|
def read_agent_catalog(file_path: str) -> List[Dict]:
|
|
"""Reads the agent catalog and extracts agent information."""
|
|
if not os.path.exists(file_path):
|
|
logger.error(f"Agent catalog not found at {file_path}")
|
|
return []
|
|
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
content = f.read()
|
|
|
|
agents = []
|
|
# Regex to find agent blocks
|
|
# Looking for ### Agent Name ... - **Crews:** Crew Name
|
|
agent_blocks = re.split(r'### ', content)[1:] # Split and skip header
|
|
|
|
for block in agent_blocks:
|
|
lines = block.strip().split('\n')
|
|
name = lines[0].strip()
|
|
|
|
description = ""
|
|
crew = ""
|
|
|
|
for line in lines:
|
|
if line.startswith("- **Papel:**"):
|
|
role = line.split(":", 1)[1].strip()
|
|
description += f"Papel: {role}. "
|
|
elif line.startswith("- **Especialidade:**"):
|
|
specialty = line.split(":", 1)[1].strip()
|
|
description += f"Especialidade: {specialty}. "
|
|
elif line.startswith("- **Crews:**"):
|
|
crew = line.split(":", 1)[1].strip()
|
|
|
|
if name and description and crew:
|
|
agents.append({
|
|
"name": name,
|
|
"description": description,
|
|
"crew": crew
|
|
})
|
|
logger.info(f"Found agent: {name} (Crew: {crew})")
|
|
|
|
return agents
|
|
|
|
def get_embeddings(texts: List[str]) -> List[List[float]]:
|
|
"""Generates embeddings for a list of texts."""
|
|
logger.info(f"Generating embeddings using model {EMBEDDING_MODEL_NAME}...")
|
|
model = SentenceTransformer(EMBEDDING_MODEL_NAME)
|
|
embeddings = model.encode(texts)
|
|
return embeddings.tolist()
|
|
|
|
def init_qdrant_collection(client: QdrantClient, collection_name: str, vector_size: int):
|
|
"""Creates the Qdrant collection if it does not exist."""
|
|
collections = client.get_collections().collections
|
|
exists = any(c.name == collection_name for c in collections)
|
|
|
|
if not exists:
|
|
logger.info(f"Creating collection '{collection_name}' with vector size {vector_size}...")
|
|
client.create_collection(
|
|
collection_name=collection_name,
|
|
vectors_config=models.VectorParams(
|
|
size=vector_size,
|
|
distance=models.Distance.COSINE
|
|
)
|
|
)
|
|
else:
|
|
logger.info(f"Collection '{collection_name}' already exists.")
|
|
|
|
def populate_collection(client: QdrantClient, collection_name: str, agents: List[Dict]):
|
|
"""Populates the collection with agent embeddings."""
|
|
if not agents:
|
|
logger.warning("No agents to index.")
|
|
return
|
|
|
|
descriptions = [agent["description"] for agent in agents]
|
|
embeddings = get_embeddings(descriptions)
|
|
|
|
points = []
|
|
for i, agent in enumerate(agents):
|
|
# We process crew string to maybe pick the first one if multiple?
|
|
# For now, let's keep the raw string, or split if it's comma separated
|
|
# The instruction says payload: {"target_crew": "NomeDaCrew"}
|
|
# If an agent belongs to multiple crews, we might want to create multiple points or just string match
|
|
# Let's keep it simple for now as per instructions.
|
|
|
|
target_crew = agent["crew"]
|
|
|
|
points.append(models.PointStruct(
|
|
id=i + 1, # Simple integer ID
|
|
vector=embeddings[i],
|
|
payload={
|
|
"agent_name": agent["name"],
|
|
"target_crew": target_crew,
|
|
"description": agent["description"]
|
|
}
|
|
))
|
|
|
|
logger.info(f"Upserting {len(points)} points into '{collection_name}'...")
|
|
client.upsert(
|
|
collection_name=collection_name,
|
|
points=points
|
|
)
|
|
logger.info("Indexing complete.")
|
|
|
|
def main():
|
|
logger.info("Starting Athena DB Initialization...")
|
|
|
|
# 1. Read Catalog
|
|
agents = read_agent_catalog(AGENT_CATALOG_PATH)
|
|
if not agents:
|
|
logger.error("Failed to extract agents from catalog.")
|
|
return
|
|
|
|
# 2. Connect to Qdrant
|
|
try:
|
|
client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
|
|
# Check connection likely by getting collections
|
|
client.get_collections()
|
|
logger.info(f"Connected to Qdrant at {QDRANT_HOST}:{QDRANT_PORT}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to connect to Qdrant: {e}")
|
|
# For the sake of this task execution without running Qdrant, we might fail here.
|
|
# But the task is to CREATE the script.
|
|
# If the user has Qdrant running, it will work.
|
|
# If not, the script is still valid.
|
|
return
|
|
|
|
# 3. Init Collection
|
|
init_qdrant_collection(client, COLLECTION_NAME, VECTOR_SIZE)
|
|
|
|
# 4. Populate
|
|
populate_collection(client, COLLECTION_NAME, agents)
|
|
|
|
logger.info("Athena DB Initialization finished successfully.")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|