Source code for config.provider_registry

"""
Script
------
provider_registry.py

Path
----
python/hillstar/config/provider_registry.py

Purpose
-------
Provider Registry: Central registry for LLM providers, models, and compliance rules.

Provides a ProviderRegistry class that loads provider configurations from JSON
and provides lookup methods for model selection, cost estimation, and compliance
verification. Supports package defaults with user overrides for customization.

Inputs
------
Provider registry JSON files (default + optional user override)

Outputs
-------
Registry instance with lookup methods for providers, models, and compliance

Assumptions
-----------
- Default registry file exists at package location
- User override follows same schema as default

Parameters
----------
None (per-query)

Failure Modes
-------------
- Missing default registry FileNotFoundError
- Malformed JSON JSONDecodeError
- Invalid provider/model Returns None

Author: Julen Gamboa <julen.gamboa.ds@gmail.com>

Created
-------
2026-02-14

Last Edited
-----------
2026-02-14 (initial implementation)
"""

import json
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple


[docs] class ProviderRegistry: """ Load and query the provider registry with fallback to user overrides. The registry is loaded from: 1. Package default: python/hillstar/config/provider_registry.default.json 2. User override: ~/.hillstar/provider_registry.json (optional) """ DEFAULT_REGISTRY_PATH = Path(__file__).parent / "provider_registry.default.json" USER_OVERRIDE_PATH = Path(os.path.expanduser("~/.hillstar/provider_registry.json"))
[docs] def __init__(self, custom_registry_path: Optional[str] = None): """ Initialize the provider registry. Args: custom_registry_path: Optional path to a custom registry file. If provided, this takes precedence over both default and user override. """ self._registry: Dict[str, Any] = {} self._providers: Dict[str, Dict[str, Any]] = {} self._models_cache: Dict[Tuple[str, str], Dict[str, Any]] = {} self._load_registry(custom_registry_path)
def _load_registry(self, custom_registry_path: Optional[str] = None) -> None: """Load registry from JSON files.""" registry_data = {} # 1. Load package default default_path = self.DEFAULT_REGISTRY_PATH if default_path.exists(): with open(default_path) as f: registry_data = json.load(f) else: raise FileNotFoundError( f"Default provider registry not found at {default_path}. " "Please reinstall the hillstar package." ) # 2. Merge user override if exists (takes precedence) user_path = self.USER_OVERRIDE_PATH if custom_registry_path: custom_path = Path(custom_registry_path) if custom_path.exists(): with open(custom_path) as f: custom_data = json.load(f) registry_data = self._merge_registry(registry_data, custom_data) elif user_path.exists(): with open(user_path) as f: user_data = json.load(f) registry_data = self._merge_registry(registry_data, user_data) self._registry = registry_data self._providers = registry_data.get("providers", {}) self._build_models_cache() def _merge_registry( self, base: Dict[str, Any], override: Dict[str, Any] ) -> Dict[str, Any]: """Deep merge override into base, with override taking precedence.""" result = base.copy() for key, value in override.items(): if key == "providers" and key in result: # Deep merge providers result["providers"] = {**result.get("providers", {}), **value} else: result[key] = value return result def _build_models_cache(self) -> None: """Build a flat cache of all models for fast lookup.""" for provider_name, provider_data in self._providers.items(): models = provider_data.get("models", {}) for model_id, model_data in models.items(): self._models_cache[(provider_name, model_id)] = model_data @property def version(self) -> str: """Get registry version.""" return self._registry.get("version", "unknown") @property def last_updated(self) -> str: """Get last update timestamp.""" return self._registry.get("last_updated", "unknown")
[docs] def list_providers(self, provider_type: Optional[str] = None) -> List[str]: """ List available providers, optionally filtered by type. Args: provider_type: Optional filter: "cloud_api", "local", "local_proxy" Returns: List of provider names """ providers = self._providers.keys() if provider_type: providers = [ p for p in providers if self._providers[p].get("type") == provider_type ] return list(providers)
[docs] def get_provider(self, provider_name: str) -> Optional[Dict[str, Any]]: """Get full provider configuration.""" return self._providers.get(provider_name)
[docs] def get_provider_compliance(self, provider_name: str) -> Optional[Dict[str, Any]]: """Get compliance rules for a provider.""" provider = self.get_provider(provider_name) if provider: return provider.get("compliance") return None
[docs] def get_model(self, provider_name: str, model_id: str) -> Optional[Dict[str, Any]]: """ Get model configuration. Args: provider_name: Provider identifier (e.g., "anthropic") model_id: Model identifier (e.g., "claude-opus-4-6") Returns: Model configuration dict or None """ provider = self._providers.get(provider_name) if provider: return provider.get("models", {}).get(model_id) return None
[docs] def find_models( self, capabilities: Optional[List[str]] = None, max_tier: Optional[str] = None, provider_type: Optional[str] = None, require_ollama: Optional[bool] = None, ) -> List[Dict[str, Any]]: """ Find models matching criteria. Args: capabilities: List of required capabilities (e.g., ["coding", "reasoning"]) max_tier: Maximum cost tier (e.g., "cheap", "standard") provider_type: Filter by provider type (e.g., "cloud_api", "local") require_ollama: If True, only return models requiring Ollama Returns: List of matching model configs with provider context """ matching = [] tier_order = ["free", "affordable", "standard", "expensive", "premium"] max_tier_idx = tier_order.index(max_tier) if max_tier else len(tier_order) - 1 for provider_name, provider_data in self._providers.items(): # Filter by provider type if provider_type and provider_data.get("type") != provider_type: continue models = provider_data.get("models", {}) for model_id, model_data in models.items(): # Filter by capabilities if capabilities: model_caps = model_data.get("capabilities", []) if not all(cap in model_caps for cap in capabilities): continue # Filter by tier model_tier = model_data.get("tier", "premium") if tier_order.index(model_tier) > max_tier_idx: continue # Filter by Ollama requirement if require_ollama is not None: if require_ollama != model_data.get("requires_ollama", False): continue matching.append( { "provider": provider_name, "model_id": model_id, "display_name": model_data.get("display_name", model_id), "tier": model_tier, **model_data, } ) return matching
[docs] def get_cheapest_model( self, capabilities: Optional[List[str]] = None, provider_preference: Optional[List[str]] = None, ) -> Optional[Tuple[str, str, Dict[str, Any]]]: """ Get the cheapest model matching criteria, respecting provider preference. Args: capabilities: Required capabilities provider_preference: Preferred provider order (e.g., ["anthropic", "openai"]) Returns: Tuple of (provider, model_id, model_config) or None """ candidates = self.find_models(capabilities=capabilities, max_tier="premium") if not candidates: return None # Sort by: 1) provider preference, 2) tier order tier_order = ["free", "affordable", "standard", "expensive", "premium"] def sort_key(item): provider = item["provider"] tier_idx = tier_order.index(item.get("tier", "standard")) # Provider preference penalty pref_penalty = 0 if provider_preference: try: pref_penalty = provider_preference.index(provider) * 10 except ValueError: pref_penalty = 100 # Unpreferred providers last return (pref_penalty, tier_idx) candidates.sort(key=sort_key) best = candidates[0] return (best["provider"], best["model_id"], best)
[docs] def estimate_cost( self, provider_name: str, model_id: str, input_tokens: int, output_tokens: int, ) -> float: """ Estimate cost for a model call. Args: provider_name: Provider identifier model_id: Model identifier input_tokens: Number of input tokens output_tokens: Number of output tokens Returns: Estimated cost in USD """ model = self.get_model(provider_name, model_id) if not model: return 0.0 pricing = model.get("pricing", {}) input_cost = (input_tokens / 1_000_000) * pricing.get("input_per_1m_usd", 0) output_cost = (output_tokens / 1_000_000) * pricing.get("output_per_1m_usd", 0) return input_cost + output_cost
[docs] def get_fallback_chain( self, complexity: str, provider_preference: Optional[List[str]] = None, ) -> List[str]: """ Get provider fallback chain for a complexity level. Args: complexity: Task complexity ("simple", "moderate", "complex", "critical") provider_preference: Preferred providers (highest priority first) Returns: List of providers in fallback order """ default_chain = self._registry.get("default_fallback_chain", {}) chain = default_chain.get( complexity, default_chain.get("moderate", ["anthropic", "openai"]) ) if provider_preference: # Insert preferred providers at the beginning result = [] for pref in provider_preference: if pref in self._providers: result.append(pref) for provider in chain: if provider not in result: result.append(provider) return result return chain
[docs] def is_usage_compliant( self, provider_name: str, use_case: str, ) -> Tuple[bool, str]: """ Check if a use case is compliant for a provider. Args: provider_name: Provider identifier use_case: Intended use case (e.g., "research", "commercial") Returns: Tuple of (is_compliant, reason) """ compliance = self.get_provider_compliance(provider_name) if not compliance: return (True, "No compliance rules defined") allowed = compliance.get("allowed_use_cases", []) restricted = compliance.get("restricted_use_cases", []) if use_case in restricted: return (False, f"Use case '{use_case}' is restricted by provider") if allowed and use_case not in allowed: return (False, f"Use case '{use_case}' not in allowed list: {allowed}") return (True, "Use case is allowed")
[docs] def get_model_sampling_params( self, provider_name: str, model_id: str, ) -> Dict[str, Any]: """Get default sampling parameters for a model.""" model = self.get_model(provider_name, model_id) if model: return model.get("default_sampling_params", {}) return {}
[docs] def get_all_models_flat(self) -> Dict[Tuple[str, str], Dict[str, Any]]: """Get a flat dictionary of all (provider, model_id) -> model_config.""" return self._models_cache.copy()
[docs] def describe(self) -> str: """Get a human-readable description of the registry.""" lines = [ f"Provider Registry v{self.version}", f"Last updated: {self.last_updated}", "", "Providers:", ] for provider_name, provider_data in sorted(self._providers.items()): model_count = len(provider_data.get("models", {})) provider_type = provider_data.get("type", "unknown") display = provider_data.get("display_name", provider_name) lines.append( f" - {provider_name}: {display} ({model_count} models, {provider_type})" ) return "\n".join(lines)
# Global registry instance (lazy loaded) _registry_instance: Optional["ProviderRegistry"] = None
[docs] def get_registry() -> "ProviderRegistry": """Get the global registry instance.""" global _registry_instance if _registry_instance is None: _registry_instance = ProviderRegistry() return _registry_instance
[docs] def reset_registry() -> None: """Reset the global registry instance (useful for testing).""" global _registry_instance _registry_instance = None