Initial commit
This commit is contained in:
153
backend/utils/cache.py
Normal file
153
backend/utils/cache.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""
|
||||
Voice prompt caching utilities.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Dict, Any
|
||||
|
||||
from .. import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_cache_dir() -> Path:
|
||||
"""Get cache directory from config."""
|
||||
return config.get_cache_dir()
|
||||
|
||||
|
||||
# In-memory cache - can store dict (voice prompt) or tensor (legacy)
|
||||
_memory_cache: dict[str, Union[torch.Tensor, Dict[str, Any]]] = {}
|
||||
|
||||
|
||||
def get_cache_key(audio_path: str, reference_text: str) -> str:
|
||||
"""
|
||||
Generate cache key from audio file and reference text.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
reference_text: Reference text
|
||||
|
||||
Returns:
|
||||
Cache key (MD5 hash)
|
||||
"""
|
||||
# Read audio file
|
||||
with open(audio_path, "rb") as f:
|
||||
audio_bytes = f.read()
|
||||
|
||||
# Combine audio bytes and text
|
||||
combined = audio_bytes + reference_text.encode("utf-8")
|
||||
|
||||
# Generate hash
|
||||
return hashlib.md5(combined).hexdigest()
|
||||
|
||||
|
||||
def get_cached_voice_prompt(
|
||||
cache_key: str,
|
||||
) -> Optional[Union[torch.Tensor, Dict[str, Any]]]:
|
||||
"""
|
||||
Get cached voice prompt if available.
|
||||
|
||||
Args:
|
||||
cache_key: Cache key
|
||||
|
||||
Returns:
|
||||
Cached voice prompt (dict or tensor) or None
|
||||
"""
|
||||
# Check in-memory cache
|
||||
if cache_key in _memory_cache:
|
||||
return _memory_cache[cache_key]
|
||||
|
||||
# Check disk cache
|
||||
cache_file = _get_cache_dir() / f"{cache_key}.prompt"
|
||||
if cache_file.exists():
|
||||
try:
|
||||
prompt = torch.load(cache_file, weights_only=True)
|
||||
_memory_cache[cache_key] = prompt
|
||||
return prompt
|
||||
except Exception:
|
||||
# Cache file corrupted, delete it
|
||||
cache_file.unlink()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def cache_voice_prompt(
|
||||
cache_key: str,
|
||||
voice_prompt: Union[torch.Tensor, Dict[str, Any]],
|
||||
) -> None:
|
||||
"""
|
||||
Cache voice prompt to memory and disk.
|
||||
|
||||
Args:
|
||||
cache_key: Cache key
|
||||
voice_prompt: Voice prompt (dict or tensor)
|
||||
"""
|
||||
# Store in memory
|
||||
_memory_cache[cache_key] = voice_prompt
|
||||
|
||||
# Store on disk (torch.save can handle both dicts and tensors)
|
||||
cache_file = _get_cache_dir() / f"{cache_key}.prompt"
|
||||
torch.save(voice_prompt, cache_file)
|
||||
|
||||
|
||||
def clear_voice_prompt_cache() -> int:
|
||||
"""
|
||||
Clear all voice prompt caches (memory and disk).
|
||||
|
||||
Returns:
|
||||
Number of cache files deleted
|
||||
"""
|
||||
# Clear memory cache
|
||||
_memory_cache.clear()
|
||||
|
||||
# Clear disk cache
|
||||
cache_dir = _get_cache_dir()
|
||||
deleted_count = 0
|
||||
|
||||
if cache_dir.exists():
|
||||
# Delete prompt cache files
|
||||
for cache_file in cache_dir.glob("*.prompt"):
|
||||
try:
|
||||
cache_file.unlink()
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
logger.warning("Failed to delete cache file %s: %s", cache_file, e)
|
||||
|
||||
# Delete combined audio files
|
||||
for audio_file in cache_dir.glob("combined_*.wav"):
|
||||
try:
|
||||
audio_file.unlink()
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
logger.warning("Failed to delete combined audio file %s: %s", audio_file, e)
|
||||
|
||||
return deleted_count
|
||||
|
||||
|
||||
def clear_profile_cache(profile_id: str) -> int:
|
||||
"""
|
||||
Clear cache files for a specific profile.
|
||||
|
||||
Args:
|
||||
profile_id: Profile ID
|
||||
|
||||
Returns:
|
||||
Number of cache files deleted
|
||||
"""
|
||||
cache_dir = _get_cache_dir()
|
||||
deleted_count = 0
|
||||
|
||||
if cache_dir.exists():
|
||||
# Delete combined audio files for this profile
|
||||
pattern = f"combined_{profile_id}_*.wav"
|
||||
for audio_file in cache_dir.glob(pattern):
|
||||
try:
|
||||
audio_file.unlink()
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
logger.warning("Failed to delete combined audio file %s: %s", audio_file, e)
|
||||
|
||||
return deleted_count
|
||||
Reference in New Issue
Block a user