154 lines
3.8 KiB
Python
154 lines
3.8 KiB
Python
"""
|
|
Voice prompt caching utilities.
|
|
"""
|
|
|
|
import hashlib
|
|
import logging
|
|
import torch
|
|
from pathlib import Path
|
|
from typing import Optional, Union, Dict, Any
|
|
|
|
from .. import config
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _get_cache_dir() -> Path:
|
|
"""Get cache directory from config."""
|
|
return config.get_cache_dir()
|
|
|
|
|
|
# In-memory cache - can store dict (voice prompt) or tensor (legacy)
|
|
_memory_cache: dict[str, Union[torch.Tensor, Dict[str, Any]]] = {}
|
|
|
|
|
|
def get_cache_key(audio_path: str, reference_text: str) -> str:
|
|
"""
|
|
Generate cache key from audio file and reference text.
|
|
|
|
Args:
|
|
audio_path: Path to audio file
|
|
reference_text: Reference text
|
|
|
|
Returns:
|
|
Cache key (MD5 hash)
|
|
"""
|
|
# Read audio file
|
|
with open(audio_path, "rb") as f:
|
|
audio_bytes = f.read()
|
|
|
|
# Combine audio bytes and text
|
|
combined = audio_bytes + reference_text.encode("utf-8")
|
|
|
|
# Generate hash
|
|
return hashlib.md5(combined).hexdigest()
|
|
|
|
|
|
def get_cached_voice_prompt(
|
|
cache_key: str,
|
|
) -> Optional[Union[torch.Tensor, Dict[str, Any]]]:
|
|
"""
|
|
Get cached voice prompt if available.
|
|
|
|
Args:
|
|
cache_key: Cache key
|
|
|
|
Returns:
|
|
Cached voice prompt (dict or tensor) or None
|
|
"""
|
|
# Check in-memory cache
|
|
if cache_key in _memory_cache:
|
|
return _memory_cache[cache_key]
|
|
|
|
# Check disk cache
|
|
cache_file = _get_cache_dir() / f"{cache_key}.prompt"
|
|
if cache_file.exists():
|
|
try:
|
|
prompt = torch.load(cache_file, weights_only=True)
|
|
_memory_cache[cache_key] = prompt
|
|
return prompt
|
|
except Exception:
|
|
# Cache file corrupted, delete it
|
|
cache_file.unlink()
|
|
|
|
return None
|
|
|
|
|
|
def cache_voice_prompt(
|
|
cache_key: str,
|
|
voice_prompt: Union[torch.Tensor, Dict[str, Any]],
|
|
) -> None:
|
|
"""
|
|
Cache voice prompt to memory and disk.
|
|
|
|
Args:
|
|
cache_key: Cache key
|
|
voice_prompt: Voice prompt (dict or tensor)
|
|
"""
|
|
# Store in memory
|
|
_memory_cache[cache_key] = voice_prompt
|
|
|
|
# Store on disk (torch.save can handle both dicts and tensors)
|
|
cache_file = _get_cache_dir() / f"{cache_key}.prompt"
|
|
torch.save(voice_prompt, cache_file)
|
|
|
|
|
|
def clear_voice_prompt_cache() -> int:
|
|
"""
|
|
Clear all voice prompt caches (memory and disk).
|
|
|
|
Returns:
|
|
Number of cache files deleted
|
|
"""
|
|
# Clear memory cache
|
|
_memory_cache.clear()
|
|
|
|
# Clear disk cache
|
|
cache_dir = _get_cache_dir()
|
|
deleted_count = 0
|
|
|
|
if cache_dir.exists():
|
|
# Delete prompt cache files
|
|
for cache_file in cache_dir.glob("*.prompt"):
|
|
try:
|
|
cache_file.unlink()
|
|
deleted_count += 1
|
|
except Exception as e:
|
|
logger.warning("Failed to delete cache file %s: %s", cache_file, e)
|
|
|
|
# Delete combined audio files
|
|
for audio_file in cache_dir.glob("combined_*.wav"):
|
|
try:
|
|
audio_file.unlink()
|
|
deleted_count += 1
|
|
except Exception as e:
|
|
logger.warning("Failed to delete combined audio file %s: %s", audio_file, e)
|
|
|
|
return deleted_count
|
|
|
|
|
|
def clear_profile_cache(profile_id: str) -> int:
|
|
"""
|
|
Clear cache files for a specific profile.
|
|
|
|
Args:
|
|
profile_id: Profile ID
|
|
|
|
Returns:
|
|
Number of cache files deleted
|
|
"""
|
|
cache_dir = _get_cache_dir()
|
|
deleted_count = 0
|
|
|
|
if cache_dir.exists():
|
|
# Delete combined audio files for this profile
|
|
pattern = f"combined_{profile_id}_*.wav"
|
|
for audio_file in cache_dir.glob(pattern):
|
|
try:
|
|
audio_file.unlink()
|
|
deleted_count += 1
|
|
except Exception as e:
|
|
logger.warning("Failed to delete combined audio file %s: %s", audio_file, e)
|
|
|
|
return deleted_count
|