AI Providers Configuration

Setting up and configuring AI/LLM providers for GoPie

GoPie supports multiple AI providers for natural language processing, SQL generation, and embeddings. This guide covers configuration and optimization for each provider.

Overview

GoPie uses AI providers for:

  • Natural Language Understanding - Parsing user queries
  • SQL Generation - Converting questions to SQL
  • Schema Analysis - Understanding database structures
  • Embeddings - Semantic search of schemas
  • Code Execution - Python code generation for visualizations

OpenAI Configuration

Basic Setup

# .env configuration
OPENAI_API_KEY=sk-your-api-key
OPENAI_ORG_ID=org-your-org-id  # Optional
OPENAI_API_BASE=https://api.openai.com/v1  # Optional custom endpoint

Model Configuration

# chat-server/app/config/ai_config.py
from enum import Enum
from pydantic import BaseSettings

class OpenAIModel(str, Enum):
    GPT4 = "gpt-4"
    GPT4_TURBO = "gpt-4-turbo-preview"
    GPT35_TURBO = "gpt-3.5-turbo"
    GPT35_TURBO_16K = "gpt-3.5-turbo-16k"
    
class EmbeddingModel(str, Enum):
    ADA_002 = "text-embedding-ada-002"
    ADA_003 = "text-embedding-3-small"
    ADA_003_LARGE = "text-embedding-3-large"

class OpenAIConfig(BaseSettings):
    api_key: str
    org_id: Optional[str] = None
    api_base: str = "https://api.openai.com/v1"
    
    # Model selection
    default_model: OpenAIModel = OpenAIModel.GPT4
    embedding_model: EmbeddingModel = EmbeddingModel.ADA_002
    
    # Model parameters
    temperature: float = 0.7
    max_tokens: int = 2000
    top_p: float = 1.0
    frequency_penalty: float = 0.0
    presence_penalty: float = 0.0
    
    # Rate limiting
    max_retries: int = 3
    retry_delay: float = 1.0
    request_timeout: int = 60
    
    # Cost optimization
    enable_caching: bool = True
    cache_ttl: int = 3600  # 1 hour
    
    class Config:
        env_prefix = "OPENAI_"

LangChain Integration

# chat-server/app/services/llm_service.py
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain.callbacks import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.cache import RedisCache
import redis

class LLMService:
    def __init__(self, config: OpenAIConfig):
        self.config = config
        
        # Initialize Redis cache if enabled
        if config.enable_caching:
            redis_client = redis.Redis.from_url(os.getenv("REDIS_URL", "redis://localhost:6379"))
            langchain.llm_cache = RedisCache(redis_client)
        
        # Initialize LLM with streaming
        self.llm = ChatOpenAI(
            api_key=config.api_key,
            organization=config.org_id,
            model_name=config.default_model,
            temperature=config.temperature,
            max_tokens=config.max_tokens,
            streaming=True,
            callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),
            request_timeout=config.request_timeout,
            max_retries=config.max_retries,
        )
        
        # Initialize embeddings
        self.embeddings = OpenAIEmbeddings(
            api_key=config.api_key,
            organization=config.org_id,
            model=config.embedding_model,
            request_timeout=config.request_timeout,
            max_retries=config.max_retries,
        )
    
    async def generate_sql(self, question: str, schema: str) -> str:
        """Generate SQL from natural language question"""
        prompt = f"""
        Given the following database schema:
        {schema}
        
        Generate a SQL query to answer this question:
        {question}
        
        Return only the SQL query without explanation.
        """
        
        response = await self.llm.ainvoke(prompt)
        return response.content.strip()

Anthropic Configuration

Setup

# .env configuration
ANTHROPIC_API_KEY=sk-ant-your-api-key
ANTHROPIC_MODEL=claude-3-opus-20240229  # or claude-3-sonnet-20240229
ANTHROPIC_MAX_TOKENS=4000

Anthropic Integration

# chat-server/app/services/anthropic_service.py
from anthropic import Anthropic, AsyncAnthropic
from typing import AsyncGenerator

