164 lines
5.6 KiB
Python
164 lines
5.6 KiB
Python
import os
|
|
import logging
|
|
from dotenv import load_dotenv
|
|
import litellm
|
|
|
|
# Disable LiteLLM callbacks to prevent missing dependency errors (APScheduler etc)
|
|
litellm.success_callback = []
|
|
litellm.failure_callback = []
|
|
litellm.callbacks = []
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
|
|
# Setup Logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger("AntigravityConfig")
|
|
|
|
class Config:
|
|
"""
|
|
Central Configuration for LLM and Memory Providers.
|
|
Supports Dual-Model Strategy (Fast vs Smart).
|
|
"""
|
|
|
|
@staticmethod
|
|
def _get_llm_dict(provider, model_name, base_url=None):
|
|
"""Helper to construct the LLM config dictionary."""
|
|
if provider == "openai":
|
|
return {"model": model_name, "temperature": 0.7}
|
|
elif provider == "anthropic":
|
|
return {"model": f"anthropic/{model_name}", "temperature": 0.7}
|
|
elif provider == "gemini":
|
|
# LiteLLM format for Gemini
|
|
return {
|
|
"model": f"gemini/{model_name}",
|
|
"temperature": 0.7,
|
|
"api_key": os.getenv("GEMINI_API_KEY")
|
|
}
|
|
elif provider == "ollama":
|
|
return {
|
|
"model": f"ollama/{model_name}",
|
|
"base_url": base_url,
|
|
"temperature": 0.7
|
|
}
|
|
elif provider == "azure":
|
|
return {"model": f"azure/{model_name}", "temperature": 0.7}
|
|
else:
|
|
return {"model": model_name, "temperature": 0.7}
|
|
|
|
@staticmethod
|
|
def get_llm_config(mode="smart"):
|
|
"""
|
|
Returns the LLM configuration.
|
|
:param mode: 'smart' (High Reasoning) or 'fast' (High Speed/Low Cost)
|
|
"""
|
|
provider = os.getenv("LLM_PROVIDER", "openai").lower()
|
|
base_url = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
|
|
|
|
# Select Model Name based on Mode
|
|
if mode == "fast":
|
|
model_name = os.getenv("LLM_MODEL_FAST", "gpt-3.5-turbo")
|
|
else:
|
|
model_name = os.getenv("LLM_MODEL_SMART", "gpt-4o")
|
|
|
|
logger.info(f"Loading {mode.upper()} LLM: Provider={provider}, Model={model_name}")
|
|
return Config._get_llm_dict(provider, model_name, base_url)
|
|
|
|
@staticmethod
|
|
def get_mem0_config():
|
|
"""
|
|
Returns the Mem0 configuration with the correct LLM, Embedder, and Vector Store.
|
|
Supports fully local mode with Ollama, or cloud with Gemini/OpenAI.
|
|
"""
|
|
memory_provider = os.getenv("MEMORY_PROVIDER", "mem0").lower()
|
|
embedding_provider = os.getenv("MEMORY_EMBEDDING_PROVIDER", "openai").lower()
|
|
llm_provider = os.getenv("LLM_PROVIDER", "openai").lower()
|
|
project_id = os.getenv("MEMORY_PROJECT_ID", "default_project")
|
|
|
|
config = {
|
|
"version": "v1.1",
|
|
}
|
|
|
|
# LLM Configuration (REQUIRED for Mem0 to process memories)
|
|
if llm_provider == "gemini":
|
|
config["llm"] = {
|
|
"provider": "litellm",
|
|
"config": {
|
|
"model": f"gemini/{os.getenv('LLM_MODEL_FAST', 'gemini-2.0-flash-exp')}",
|
|
"api_key": os.getenv("GEMINI_API_KEY"),
|
|
"temperature": 0.1
|
|
}
|
|
}
|
|
elif llm_provider == "openai":
|
|
config["llm"] = {
|
|
"provider": "openai",
|
|
"config": {
|
|
"model": os.getenv("LLM_MODEL_FAST", "gpt-3.5-turbo"),
|
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
|
"temperature": 0.1
|
|
}
|
|
}
|
|
elif llm_provider == "ollama":
|
|
config["llm"] = {
|
|
"provider": "ollama",
|
|
"config": {
|
|
"model": os.getenv("LLM_MODEL_FAST", "llama3"),
|
|
"ollama_base_url": os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"),
|
|
"temperature": 0.1
|
|
}
|
|
}
|
|
|
|
# Embedder Configuration
|
|
if embedding_provider == "openai":
|
|
config["embedder"] = {
|
|
"provider": "openai",
|
|
"config": {
|
|
"model": "text-embedding-3-small",
|
|
"api_key": os.getenv("OPENAI_API_KEY")
|
|
}
|
|
}
|
|
elif embedding_provider == "local":
|
|
# Use sentence-transformers for fully local embeddings
|
|
config["embedder"] = {
|
|
"provider": "huggingface",
|
|
"config": {
|
|
"model": "all-MiniLM-L6-v2"
|
|
}
|
|
}
|
|
elif embedding_provider == "gemini":
|
|
# Gemini embeddings
|
|
config["embedder"] = {
|
|
"provider": "gemini",
|
|
"config": {
|
|
"model": "models/text-embedding-004",
|
|
"api_key": os.getenv("GEMINI_API_KEY")
|
|
}
|
|
}
|
|
|
|
# Vector Store Configuration
|
|
if memory_provider == "qdrant":
|
|
config["vector_store"] = {
|
|
"provider": "qdrant",
|
|
"config": {
|
|
"host": os.getenv("QDRANT_HOST", "localhost"),
|
|
"port": int(os.getenv("QDRANT_PORT", 6333)),
|
|
"collection_name": project_id
|
|
}
|
|
}
|
|
|
|
# Add user_id for memory scoping
|
|
config["user_id"] = project_id
|
|
|
|
return config
|
|
|
|
@staticmethod
|
|
def get_telegram_token():
|
|
return os.getenv("TELEGRAM_BOT_TOKEN")
|
|
|
|
@staticmethod
|
|
def get_allowed_chats():
|
|
try:
|
|
chats = os.getenv("TELEGRAM_ALLOWED_CHAT_IDS", "").split(",")
|
|
return [int(c.strip()) for c in chats if c.strip()]
|
|
except ValueError:
|
|
return [] |