"""Voice profile management module.""" import json as _json import logging import shutil import uuid from datetime import datetime from pathlib import Path from sqlalchemy import func from sqlalchemy.orm import Session from .. import config from ..database import Generation as DBGeneration, ProfileSample as DBProfileSample, VoiceProfile as DBVoiceProfile from ..models import ( EffectConfig, ProfileSampleResponse, VoiceProfileCreate, VoiceProfileResponse, ) from ..utils.audio import save_audio, validate_and_load_reference_audio from ..utils.cache import _get_cache_dir, clear_profile_cache from ..utils.images import process_avatar, validate_image logger = logging.getLogger(__name__) CLONING_ENGINES = {"qwen", "luxtts", "chatterbox", "chatterbox_turbo", "tada"} def _profile_to_response( profile: DBVoiceProfile, generation_count: int = 0, sample_count: int = 0, ) -> VoiceProfileResponse: """Convert a DB profile to a VoiceProfileResponse, deserializing effects_chain.""" effects_chain = None if profile.effects_chain: try: raw = _json.loads(profile.effects_chain) effects_chain = [EffectConfig(**e) for e in raw] except Exception as e: import logging logging.warning(f"Failed to parse effects_chain for profile {profile.id}: {e}") return VoiceProfileResponse( id=profile.id, name=profile.name, description=profile.description, language=profile.language, avatar_path=profile.avatar_path, effects_chain=effects_chain, voice_type=getattr(profile, "voice_type", None) or "cloned", preset_engine=getattr(profile, "preset_engine", None), preset_voice_id=getattr(profile, "preset_voice_id", None), design_prompt=getattr(profile, "design_prompt", None), default_engine=getattr(profile, "default_engine", None), generation_count=generation_count, sample_count=sample_count, created_at=profile.created_at, updated_at=profile.updated_at, ) def _get_preset_voice_ids(engine: str) -> set[str]: if engine == "kokoro": from ..backends.kokoro_backend import KOKORO_VOICES return {voice_id for voice_id, _name, _gender, _lang in KOKORO_VOICES} if engine == "qwen_custom_voice": from ..backends.qwen_custom_voice_backend import QWEN_CUSTOM_VOICES return {voice_id for voice_id, _name, _gender, _lang, _desc in QWEN_CUSTOM_VOICES} return set() def _validate_profile_fields( *, voice_type: str, preset_engine: str | None, preset_voice_id: str | None, design_prompt: str | None, default_engine: str | None, ) -> str | None: if voice_type == "preset": if not preset_engine or not preset_voice_id: return "Preset profiles require both preset_engine and preset_voice_id" if default_engine and default_engine != preset_engine: return "Preset profiles must use their preset_engine as default_engine" available_voice_ids = _get_preset_voice_ids(preset_engine) if available_voice_ids and preset_voice_id not in available_voice_ids: return f"Preset voice '{preset_voice_id}' is not valid for engine '{preset_engine}'" return None if voice_type == "designed": if not design_prompt or not design_prompt.strip(): return "Designed profiles require a design_prompt" if preset_engine or preset_voice_id: return "Designed profiles cannot set preset_engine or preset_voice_id" return None if preset_engine or preset_voice_id: return "Cloned profiles cannot set preset_engine or preset_voice_id" if design_prompt: return "Cloned profiles cannot set design_prompt" if default_engine and default_engine not in CLONING_ENGINES: return f"Cloned profiles cannot use default engine '{default_engine}'" return None def validate_profile_engine(profile, engine: str) -> None: voice_type = getattr(profile, "voice_type", None) or "cloned" if voice_type == "preset": preset_engine = getattr(profile, "preset_engine", None) preset_voice_id = getattr(profile, "preset_voice_id", None) if not preset_engine or not preset_voice_id: raise ValueError(f"Preset profile {profile.id} is missing preset engine metadata") if preset_engine != engine: raise ValueError( f"Preset profile {profile.id} only supports engine '{preset_engine}', not '{engine}'" ) return if voice_type == "designed": design_prompt = getattr(profile, "design_prompt", None) if not design_prompt or not design_prompt.strip(): raise ValueError(f"Designed profile {profile.id} is missing design_prompt") return if engine not in CLONING_ENGINES: raise ValueError(f"Engine '{engine}' does not support cloned voice profiles") async def create_profile( data: VoiceProfileCreate, db: Session, ) -> VoiceProfileResponse: """ Create a new voice profile. Args: data: Profile creation data db: Database session Returns: Created profile Raises: ValueError: If a profile with the same name already exists """ existing_profile = db.query(DBVoiceProfile).filter_by(name=data.name).first() if existing_profile: raise ValueError(f"A profile with the name '{data.name}' already exists. Please choose a different name.") # Auto-set default_engine for preset profiles default_engine = data.default_engine voice_type = data.voice_type or "cloned" if voice_type == "preset" and data.preset_engine and not default_engine: default_engine = data.preset_engine validation_error = _validate_profile_fields( voice_type=voice_type, preset_engine=data.preset_engine, preset_voice_id=data.preset_voice_id, design_prompt=data.design_prompt, default_engine=default_engine, ) if validation_error: raise ValueError(validation_error) db_profile = DBVoiceProfile( id=str(uuid.uuid4()), name=data.name, description=data.description, language=data.language, voice_type=voice_type, preset_engine=data.preset_engine, preset_voice_id=data.preset_voice_id, design_prompt=data.design_prompt, default_engine=default_engine, created_at=datetime.utcnow(), updated_at=datetime.utcnow(), ) db.add(db_profile) db.commit() db.refresh(db_profile) profile_dir = config.get_profiles_dir() / db_profile.id profile_dir.mkdir(parents=True, exist_ok=True) return _profile_to_response(db_profile) async def add_profile_sample( profile_id: str, audio_path: str, reference_text: str, db: Session, ) -> ProfileSampleResponse: """ Add a sample to a voice profile. Args: profile_id: Profile ID audio_path: Path to temporary audio file reference_text: Transcript of audio db: Database session Returns: Created sample """ import asyncio profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first() if not profile: raise ValueError(f"Profile {profile_id} not found") # Validate and load audio in a single pass, off the event loop is_valid, error_msg, audio, sr = await asyncio.to_thread( validate_and_load_reference_audio, audio_path ) if not is_valid: raise ValueError(f"Invalid reference audio: {error_msg}") sample_id = str(uuid.uuid4()) profile_dir = config.get_profiles_dir() / profile_id profile_dir.mkdir(parents=True, exist_ok=True) dest_path = profile_dir / f"{sample_id}.wav" await asyncio.to_thread(save_audio, audio, str(dest_path), sr) db_sample = DBProfileSample( id=sample_id, profile_id=profile_id, audio_path=config.to_storage_path(dest_path), reference_text=reference_text, ) db.add(db_sample) profile.updated_at = datetime.utcnow() db.commit() db.refresh(db_sample) # Invalidate combined audio cache for this profile # Since a new sample was added, any cached combined audio is now stale clear_profile_cache(profile_id) return ProfileSampleResponse.model_validate(db_sample) async def get_profile( profile_id: str, db: Session, ) -> VoiceProfileResponse | None: """ Get a voice profile by ID. Args: profile_id: Profile ID db: Database session Returns: Profile or None if not found """ profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first() if not profile: return None return _profile_to_response(profile) async def get_profile_samples( profile_id: str, db: Session, ) -> list[ProfileSampleResponse]: """ Get all samples for a profile. Args: profile_id: Profile ID db: Database session Returns: List of samples """ samples = db.query(DBProfileSample).filter_by(profile_id=profile_id).all() return [ProfileSampleResponse.model_validate(s) for s in samples] async def list_profiles(db: Session) -> list[VoiceProfileResponse]: """ List all voice profiles with generation and sample counts. Args: db: Database session Returns: List of profiles """ profiles = db.query(DBVoiceProfile).order_by(DBVoiceProfile.created_at.desc()).all() if not profiles: return [] # Batch-fetch generation counts gen_counts_rows = ( db.query(DBGeneration.profile_id, func.count(DBGeneration.id)).group_by(DBGeneration.profile_id).all() ) gen_counts = {row[0]: row[1] for row in gen_counts_rows} # Batch-fetch sample counts sample_counts_rows = ( db.query(DBProfileSample.profile_id, func.count(DBProfileSample.id)).group_by(DBProfileSample.profile_id).all() ) sample_counts = {row[0]: row[1] for row in sample_counts_rows} return [ _profile_to_response( p, generation_count=gen_counts.get(p.id, 0), sample_count=sample_counts.get(p.id, 0), ) for p in profiles ] async def update_profile( profile_id: str, data: VoiceProfileCreate, db: Session, ) -> VoiceProfileResponse | None: """ Update a voice profile. Args: profile_id: Profile ID data: Updated profile data db: Database session Returns: Updated profile or None if not found Raises: ValueError: If a profile with the same name already exists (different profile) """ profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first() if not profile: return None if profile.name != data.name: existing_profile = db.query(DBVoiceProfile).filter_by(name=data.name).first() if existing_profile: raise ValueError(f"A profile with the name '{data.name}' already exists. Please choose a different name.") voice_type = getattr(profile, "voice_type", None) or "cloned" preset_engine = getattr(profile, "preset_engine", None) preset_voice_id = getattr(profile, "preset_voice_id", None) design_prompt = getattr(profile, "design_prompt", None) default_engine = data.default_engine if data.default_engine is not None else getattr(profile, "default_engine", None) validation_error = _validate_profile_fields( voice_type=voice_type, preset_engine=preset_engine, preset_voice_id=preset_voice_id, design_prompt=design_prompt, default_engine=default_engine, ) if validation_error: raise ValueError(validation_error) profile.name = data.name profile.description = data.description profile.language = data.language if data.default_engine is not None: profile.default_engine = data.default_engine or None # empty string → NULL profile.updated_at = datetime.utcnow() db.commit() db.refresh(profile) return _profile_to_response(profile) async def delete_profile( profile_id: str, db: Session, ) -> bool: """ Delete a voice profile and all associated data. Args: profile_id: Profile ID db: Database session Returns: True if deleted, False if not found """ profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first() if not profile: return False db.query(DBProfileSample).filter_by(profile_id=profile_id).delete() db.delete(profile) db.commit() profile_dir = config.get_profiles_dir() / profile_id if profile_dir.exists(): shutil.rmtree(profile_dir) # Clean up combined audio cache files for this profile clear_profile_cache(profile_id) return True async def delete_profile_sample( sample_id: str, db: Session, ) -> bool: """ Delete a profile sample. Args: sample_id: Sample ID db: Database session Returns: True if deleted, False if not found """ sample = db.query(DBProfileSample).filter_by(id=sample_id).first() if not sample: return False # Store profile_id before deleting profile_id = sample.profile_id audio_path = config.resolve_storage_path(sample.audio_path) if audio_path is not None and audio_path.exists(): audio_path.unlink() db.delete(sample) db.commit() # Invalidate combined audio cache for this profile # Since the sample set changed, any cached combined audio is now stale clear_profile_cache(profile_id) return True async def update_profile_sample( sample_id: str, reference_text: str, db: Session, ) -> ProfileSampleResponse | None: """ Update a profile sample's reference text. Args: sample_id: Sample ID reference_text: Updated reference text db: Database session Returns: Updated sample or None if not found """ sample = db.query(DBProfileSample).filter_by(id=sample_id).first() if not sample: return None # Store profile_id before updating profile_id = sample.profile_id sample.reference_text = reference_text db.commit() db.refresh(sample) # Invalidate combined audio cache for this profile # Since the reference text changed, cache keys and combined text are now stale clear_profile_cache(profile_id) return ProfileSampleResponse.model_validate(sample) async def create_voice_prompt_for_profile( profile_id: str, db: Session, use_cache: bool = True, engine: str = "qwen", ) -> dict: """ Create a voice prompt from a profile. For cloned profiles: combines all audio samples into a voice prompt. For preset profiles: returns the engine-specific preset voice reference. For designed profiles: returns the text design prompt (future). Args: profile_id: Profile ID db: Database session use_cache: Whether to use cached prompts engine: TTS engine to create prompt for Returns: Voice prompt dictionary """ from ..backends import get_tts_backend_for_engine profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first() if not profile: raise ValueError(f"Profile not found: {profile_id}") voice_type = getattr(profile, "voice_type", None) or "cloned" validate_profile_engine(profile, engine) # ── Preset profiles: return engine-specific voice reference ── if voice_type == "preset": if not profile.preset_engine or not profile.preset_voice_id: raise ValueError(f"Preset profile {profile_id} is missing preset engine metadata") if profile.preset_engine != engine: raise ValueError( f"Preset profile {profile_id} only supports engine '{profile.preset_engine}', not '{engine}'" ) return { "voice_type": "preset", "preset_engine": profile.preset_engine, "preset_voice_id": profile.preset_voice_id, } # ── Designed profiles: return text description (future) ── if voice_type == "designed": if not profile.design_prompt or not profile.design_prompt.strip(): raise ValueError(f"Designed profile {profile_id} is missing design_prompt") return { "voice_type": "designed", "design_prompt": profile.design_prompt, } if engine not in CLONING_ENGINES: raise ValueError(f"Engine '{engine}' does not support cloned voice profiles") # ── Cloned profiles: create from audio samples ── samples = db.query(DBProfileSample).filter_by(profile_id=profile_id).all() if not samples: raise ValueError(f"No samples found for profile {profile_id}") tts_model = get_tts_backend_for_engine(engine) if len(samples) == 1: sample = samples[0] sample_audio_path = config.resolve_storage_path(sample.audio_path) if sample_audio_path is None: raise ValueError(f"Sample audio not found for profile {profile_id}") voice_prompt, _ = await tts_model.create_voice_prompt( str(sample_audio_path), sample.reference_text, use_cache=use_cache, ) return voice_prompt audio_paths = [] for sample in samples: sample_audio_path = config.resolve_storage_path(sample.audio_path) if sample_audio_path is None: raise ValueError(f"Sample audio not found for profile {profile_id}") audio_paths.append(str(sample_audio_path)) reference_texts = [s.reference_text for s in samples] combined_audio, combined_text = await tts_model.combine_voice_prompts( audio_paths, reference_texts, ) # Save combined audio to cache directory (persistent) # Create a hash of sample IDs to identify this specific combination import hashlib sample_ids_str = "-".join(sorted([s.id for s in samples])) combination_hash = hashlib.md5(sample_ids_str.encode()).hexdigest()[:12] cache_dir = _get_cache_dir() cache_dir.mkdir(parents=True, exist_ok=True) combined_path = cache_dir / f"combined_{profile_id}_{combination_hash}.wav" save_audio(combined_audio, str(combined_path), 24000) voice_prompt, _ = await tts_model.create_voice_prompt( str(combined_path), combined_text, use_cache=use_cache, ) return voice_prompt async def upload_avatar( profile_id: str, image_path: str, db: Session, ) -> VoiceProfileResponse: """ Upload and process avatar image for a profile. Args: profile_id: Profile ID image_path: Path to uploaded image file db: Database session Returns: Updated profile """ profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first() if not profile: raise ValueError(f"Profile {profile_id} not found") is_valid, error_msg = validate_image(image_path) if not is_valid: raise ValueError(error_msg) if profile.avatar_path: old_avatar = config.resolve_storage_path(profile.avatar_path) if old_avatar is not None and old_avatar.exists(): old_avatar.unlink() # Determine file extension from uploaded file from PIL import Image with Image.open(image_path) as img: # Normalize JPEG variants (MPO is multi-picture format from some cameras) img_format = img.format if img_format in ("MPO", "JPG"): img_format = "JPEG" ext_map = {"PNG": ".png", "JPEG": ".jpg", "WEBP": ".webp"} ext = ext_map.get(img_format, ".png") profile_dir = config.get_profiles_dir() / profile_id profile_dir.mkdir(parents=True, exist_ok=True) output_path = profile_dir / f"avatar{ext}" process_avatar(image_path, str(output_path)) profile.avatar_path = config.to_storage_path(output_path) profile.updated_at = datetime.utcnow() db.commit() db.refresh(profile) return _profile_to_response(profile) async def delete_avatar( profile_id: str, db: Session, ) -> bool: """ Delete avatar image for a profile. Args: profile_id: Profile ID db: Database session Returns: True if deleted, False if not found or no avatar """ profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first() if not profile or not profile.avatar_path: return False avatar_path = config.resolve_storage_path(profile.avatar_path) if avatar_path is not None and avatar_path.exists(): avatar_path.unlink() profile.avatar_path = None profile.updated_at = datetime.utcnow() db.commit() return True