class AnthropicService:
    def __init__(self):
        self.client = AsyncAnthropic(
            api_key=os.getenv("ANTHROPIC_API_KEY"),
        )
        self.model = os.getenv("ANTHROPIC_MODEL", "claude-3-opus-20240229")
        self.max_tokens = int(os.getenv("ANTHROPIC_MAX_TOKENS", "4000"))
    
    async def generate_sql_streaming(
        self, 
        question: str, 
        schema: str
    ) -> AsyncGenerator[str, None]:
        """Generate SQL with streaming response"""
        prompt = f"""
        You are a SQL expert. Given this database schema:
        
        {schema}
        
        Generate a SQL query to answer: {question}
        
        Requirements:
        1. Use proper SQL syntax
        2. Handle edge cases
        3. Optimize for performance
        4. Return only the SQL query
        """
        
        stream = await self.client.messages.create(
            model=self.model,
            max_tokens=self.max_tokens,
            temperature=0.3,  # Lower temperature for SQL generation
            messages=[
                {"role": "user", "content": prompt}
            ],
            stream=True,
        )
        
        async for chunk in stream:
            if chunk.type == "content_block_delta":
                yield chunk.delta.text

Azure OpenAI Configuration

Setup

# .env configuration
AZURE_OPENAI_API_KEY=your-azure-api-key
AZURE_OPENAI_ENDPOINT=https://your-resource.openai.azure.com/
AZURE_OPENAI_DEPLOYMENT_NAME=your-deployment-name
AZURE_OPENAI_API_VERSION=2024-02-15-preview

Azure Integration

# chat-server/app/services/azure_openai_service.py
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings

class AzureOpenAIService:
    def __init__(self):
        self.llm = AzureChatOpenAI(
            azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
            api_key=os.getenv("AZURE_OPENAI_API_KEY"),
            azure_deployment=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME"),
            api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
            temperature=0.7,
            max_tokens=2000,
            streaming=True,
        )
        
        self.embeddings = AzureOpenAIEmbeddings(
            azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
            api_key=os.getenv("AZURE_OPENAI_API_KEY"),
            azure_deployment=os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT"),
            api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
        )

Local Model Support

Ollama Configuration

# Install Ollama
curl -fsSL https://ollama.ai/install.sh | sh

# Pull models
ollama pull llama2
ollama pull codellama
ollama pull mistral

# .env configuration
OLLAMA_BASE_URL=http://localhost:11434
OLLAMA_MODEL=codellama:13b

Ollama Integration

# chat-server/app/services/ollama_service.py
from langchain_community.llms import Ollama
from langchain_community.embeddings import OllamaEmbeddings

class OllamaService:
    def __init__(self):
        self.llm = Ollama(
            base_url=os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"),
            model=os.getenv("OLLAMA_MODEL", "codellama:13b"),
            temperature=0.7,
            num_predict=2000,
            streaming=True,
        )
        
        self.embeddings = OllamaEmbeddings(
            base_url=os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"),
            model="nomic-embed-text",  # Specialized embedding model
        )

Multi-Provider Strategy

Provider Selection

# chat-server/app/services/ai_provider_factory.py
from enum import Enum
from typing import Union

class AIProvider(str, Enum):
    OPENAI = "openai"
    ANTHROPIC = "anthropic"
    AZURE = "azure"
    OLLAMA = "ollama"
    
class AIProviderFactory:
    @staticmethod
    def create_llm_service(provider: AIProvider) -> BaseLLMService:
        if provider == AIProvider.OPENAI:
            return OpenAIService()
        elif provider == AIProvider.ANTHROPIC:
            return AnthropicService()
        elif provider == AIProvider.AZURE:
            return AzureOpenAIService()
        elif provider == AIProvider.OLLAMA:
            return OllamaService()
        else:
            raise ValueError(f"Unknown provider: {provider}")
    
    @staticmethod
    def get_provider_for_task(task_type: str) -> AIProvider:
        """Select best provider for specific task"""
        task_mapping = {
            "sql_generation": AIProvider.OPENAI,  # Best for SQL
            "code_generation": AIProvider.ANTHROPIC,  # Best for Python
            "embeddings": AIProvider.OPENAI,  # Best embeddings
            "local_dev": AIProvider.OLLAMA,  # For development
        }
        return task_mapping.get(task_type, AIProvider.OPENAI)

