minions-ai-agents/src/config.py

164 lines
5.6 KiB
Python

import os
import logging
from dotenv import load_dotenv
import litellm
# Disable LiteLLM callbacks to prevent missing dependency errors (APScheduler etc)
litellm.success_callback = []
litellm.failure_callback = []
litellm.callbacks = []
# Load environment variables
load_dotenv()
# Setup Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("AntigravityConfig")
class Config:
"""
Central Configuration for LLM and Memory Providers.
Supports Dual-Model Strategy (Fast vs Smart).
"""
@staticmethod
def _get_llm_dict(provider, model_name, base_url=None):
"""Helper to construct the LLM config dictionary."""
if provider == "openai":
return {"model": model_name, "temperature": 0.7}
elif provider == "anthropic":
return {"model": f"anthropic/{model_name}", "temperature": 0.7}
elif provider == "gemini":
# LiteLLM format for Gemini
return {
"model": f"gemini/{model_name}",
"temperature": 0.7,
"api_key": os.getenv("GEMINI_API_KEY")
}
elif provider == "ollama":
return {
"model": f"ollama/{model_name}",
"base_url": base_url,
"temperature": 0.7
}
elif provider == "azure":
return {"model": f"azure/{model_name}", "temperature": 0.7}
else:
return {"model": model_name, "temperature": 0.7}
@staticmethod
def get_llm_config(mode="smart"):
"""
Returns the LLM configuration.
:param mode: 'smart' (High Reasoning) or 'fast' (High Speed/Low Cost)
"""
provider = os.getenv("LLM_PROVIDER", "openai").lower()
base_url = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
# Select Model Name based on Mode
if mode == "fast":
model_name = os.getenv("LLM_MODEL_FAST", "gpt-3.5-turbo")
else:
model_name = os.getenv("LLM_MODEL_SMART", "gpt-4o")
logger.info(f"Loading {mode.upper()} LLM: Provider={provider}, Model={model_name}")
return Config._get_llm_dict(provider, model_name, base_url)
@staticmethod
def get_mem0_config():
"""
Returns the Mem0 configuration with the correct LLM, Embedder, and Vector Store.
Supports fully local mode with Ollama, or cloud with Gemini/OpenAI.
"""
memory_provider = os.getenv("MEMORY_PROVIDER", "mem0").lower()
embedding_provider = os.getenv("MEMORY_EMBEDDING_PROVIDER", "openai").lower()
llm_provider = os.getenv("LLM_PROVIDER", "openai").lower()
project_id = os.getenv("MEMORY_PROJECT_ID", "default_project")
config = {
"version": "v1.1",
}
# LLM Configuration (REQUIRED for Mem0 to process memories)
if llm_provider == "gemini":
config["llm"] = {
"provider": "litellm",
"config": {
"model": f"gemini/{os.getenv('LLM_MODEL_FAST', 'gemini-2.0-flash-exp')}",
"api_key": os.getenv("GEMINI_API_KEY"),
"temperature": 0.1
}
}
elif llm_provider == "openai":
config["llm"] = {
"provider": "openai",
"config": {
"model": os.getenv("LLM_MODEL_FAST", "gpt-3.5-turbo"),
"api_key": os.getenv("OPENAI_API_KEY"),
"temperature": 0.1
}
}
elif llm_provider == "ollama":
config["llm"] = {
"provider": "ollama",
"config": {
"model": os.getenv("LLM_MODEL_FAST", "llama3"),
"ollama_base_url": os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"),
"temperature": 0.1
}
}
# Embedder Configuration
if embedding_provider == "openai":
config["embedder"] = {
"provider": "openai",
"config": {
"model": "text-embedding-3-small",
"api_key": os.getenv("OPENAI_API_KEY")
}
}
elif embedding_provider == "local":
# Use sentence-transformers for fully local embeddings
config["embedder"] = {
"provider": "huggingface",
"config": {
"model": "all-MiniLM-L6-v2"
}
}
elif embedding_provider == "gemini":
# Gemini embeddings
config["embedder"] = {
"provider": "gemini",
"config": {
"model": "models/text-embedding-004",
"api_key": os.getenv("GEMINI_API_KEY")
}
}
# Vector Store Configuration
if memory_provider == "qdrant":
config["vector_store"] = {
"provider": "qdrant",
"config": {
"host": os.getenv("QDRANT_HOST", "localhost"),
"port": int(os.getenv("QDRANT_PORT", 6333)),
"collection_name": project_id
}
}
# Add user_id for memory scoping
config["user_id"] = project_id
return config
@staticmethod
def get_telegram_token():
return os.getenv("TELEGRAM_BOT_TOKEN")
@staticmethod
def get_allowed_chats():
try:
chats = os.getenv("TELEGRAM_ALLOWED_CHAT_IDS", "").split(",")
return [int(c.strip()) for c in chats if c.strip()]
except ValueError:
return []