AI Providers Configuration
Setting up and configuring AI/LLM providers for GoPie
GoPie supports multiple AI providers for natural language processing, SQL generation, and embeddings. This guide covers configuration and optimization for each provider.
Overview
GoPie uses AI providers for:
- Natural Language Understanding - Parsing user queries
- SQL Generation - Converting questions to SQL
- Schema Analysis - Understanding database structures
- Embeddings - Semantic search of schemas
- Code Execution - Python code generation for visualizations
OpenAI Configuration
Basic Setup
# .env configuration
OPENAI_API_KEY=sk-your-api-key
OPENAI_ORG_ID=org-your-org-id # Optional
OPENAI_API_BASE=https://api.openai.com/v1 # Optional custom endpointModel Configuration
# chat-server/app/config/ai_config.py
from enum import Enum
from pydantic import BaseSettings
class OpenAIModel(str, Enum):
GPT4 = "gpt-4"
GPT4_TURBO = "gpt-4-turbo-preview"
GPT35_TURBO = "gpt-3.5-turbo"
GPT35_TURBO_16K = "gpt-3.5-turbo-16k"
class EmbeddingModel(str, Enum):
ADA_002 = "text-embedding-ada-002"
ADA_003 = "text-embedding-3-small"
ADA_003_LARGE = "text-embedding-3-large"
class OpenAIConfig(BaseSettings):
api_key: str
org_id: Optional[str] = None
api_base: str = "https://api.openai.com/v1"
# Model selection
default_model: OpenAIModel = OpenAIModel.GPT4
embedding_model: EmbeddingModel = EmbeddingModel.ADA_002
# Model parameters
temperature: float = 0.7
max_tokens: int = 2000
top_p: float = 1.0
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
# Rate limiting
max_retries: int = 3
retry_delay: float = 1.0
request_timeout: int = 60
# Cost optimization
enable_caching: bool = True
cache_ttl: int = 3600 # 1 hour
class Config:
env_prefix = "OPENAI_"LangChain Integration
# chat-server/app/services/llm_service.py
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain.callbacks import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.cache import RedisCache
import redis
class LLMService:
def __init__(self, config: OpenAIConfig):
self.config = config
# Initialize Redis cache if enabled
if config.enable_caching:
redis_client = redis.Redis.from_url(os.getenv("REDIS_URL", "redis://localhost:6379"))
langchain.llm_cache = RedisCache(redis_client)
# Initialize LLM with streaming
self.llm = ChatOpenAI(
api_key=config.api_key,
organization=config.org_id,
model_name=config.default_model,
temperature=config.temperature,
max_tokens=config.max_tokens,
streaming=True,
callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),
request_timeout=config.request_timeout,
max_retries=config.max_retries,
)
# Initialize embeddings
self.embeddings = OpenAIEmbeddings(
api_key=config.api_key,
organization=config.org_id,
model=config.embedding_model,
request_timeout=config.request_timeout,
max_retries=config.max_retries,
)
async def generate_sql(self, question: str, schema: str) -> str:
"""Generate SQL from natural language question"""
prompt = f"""
Given the following database schema:
{schema}
Generate a SQL query to answer this question:
{question}
Return only the SQL query without explanation.
"""
response = await self.llm.ainvoke(prompt)
return response.content.strip()Anthropic Configuration
Setup
# .env configuration
ANTHROPIC_API_KEY=sk-ant-your-api-key
ANTHROPIC_MODEL=claude-3-opus-20240229 # or claude-3-sonnet-20240229
ANTHROPIC_MAX_TOKENS=4000Anthropic Integration
# chat-server/app/services/anthropic_service.py
from anthropic import Anthropic, AsyncAnthropic
from typing import AsyncGenerator
class AnthropicService:
def __init__(self):
self.client = AsyncAnthropic(
api_key=os.getenv("ANTHROPIC_API_KEY"),
)
self.model = os.getenv("ANTHROPIC_MODEL", "claude-3-opus-20240229")
self.max_tokens = int(os.getenv("ANTHROPIC_MAX_TOKENS", "4000"))
async def generate_sql_streaming(
self,
question: str,
schema: str
) -> AsyncGenerator[str, None]:
"""Generate SQL with streaming response"""
prompt = f"""
You are a SQL expert. Given this database schema:
{schema}
Generate a SQL query to answer: {question}
Requirements:
1. Use proper SQL syntax
2. Handle edge cases
3. Optimize for performance
4. Return only the SQL query
"""
stream = await self.client.messages.create(
model=self.model,
max_tokens=self.max_tokens,
temperature=0.3, # Lower temperature for SQL generation
messages=[
{"role": "user", "content": prompt}
],
stream=True,
)
async for chunk in stream:
if chunk.type == "content_block_delta":
yield chunk.delta.textAzure OpenAI Configuration
Setup
# .env configuration
AZURE_OPENAI_API_KEY=your-azure-api-key
AZURE_OPENAI_ENDPOINT=https://your-resource.openai.azure.com/
AZURE_OPENAI_DEPLOYMENT_NAME=your-deployment-name
AZURE_OPENAI_API_VERSION=2024-02-15-previewAzure Integration
# chat-server/app/services/azure_openai_service.py
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
class AzureOpenAIService:
def __init__(self):
self.llm = AzureChatOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
azure_deployment=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
temperature=0.7,
max_tokens=2000,
streaming=True,
)
self.embeddings = AzureOpenAIEmbeddings(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
azure_deployment=os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
)Local Model Support
Ollama Configuration
# Install Ollama
curl -fsSL https://ollama.ai/install.sh | sh
# Pull models
ollama pull llama2
ollama pull codellama
ollama pull mistral
# .env configuration
OLLAMA_BASE_URL=http://localhost:11434
OLLAMA_MODEL=codellama:13bOllama Integration
# chat-server/app/services/ollama_service.py
from langchain_community.llms import Ollama
from langchain_community.embeddings import OllamaEmbeddings
class OllamaService:
def __init__(self):
self.llm = Ollama(
base_url=os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"),
model=os.getenv("OLLAMA_MODEL", "codellama:13b"),
temperature=0.7,
num_predict=2000,
streaming=True,
)
self.embeddings = OllamaEmbeddings(
base_url=os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"),
model="nomic-embed-text", # Specialized embedding model
)Multi-Provider Strategy
Provider Selection
# chat-server/app/services/ai_provider_factory.py
from enum import Enum
from typing import Union
class AIProvider(str, Enum):
OPENAI = "openai"
ANTHROPIC = "anthropic"
AZURE = "azure"
OLLAMA = "ollama"
class AIProviderFactory:
@staticmethod
def create_llm_service(provider: AIProvider) -> BaseLLMService:
if provider == AIProvider.OPENAI:
return OpenAIService()
elif provider == AIProvider.ANTHROPIC:
return AnthropicService()
elif provider == AIProvider.AZURE:
return AzureOpenAIService()
elif provider == AIProvider.OLLAMA:
return OllamaService()
else:
raise ValueError(f"Unknown provider: {provider}")
@staticmethod
def get_provider_for_task(task_type: str) -> AIProvider:
"""Select best provider for specific task"""
task_mapping = {
"sql_generation": AIProvider.OPENAI, # Best for SQL
"code_generation": AIProvider.ANTHROPIC, # Best for Python
"embeddings": AIProvider.OPENAI, # Best embeddings
"local_dev": AIProvider.OLLAMA, # For development
}
return task_mapping.get(task_type, AIProvider.OPENAI)Fallback Strategy
# chat-server/app/services/ai_fallback_service.py
class AIFallbackService:
def __init__(self):
self.providers = [
(AIProvider.OPENAI, OpenAIService()),
(AIProvider.ANTHROPIC, AnthropicService()),
(AIProvider.AZURE, AzureOpenAIService()),
]
async def generate_with_fallback(
self,
prompt: str,
max_retries: int = 3
) -> str:
"""Try multiple providers with fallback"""
errors = []
for provider_name, provider in self.providers:
try:
response = await provider.generate(prompt)
# Log success
logger.info(f"Successfully used {provider_name}")
return response
except Exception as e:
errors.append(f"{provider_name}: {str(e)}")
logger.warning(f"Provider {provider_name} failed: {e}")
continue
# All providers failed
raise Exception(f"All providers failed: {'; '.join(errors)}")Embeddings Configuration
Vector Dimension Management
# chat-server/app/services/embedding_service.py
class EmbeddingService:
def __init__(self):
self.provider = os.getenv("EMBEDDING_PROVIDER", "openai")
self.dimensions = self._get_dimensions()
def _get_dimensions(self) -> int:
"""Get embedding dimensions by provider"""
dimension_map = {
"openai-ada-002": 1536,
"openai-3-small": 1536,
"openai-3-large": 3072,
"cohere-embed-v3": 1024,
"voyage-2": 1024,
"bge-large": 1024,
"e5-large-v2": 1024,
}
return dimension_map.get(self.provider, 1536)
async def generate_embeddings(
self,
texts: List[str],
batch_size: int = 100
) -> List[List[float]]:
"""Generate embeddings with batching"""
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
embeddings = await self._embed_batch(batch)
all_embeddings.extend(embeddings)
return all_embeddingsCost Optimization
Token Counting
# chat-server/app/utils/token_counter.py
import tiktoken
from typing import Dict
class TokenCounter:
def __init__(self):
self.encoders = {
"gpt-4": tiktoken.encoding_for_model("gpt-4"),
"gpt-3.5-turbo": tiktoken.encoding_for_model("gpt-3.5-turbo"),
"claude": tiktoken.get_encoding("cl100k_base"), # Approximation
}
# Cost per 1K tokens (in USD)
self.costs = {
"gpt-4": {"input": 0.03, "output": 0.06},
"gpt-4-turbo": {"input": 0.01, "output": 0.03},
"gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015},
"claude-3-opus": {"input": 0.015, "output": 0.075},
"claude-3-sonnet": {"input": 0.003, "output": 0.015},
}
def count_tokens(self, text: str, model: str) -> int:
"""Count tokens for text"""
encoder = self.encoders.get(model.split("-")[0], self.encoders["gpt-4"])
return len(encoder.encode(text))
def estimate_cost(
self,
input_text: str,
output_text: str,
model: str
) -> Dict[str, float]:
"""Estimate API cost"""
input_tokens = self.count_tokens(input_text, model)
output_tokens = self.count_tokens(output_text, model)
costs = self.costs.get(model, self.costs["gpt-4"])
input_cost = (input_tokens / 1000) * costs["input"]
output_cost = (output_tokens / 1000) * costs["output"]
return {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"input_cost": input_cost,
"output_cost": output_cost,
"total_cost": input_cost + output_cost,
}Caching Strategy
# chat-server/app/services/ai_cache_service.py
import hashlib
import json
from typing import Optional
class AICacheService:
def __init__(self, redis_client):
self.redis = redis_client
self.ttl = 3600 # 1 hour
def _generate_cache_key(self, prompt: str, model: str, params: dict) -> str:
"""Generate deterministic cache key"""
cache_data = {
"prompt": prompt,
"model": model,
"params": params,
}
cache_str = json.dumps(cache_data, sort_keys=True)
return f"ai_cache:{hashlib.sha256(cache_str.encode()).hexdigest()}"
async def get_cached_response(
self,
prompt: str,
model: str,
params: dict
) -> Optional[str]:
"""Get cached AI response"""
key = self._generate_cache_key(prompt, model, params)
cached = await self.redis.get(key)
if cached:
# Update TTL on cache hit
await self.redis.expire(key, self.ttl)
return cached.decode()
return None
async def cache_response(
self,
prompt: str,
model: str,
params: dict,
response: str
):
"""Cache AI response"""
key = self._generate_cache_key(prompt, model, params)
await self.redis.set(key, response, ex=self.ttl)Rate Limiting
Provider-Specific Limits
# chat-server/app/middleware/rate_limiter.py
from aioredis import Redis
import asyncio
from datetime import datetime, timedelta
class AIRateLimiter:
def __init__(self, redis: Redis):
self.redis = redis
# Rate limits per provider (requests per minute)
self.limits = {
"openai-gpt4": 200,
"openai-gpt35": 3500,
"anthropic-claude3": 1000,
"azure-openai": 600,
}
# Token limits per minute
self.token_limits = {
"openai-gpt4": 40000,
"openai-gpt35": 90000,
"anthropic-claude3": 100000,
}
async def check_rate_limit(
self,
provider: str,
tokens: int = 0
) -> bool:
"""Check if request is within rate limits"""
now = datetime.now()
minute_key = f"rate_limit:{provider}:{now.strftime('%Y%m%d%H%M')}"
token_key = f"token_limit:{provider}:{now.strftime('%Y%m%d%H%M')}"
# Check request count
current_count = await self.redis.incr(minute_key)
if current_count == 1:
await self.redis.expire(minute_key, 60)
if current_count > self.limits.get(provider, 1000):
return False
# Check token count
if tokens > 0 and provider in self.token_limits:
current_tokens = await self.redis.incrby(token_key, tokens)
if current_tokens == tokens:
await self.redis.expire(token_key, 60)
if current_tokens > self.token_limits[provider]:
return False
return True
async def wait_if_needed(self, provider: str):
"""Wait if rate limited"""
while not await self.check_rate_limit(provider):
await asyncio.sleep(1)Monitoring and Observability
Metrics Collection
# chat-server/app/monitoring/ai_metrics.py
from prometheus_client import Counter, Histogram, Gauge
import time
# Define metrics
ai_requests_total = Counter(
'ai_requests_total',
'Total AI API requests',
['provider', 'model', 'status']
)
ai_request_duration = Histogram(
'ai_request_duration_seconds',
'AI API request duration',
['provider', 'model']
)
ai_tokens_used = Counter(
'ai_tokens_used_total',
'Total tokens used',
['provider', 'model', 'type'] # type: input/output
)
ai_cost_total = Counter(
'ai_cost_usd_total',
'Total AI API cost in USD',
['provider', 'model']
)
class AIMetricsCollector:
@staticmethod
async def track_request(
provider: str,
model: str,
input_tokens: int,
output_tokens: int,
duration: float,
cost: float,
success: bool
):
# Track request count
ai_requests_total.labels(
provider=provider,
model=model,
status='success' if success else 'failure'
).inc()
# Track duration
ai_request_duration.labels(
provider=provider,
model=model
).observe(duration)
# Track tokens
ai_tokens_used.labels(
provider=provider,
model=model,
type='input'
).inc(input_tokens)
ai_tokens_used.labels(
provider=provider,
model=model,
type='output'
).inc(output_tokens)
# Track cost
ai_cost_total.labels(
provider=provider,
model=model
).inc(cost)Logging Configuration
# chat-server/app/utils/ai_logger.py
import structlog
from typing import Dict, Any
class AILogger:
def __init__(self):
self.logger = structlog.get_logger()
def log_request(
self,
provider: str,
model: str,
prompt: str,
response: str,
metadata: Dict[str, Any]
):
"""Log AI request with structured data"""
self.logger.info(
"ai_request",
provider=provider,
model=model,
prompt_preview=prompt[:100] + "..." if len(prompt) > 100 else prompt,
response_preview=response[:100] + "..." if len(response) > 100 else response,
tokens=metadata.get("tokens", {}),
duration=metadata.get("duration"),
cost=metadata.get("cost"),
cache_hit=metadata.get("cache_hit", False),
)Security Best Practices
API Key Management
# chat-server/app/security/api_key_manager.py
import os
from cryptography.fernet import Fernet
from typing import Optional
class APIKeyManager:
def __init__(self):
# Use environment variable or generate key
encryption_key = os.getenv("ENCRYPTION_KEY")
if not encryption_key:
encryption_key = Fernet.generate_key()
self.cipher = Fernet(encryption_key)
def encrypt_api_key(self, api_key: str) -> str:
"""Encrypt API key for storage"""
return self.cipher.encrypt(api_key.encode()).decode()
def decrypt_api_key(self, encrypted_key: str) -> str:
"""Decrypt API key for use"""
return self.cipher.decrypt(encrypted_key.encode()).decode()
@staticmethod
def validate_api_key_format(api_key: str, provider: str) -> bool:
"""Validate API key format"""
patterns = {
"openai": r"^sk-[a-zA-Z0-9]{48}$",
"anthropic": r"^sk-ant-[a-zA-Z0-9]{48}$",
"azure": r"^[a-f0-9]{32}$",
}
pattern = patterns.get(provider)
if pattern:
import re
return bool(re.match(pattern, api_key))
return TruePrompt Injection Prevention
# chat-server/app/security/prompt_security.py
class PromptSecurity:
@staticmethod
def sanitize_user_input(user_input: str) -> str:
"""Sanitize user input to prevent prompt injection"""
# Remove potential injection patterns
dangerous_patterns = [
r"ignore previous instructions",
r"disregard all prior",
r"system:\s*",
r"assistant:\s*",
r"\[INST\]",
r"\[/INST\]",
]
sanitized = user_input
for pattern in dangerous_patterns:
import re
sanitized = re.sub(pattern, "", sanitized, flags=re.IGNORECASE)
return sanitized.strip()
@staticmethod
def validate_sql_output(sql: str) -> bool:
"""Validate generated SQL for safety"""
# Prevent dangerous operations
dangerous_keywords = [
"DROP", "DELETE", "TRUNCATE", "ALTER",
"CREATE", "GRANT", "REVOKE", "EXECUTE"
]
sql_upper = sql.upper()
for keyword in dangerous_keywords:
if keyword in sql_upper:
return False
return TrueNext Steps
- Deploy with Docker
- Set up Kubernetes Deployment
- Configure Monitoring
- Review Scaling Strategies