Fallback Strategy

# chat-server/app/services/ai_fallback_service.py
class AIFallbackService:
    def __init__(self):
        self.providers = [
            (AIProvider.OPENAI, OpenAIService()),
            (AIProvider.ANTHROPIC, AnthropicService()),
            (AIProvider.AZURE, AzureOpenAIService()),
        ]
    
    async def generate_with_fallback(
        self, 
        prompt: str,
        max_retries: int = 3
    ) -> str:
        """Try multiple providers with fallback"""
        errors = []
        
        for provider_name, provider in self.providers:
            try:
                response = await provider.generate(prompt)
                # Log success
                logger.info(f"Successfully used {provider_name}")
                return response
            except Exception as e:
                errors.append(f"{provider_name}: {str(e)}")
                logger.warning(f"Provider {provider_name} failed: {e}")
                continue
        
        # All providers failed
        raise Exception(f"All providers failed: {'; '.join(errors)}")

Embeddings Configuration

Vector Dimension Management

# chat-server/app/services/embedding_service.py
class EmbeddingService:
    def __init__(self):
        self.provider = os.getenv("EMBEDDING_PROVIDER", "openai")
        self.dimensions = self._get_dimensions()
    
    def _get_dimensions(self) -> int:
        """Get embedding dimensions by provider"""
        dimension_map = {
            "openai-ada-002": 1536,
            "openai-3-small": 1536,
            "openai-3-large": 3072,
            "cohere-embed-v3": 1024,
            "voyage-2": 1024,
            "bge-large": 1024,
            "e5-large-v2": 1024,
        }
        return dimension_map.get(self.provider, 1536)
    
    async def generate_embeddings(
        self, 
        texts: List[str],
        batch_size: int = 100
    ) -> List[List[float]]:
        """Generate embeddings with batching"""
        all_embeddings = []
        
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]
            embeddings = await self._embed_batch(batch)
            all_embeddings.extend(embeddings)
        
        return all_embeddings

Cost Optimization

Token Counting

# chat-server/app/utils/token_counter.py
import tiktoken
from typing import Dict

class TokenCounter:
    def __init__(self):
        self.encoders = {
            "gpt-4": tiktoken.encoding_for_model("gpt-4"),
            "gpt-3.5-turbo": tiktoken.encoding_for_model("gpt-3.5-turbo"),
            "claude": tiktoken.get_encoding("cl100k_base"),  # Approximation
        }
        
        # Cost per 1K tokens (in USD)
        self.costs = {
            "gpt-4": {"input": 0.03, "output": 0.06},
            "gpt-4-turbo": {"input": 0.01, "output": 0.03},
            "gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015},
            "claude-3-opus": {"input": 0.015, "output": 0.075},
            "claude-3-sonnet": {"input": 0.003, "output": 0.015},
        }
    
    def count_tokens(self, text: str, model: str) -> int:
        """Count tokens for text"""
        encoder = self.encoders.get(model.split("-")[0], self.encoders["gpt-4"])
        return len(encoder.encode(text))
    
    def estimate_cost(
        self, 
        input_text: str, 
        output_text: str, 
        model: str
    ) -> Dict[str, float]:
        """Estimate API cost"""
        input_tokens = self.count_tokens(input_text, model)
        output_tokens = self.count_tokens(output_text, model)
        
        costs = self.costs.get(model, self.costs["gpt-4"])
        
        input_cost = (input_tokens / 1000) * costs["input"]
        output_cost = (output_tokens / 1000) * costs["output"]
        
        return {
            "input_tokens": input_tokens,
            "output_tokens": output_tokens,
            "input_cost": input_cost,
            "output_cost": output_cost,
            "total_cost": input_cost + output_cost,
        }

