""" CUDA backend download, assembly, and verification. Downloads two archives from GitHub Releases: 1. Server core (voicebox-server-cuda.tar.gz) — the exe + non-NVIDIA deps, versioned with the app. 2. CUDA libs (cuda-libs-{version}.tar.gz) — NVIDIA runtime libraries, versioned independently (only redownloaded on CUDA toolkit bump). Both archives are extracted into {data_dir}/backends/cuda/ which forms the complete PyInstaller --onedir directory structure that torch expects. """ import asyncio import hashlib import json import logging import os import sys import tarfile from pathlib import Path from typing import Optional from ..config import get_data_dir from ..utils.progress import get_progress_manager from .. import __version__ logger = logging.getLogger(__name__) GITHUB_RELEASES_URL = "https://github.com/jamiepine/voicebox/releases/download" PROGRESS_KEY = "cuda-backend" # The current expected CUDA libs version. Bump this when we change the # CUDA toolkit version or torch's CUDA dependency changes (e.g. cu126 -> cu128). CUDA_LIBS_VERSION = "cu128-v1" # Prevents concurrent download_cuda_binary() calls from racing on the same # temp file. The auto-update background task and the manual HTTP endpoint # can both invoke download_cuda_binary(); without this lock the progress- # manager status check is a TOCTOU race. _download_lock = asyncio.Lock() def get_backends_dir() -> Path: """Directory where downloaded backend binaries are stored.""" d = get_data_dir() / "backends" d.mkdir(parents=True, exist_ok=True) return d def get_cuda_dir() -> Path: """Directory where the CUDA backend (onedir) is extracted.""" d = get_backends_dir() / "cuda" d.mkdir(parents=True, exist_ok=True) return d def get_cuda_exe_name() -> str: """Platform-specific CUDA executable filename.""" if sys.platform == "win32": return "voicebox-server-cuda.exe" return "voicebox-server-cuda" def get_cuda_binary_path() -> Optional[Path]: """Return path to the CUDA executable if it exists inside the onedir.""" p = get_cuda_dir() / get_cuda_exe_name() if p.exists(): return p return None def get_cuda_libs_manifest_path() -> Path: """Path to the cuda-libs.json manifest inside the CUDA dir.""" return get_cuda_dir() / "cuda-libs.json" def get_installed_cuda_libs_version() -> Optional[str]: """Read the installed CUDA libs version from cuda-libs.json, or None.""" manifest_path = get_cuda_libs_manifest_path() if not manifest_path.exists(): return None try: data = json.loads(manifest_path.read_text()) return data.get("version") except Exception as e: logger.warning(f"Could not read cuda-libs.json: {e}") return None def is_cuda_active() -> bool: """Check if the current process is the CUDA binary. The CUDA binary sets this env var on startup (see server.py). """ return os.environ.get("VOICEBOX_BACKEND_VARIANT") == "cuda" def get_cuda_status() -> dict: """Get current CUDA backend status for the API.""" progress_manager = get_progress_manager() cuda_path = get_cuda_binary_path() progress = progress_manager.get_progress(PROGRESS_KEY) cuda_libs_version = get_installed_cuda_libs_version() return { "available": cuda_path is not None, "active": is_cuda_active(), "binary_path": str(cuda_path) if cuda_path else None, "cuda_libs_version": cuda_libs_version, "downloading": progress is not None and progress.get("status") == "downloading", "download_progress": progress, } def _needs_server_download(version: Optional[str] = None) -> bool: """Check if the server core archive needs to be (re)downloaded.""" cuda_path = get_cuda_binary_path() if not cuda_path: return True # Check if the binary version matches the expected app version installed = get_cuda_binary_version() expected = version or __version__ if expected.startswith("v"): expected = expected[1:] return installed != expected def _needs_cuda_libs_download() -> bool: """Check if the CUDA libs archive needs to be (re)downloaded.""" installed = get_installed_cuda_libs_version() if installed is None: return True return installed != CUDA_LIBS_VERSION async def _download_and_extract_archive( client, url: str, sha256_url: Optional[str], dest_dir: Path, label: str, progress_offset: int, total_size: int, ): """Download a .tar.gz archive and extract it into dest_dir. Args: client: httpx.AsyncClient url: URL of the .tar.gz archive sha256_url: URL of the .sha256 checksum file (optional) dest_dir: Directory to extract into label: Human-readable label for progress updates progress_offset: Byte offset for progress reporting (when downloading multiple archives sequentially) total_size: Total bytes across all downloads (for progress bar) """ progress = get_progress_manager() temp_path = dest_dir / f".download-{label.replace(' ', '-')}.tmp" # Clean up leftover partial download if temp_path.exists(): temp_path.unlink() # Fetch expected checksum (fail-fast: never extract an unverified archive) expected_sha = None if sha256_url: try: sha_resp = await client.get(sha256_url) sha_resp.raise_for_status() expected_sha = sha_resp.text.strip().split()[0] logger.info(f"{label}: expected SHA-256: {expected_sha[:16]}...") except Exception as e: raise RuntimeError(f"{label}: failed to fetch checksum from {sha256_url}") from e # Stream download, verify, and extract — always clean up temp file downloaded = 0 try: async with client.stream("GET", url) as response: response.raise_for_status() with open(temp_path, "wb") as f: async for chunk in response.aiter_bytes(chunk_size=1024 * 1024): f.write(chunk) downloaded += len(chunk) progress.update_progress( PROGRESS_KEY, current=progress_offset + downloaded, total=total_size, filename=f"Downloading {label}", status="downloading", ) # Verify integrity if expected_sha: progress.update_progress( PROGRESS_KEY, current=progress_offset + downloaded, total=total_size, filename=f"Verifying {label}...", status="downloading", ) sha256 = hashlib.sha256() with open(temp_path, "rb") as f: while True: data = f.read(1024 * 1024) if not data: break sha256.update(data) actual = sha256.hexdigest() if actual != expected_sha: raise ValueError( f"{label} integrity check failed: expected {expected_sha[:16]}..., got {actual[:16]}..." ) logger.info(f"{label}: integrity verified") # Extract (use data filter for path traversal protection on Python 3.12+) progress.update_progress( PROGRESS_KEY, current=progress_offset + downloaded, total=total_size, filename=f"Extracting {label}...", status="downloading", ) with tarfile.open(temp_path, "r:gz") as tar: if sys.version_info >= (3, 12): tar.extractall(path=dest_dir, filter="data") else: tar.extractall(path=dest_dir) logger.info(f"{label}: extracted to {dest_dir}") finally: if temp_path.exists(): temp_path.unlink() return downloaded async def download_cuda_binary(version: Optional[str] = None): """Download the CUDA backend (server core + CUDA libs if needed). Downloads both archives from GitHub Releases, extracts them into {data_dir}/backends/cuda/, and writes the cuda-libs.json manifest. Only downloads what's needed: - Server core: always redownloaded (versioned with app) - CUDA libs: only if missing or version mismatch Args: version: Version tag (e.g. "v0.3.0"). Defaults to current app version. """ if _download_lock.locked(): logger.info("CUDA download already in progress, skipping duplicate request") return async with _download_lock: await _download_cuda_binary_locked(version) async def _download_cuda_binary_locked(version: Optional[str] = None): """Inner implementation of download_cuda_binary, called under _download_lock.""" import httpx if version is None: version = f"v{__version__}" progress = get_progress_manager() cuda_dir = get_cuda_dir() need_server = _needs_server_download(version) need_libs = _needs_cuda_libs_download() if not need_server and not need_libs: logger.info("CUDA backend is up to date, nothing to download") return logger.info( f"Starting CUDA backend download for {version} " f"(server={'yes' if need_server else 'cached'}, " f"libs={'yes' if need_libs else 'cached'})" ) progress.update_progress( PROGRESS_KEY, current=0, total=0, filename="Preparing download...", status="downloading", ) base_url = f"{GITHUB_RELEASES_URL}/{version}" server_archive = "voicebox-server-cuda.tar.gz" libs_archive = f"cuda-libs-{CUDA_LIBS_VERSION}.tar.gz" try: async with httpx.AsyncClient(follow_redirects=True, timeout=30.0) as client: # Estimate total download size total_size = 0 if need_server: try: head = await client.head(f"{base_url}/{server_archive}") total_size += int(head.headers.get("content-length", 0)) except Exception: pass if need_libs: try: head = await client.head(f"{base_url}/{libs_archive}") total_size += int(head.headers.get("content-length", 0)) except Exception: pass logger.info(f"Total download size: {total_size / 1024 / 1024:.1f} MB") offset = 0 # Download server core if need_server: server_downloaded = await _download_and_extract_archive( client, url=f"{base_url}/{server_archive}", sha256_url=f"{base_url}/{server_archive}.sha256", dest_dir=cuda_dir, label="CUDA server", progress_offset=offset, total_size=total_size, ) offset += server_downloaded # Make executable on Unix exe_path = cuda_dir / get_cuda_exe_name() if sys.platform != "win32" and exe_path.exists(): exe_path.chmod(0o755) # Download CUDA libs if need_libs: await _download_and_extract_archive( client, url=f"{base_url}/{libs_archive}", sha256_url=f"{base_url}/{libs_archive}.sha256", dest_dir=cuda_dir, label="CUDA libraries", progress_offset=offset, total_size=total_size, ) # Write local cuda-libs.json manifest manifest = {"version": CUDA_LIBS_VERSION} get_cuda_libs_manifest_path().write_text(json.dumps(manifest, indent=2) + "\n") logger.info(f"CUDA backend ready at {cuda_dir}") progress.mark_complete(PROGRESS_KEY) except Exception as e: logger.error(f"CUDA backend download failed: {e}") progress.mark_error(PROGRESS_KEY, str(e)) raise def get_cuda_binary_version() -> Optional[str]: """Get the version of the installed CUDA binary, or None if not installed.""" import subprocess cuda_path = get_cuda_binary_path() if not cuda_path: return None try: result = subprocess.run( [str(cuda_path), "--version"], capture_output=True, text=True, timeout=30, cwd=str(cuda_path.parent), # Run from the onedir directory ) # Output format: "voicebox-server 0.3.0" for line in result.stdout.strip().splitlines(): if "voicebox-server" in line: return line.split()[-1] except Exception as e: logger.warning(f"Could not get CUDA binary version: {e}") return None async def check_and_update_cuda_binary(): """Check if the CUDA binary is outdated and auto-download if so. Called on server startup. Checks both server version and CUDA libs version. Downloads only what's needed. """ cuda_path = get_cuda_binary_path() if not cuda_path: return # No CUDA binary installed, nothing to update need_server = _needs_server_download() need_libs = _needs_cuda_libs_download() if not need_server and not need_libs: logger.info(f"CUDA binary is up to date (server=v{__version__}, libs={get_installed_cuda_libs_version()})") return reasons = [] if need_server: cuda_version = get_cuda_binary_version() reasons.append(f"server v{cuda_version} != v{__version__}") if need_libs: installed_libs = get_installed_cuda_libs_version() reasons.append(f"libs {installed_libs} != {CUDA_LIBS_VERSION}") logger.info(f"CUDA backend needs update ({', '.join(reasons)}). Auto-downloading...") try: await download_cuda_binary() except Exception as e: logger.error(f"Auto-update of CUDA binary failed: {e}") async def delete_cuda_binary() -> bool: """Delete the downloaded CUDA backend directory. Returns True if deleted.""" import shutil cuda_dir = get_cuda_dir() if cuda_dir.exists() and any(cuda_dir.iterdir()): shutil.rmtree(cuda_dir) logger.info(f"Deleted CUDA backend directory: {cuda_dir}") return True return False