Caching Strategy

# chat-server/app/services/ai_cache_service.py
import hashlib
import json
from typing import Optional

class AICacheService:
    def __init__(self, redis_client):
        self.redis = redis_client
        self.ttl = 3600  # 1 hour
    
    def _generate_cache_key(self, prompt: str, model: str, params: dict) -> str:
        """Generate deterministic cache key"""
        cache_data = {
            "prompt": prompt,
            "model": model,
            "params": params,
        }
        cache_str = json.dumps(cache_data, sort_keys=True)
        return f"ai_cache:{hashlib.sha256(cache_str.encode()).hexdigest()}"
    
    async def get_cached_response(
        self, 
        prompt: str, 
        model: str, 
        params: dict
    ) -> Optional[str]:
        """Get cached AI response"""
        key = self._generate_cache_key(prompt, model, params)
        cached = await self.redis.get(key)
        
        if cached:
            # Update TTL on cache hit
            await self.redis.expire(key, self.ttl)
            return cached.decode()
        
        return None
    
    async def cache_response(
        self, 
        prompt: str, 
        model: str, 
        params: dict, 
        response: str
    ):
        """Cache AI response"""
        key = self._generate_cache_key(prompt, model, params)
        await self.redis.set(key, response, ex=self.ttl)

Rate Limiting

Provider-Specific Limits

# chat-server/app/middleware/rate_limiter.py
from aioredis import Redis
import asyncio
from datetime import datetime, timedelta

class AIRateLimiter:
    def __init__(self, redis: Redis):
        self.redis = redis
        
        # Rate limits per provider (requests per minute)
        self.limits = {
            "openai-gpt4": 200,
            "openai-gpt35": 3500,
            "anthropic-claude3": 1000,
            "azure-openai": 600,
        }
        
        # Token limits per minute
        self.token_limits = {
            "openai-gpt4": 40000,
            "openai-gpt35": 90000,
            "anthropic-claude3": 100000,
        }
    
    async def check_rate_limit(
        self, 
        provider: str, 
        tokens: int = 0
    ) -> bool:
        """Check if request is within rate limits"""
        now = datetime.now()
        minute_key = f"rate_limit:{provider}:{now.strftime('%Y%m%d%H%M')}"
        token_key = f"token_limit:{provider}:{now.strftime('%Y%m%d%H%M')}"
        
        # Check request count
        current_count = await self.redis.incr(minute_key)
        if current_count == 1:
            await self.redis.expire(minute_key, 60)
        
        if current_count > self.limits.get(provider, 1000):
            return False
        
        # Check token count
        if tokens > 0 and provider in self.token_limits:
            current_tokens = await self.redis.incrby(token_key, tokens)
            if current_tokens == tokens:
                await self.redis.expire(token_key, 60)
            
            if current_tokens > self.token_limits[provider]:
                return False
        
        return True
    
    async def wait_if_needed(self, provider: str):
        """Wait if rate limited"""
        while not await self.check_rate_limit(provider):
            await asyncio.sleep(1)

Monitoring and Observability

Metrics Collection

# chat-server/app/monitoring/ai_metrics.py
from prometheus_client import Counter, Histogram, Gauge
import time

# Define metrics
ai_requests_total = Counter(
    'ai_requests_total',
    'Total AI API requests',
    ['provider', 'model', 'status']
)

ai_request_duration = Histogram(
    'ai_request_duration_seconds',
    'AI API request duration',
    ['provider', 'model']
)

ai_tokens_used = Counter(
    'ai_tokens_used_total',
    'Total tokens used',
    ['provider', 'model', 'type']  # type: input/output
)

ai_cost_total = Counter(
    'ai_cost_usd_total',
    'Total AI API cost in USD',
    ['provider', 'model']
)

class AIMetricsCollector:
    @staticmethod
    async def track_request(
        provider: str,
        model: str,
        input_tokens: int,
        output_tokens: int,
        duration: float,
        cost: float,
        success: bool
    ):
        # Track request count
        ai_requests_total.labels(
            provider=provider,
            model=model,
            status='success' if success else 'failure'
        ).inc()
        
        # Track duration
        ai_request_duration.labels(
            provider=provider,
            model=model
        ).observe(duration)
        
        # Track tokens
        ai_tokens_used.labels(
            provider=provider,
            model=model,
            type='input'
        ).inc(input_tokens)
        
        ai_tokens_used.labels(
            provider=provider,
            model=model,
            type='output'
        ).inc(output_tokens)
        
        # Track cost
        ai_cost_total.labels(
            provider=provider,
            model=model
        ).inc(cost)

Logging Configuration

# chat-server/app/utils/ai_logger.py
import structlog
from typing import Dict, Any

class AILogger:
    def __init__(self):
        self.logger = structlog.get_logger()
    
    def log_request(
        self,
        provider: str,
        model: str,
        prompt: str,
        response: str,
        metadata: Dict[str, Any]
    ):
        """Log AI request with structured data"""
        self.logger.info(
            "ai_request",
            provider=provider,
            model=model,
            prompt_preview=prompt[:100] + "..." if len(prompt) > 100 else prompt,
            response_preview=response[:100] + "..." if len(response) > 100 else response,
            tokens=metadata.get("tokens", {}),
            duration=metadata.get("duration"),
            cost=metadata.get("cost"),
            cache_hit=metadata.get("cache_hit", False),
        )

Security Best Practices

API Key Management

# chat-server/app/security/api_key_manager.py
import os
from cryptography.fernet import Fernet
from typing import Optional

class APIKeyManager:
    def __init__(self):
        # Use environment variable or generate key
        encryption_key = os.getenv("ENCRYPTION_KEY")
        if not encryption_key:
            encryption_key = Fernet.generate_key()
        
        self.cipher = Fernet(encryption_key)
    
    def encrypt_api_key(self, api_key: str) -> str:
        """Encrypt API key for storage"""
        return self.cipher.encrypt(api_key.encode()).decode()
    
    def decrypt_api_key(self, encrypted_key: str) -> str:
        """Decrypt API key for use"""
        return self.cipher.decrypt(encrypted_key.encode()).decode()
    
    @staticmethod
    def validate_api_key_format(api_key: str, provider: str) -> bool:
        """Validate API key format"""
        patterns = {
            "openai": r"^sk-[a-zA-Z0-9]{48}$",
            "anthropic": r"^sk-ant-[a-zA-Z0-9]{48}$",
            "azure": r"^[a-f0-9]{32}$",
        }
        
        pattern = patterns.get(provider)
        if pattern:
            import re
            return bool(re.match(pattern, api_key))
        
        return True

Prompt Injection Prevention

# chat-server/app/security/prompt_security.py
class PromptSecurity:
    @staticmethod
    def sanitize_user_input(user_input: str) -> str:
        """Sanitize user input to prevent prompt injection"""
        # Remove potential injection patterns
        dangerous_patterns = [
            r"ignore previous instructions",
            r"disregard all prior",
            r"system:\s*",
            r"assistant:\s*",
            r"\[INST\]",
            r"\[/INST\]",
        ]
        
        sanitized = user_input
        for pattern in dangerous_patterns:
            import re
            sanitized = re.sub(pattern, "", sanitized, flags=re.IGNORECASE)
        
        return sanitized.strip()
    
    @staticmethod
    def validate_sql_output(sql: str) -> bool:
        """Validate generated SQL for safety"""
        # Prevent dangerous operations
        dangerous_keywords = [
            "DROP", "DELETE", "TRUNCATE", "ALTER", 
            "CREATE", "GRANT", "REVOKE", "EXECUTE"
        ]
        
        sql_upper = sql.upper()
        for keyword in dangerous_keywords:
            if keyword in sql_upper:
                return False
        
        return True

Next Steps