""" Story management module. """ from typing import List, Optional from datetime import datetime import uuid import tempfile from pathlib import Path from sqlalchemy.orm import Session from sqlalchemy import func from .. import config from ..models import ( StoryCreate, StoryResponse, StoryDetailResponse, StoryItemDetail, StoryItemCreate, StoryItemBatchUpdate, StoryItemMove, StoryItemTrim, StoryItemSplit, StoryItemVersionUpdate, ) from ..database import ( Story as DBStory, StoryItem as DBStoryItem, Generation as DBGeneration, VoiceProfile as DBVoiceProfile, ) from .history import _get_versions_for_generation from ..utils.audio import load_audio, save_audio import numpy as np def _build_item_detail( item: DBStoryItem, generation: DBGeneration, profile_name: str, db: Session, ) -> StoryItemDetail: """Build a StoryItemDetail with version info from a story item and its generation.""" versions, active_version_id = _get_versions_for_generation(generation.id, db) # Resolve the audio path: if version_id is set, use that version's audio audio_path = generation.audio_path if item.version_id and versions: for v in versions: if v.id == item.version_id: audio_path = v.audio_path break return StoryItemDetail( id=item.id, story_id=item.story_id, generation_id=item.generation_id, version_id=getattr(item, "version_id", None), start_time_ms=item.start_time_ms, track=item.track, trim_start_ms=getattr(item, "trim_start_ms", 0), trim_end_ms=getattr(item, "trim_end_ms", 0), created_at=item.created_at, profile_id=generation.profile_id, profile_name=profile_name, text=generation.text, language=generation.language, audio_path=audio_path, duration=generation.duration, seed=generation.seed, instruct=generation.instruct, generation_created_at=generation.created_at, versions=versions, active_version_id=active_version_id, ) async def create_story( data: StoryCreate, db: Session, ) -> StoryResponse: """ Create a new story. Args: data: Story creation data db: Database session Returns: Created story """ db_story = DBStory( id=str(uuid.uuid4()), name=data.name, description=data.description, created_at=datetime.utcnow(), updated_at=datetime.utcnow(), ) db.add(db_story) db.commit() db.refresh(db_story) item_count = db.query(func.count(DBStoryItem.id)).filter(DBStoryItem.story_id == db_story.id).scalar() response = StoryResponse.model_validate(db_story) response.item_count = item_count return response async def list_stories( db: Session, ) -> List[StoryResponse]: """ List all stories. Args: db: Database session Returns: List of stories with item counts """ stories = db.query(DBStory).order_by(DBStory.updated_at.desc()).all() result = [] for story in stories: item_count = db.query(func.count(DBStoryItem.id)).filter(DBStoryItem.story_id == story.id).scalar() response = StoryResponse.model_validate(story) response.item_count = item_count result.append(response) return result async def get_story( story_id: str, db: Session, ) -> Optional[StoryDetailResponse]: """ Get a story with all its items. Args: story_id: Story ID db: Database session Returns: Story with items or None if not found """ story = db.query(DBStory).filter_by(id=story_id).first() if not story: return None items = ( db.query(DBStoryItem, DBGeneration, DBVoiceProfile.name.label("profile_name")) .join(DBGeneration, DBStoryItem.generation_id == DBGeneration.id) .join(DBVoiceProfile, DBGeneration.profile_id == DBVoiceProfile.id) .filter(DBStoryItem.story_id == story_id) .order_by(DBStoryItem.start_time_ms) .all() ) item_details = [] for item, generation, profile_name in items: item_details.append(_build_item_detail(item, generation, profile_name, db)) response = StoryDetailResponse.model_validate(story) response.items = item_details return response async def update_story( story_id: str, data: StoryCreate, db: Session, ) -> Optional[StoryResponse]: """ Update a story. Args: story_id: Story ID data: Update data db: Database session Returns: Updated story or None if not found """ story = db.query(DBStory).filter_by(id=story_id).first() if not story: return None story.name = data.name story.description = data.description story.updated_at = datetime.utcnow() db.commit() db.refresh(story) item_count = db.query(func.count(DBStoryItem.id)).filter(DBStoryItem.story_id == story.id).scalar() response = StoryResponse.model_validate(story) response.item_count = item_count return response async def delete_story( story_id: str, db: Session, ) -> bool: """ Delete a story and all its items. Args: story_id: Story ID db: Database session Returns: True if deleted, False if not found """ story = db.query(DBStory).filter_by(id=story_id).first() if not story: return False # Delete all items db.query(DBStoryItem).filter_by(story_id=story_id).delete() # Delete story db.delete(story) db.commit() return True async def add_item_to_story( story_id: str, data: StoryItemCreate, db: Session, ) -> Optional[StoryItemDetail]: """ Add a generation to a story. Args: story_id: Story ID data: Item creation data db: Database session Returns: Created item detail or None if story/generation not found """ # Verify story exists story = db.query(DBStory).filter_by(id=story_id).first() if not story: return None # Verify generation exists generation = db.query(DBGeneration).filter_by(id=data.generation_id).first() if not generation: return None # Check if generation is already in story existing = db.query(DBStoryItem).filter_by(story_id=story_id, generation_id=data.generation_id).first() if existing: # Return existing item profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first() return _build_item_detail(existing, generation, profile.name if profile else "Unknown", db) # Get track from data or default to 0 track = data.track if data.track is not None else 0 # Calculate start_time_ms if not provided if data.start_time_ms is not None: start_time_ms = data.start_time_ms else: existing_items = ( db.query(DBStoryItem, DBGeneration) .join(DBGeneration, DBStoryItem.generation_id == DBGeneration.id) .filter( DBStoryItem.story_id == story_id, DBStoryItem.track == track, ) .all() ) if not existing_items: start_time_ms = 0 else: max_end_time_ms = 0 for item, gen in existing_items: item_end_ms = item.start_time_ms + int(gen.duration * 1000) max_end_time_ms = max(max_end_time_ms, item_end_ms) # Add 200ms gap after the last item start_time_ms = max_end_time_ms + 200 # Create item item = DBStoryItem( id=str(uuid.uuid4()), story_id=story_id, generation_id=data.generation_id, start_time_ms=start_time_ms, track=track, created_at=datetime.utcnow(), ) db.add(item) # Update story updated_at story.updated_at = datetime.utcnow() db.commit() db.refresh(item) # Get profile name profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first() return _build_item_detail(item, generation, profile.name if profile else "Unknown", db) async def move_story_item( story_id: str, item_id: str, data: StoryItemMove, db: Session, ) -> Optional[StoryItemDetail]: """ Move a story item (update position and/or track). Args: story_id: Story ID item_id: Story item ID data: New position and track data db: Database session Returns: Updated item detail or None if not found """ # Get the item item = ( db.query(DBStoryItem) .filter_by( id=item_id, story_id=story_id, ) .first() ) if not item: return None # Get the generation generation = db.query(DBGeneration).filter_by(id=item.generation_id).first() if not generation: return None # Update position and track item.start_time_ms = data.start_time_ms item.track = data.track # Update story updated_at story = db.query(DBStory).filter_by(id=story_id).first() if story: story.updated_at = datetime.utcnow() db.commit() db.refresh(item) # Get profile name profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first() return _build_item_detail(item, generation, profile.name if profile else "Unknown", db) async def remove_item_from_story( story_id: str, item_id: str, db: Session, ) -> bool: """ Remove a story item from a story. Args: story_id: Story ID item_id: Story item ID to remove db: Database session Returns: True if removed, False if not found """ item = ( db.query(DBStoryItem) .filter_by( id=item_id, story_id=story_id, ) .first() ) if not item: return False # Delete item db.delete(item) # Update story updated_at story = db.query(DBStory).filter_by(id=story_id).first() if story: story.updated_at = datetime.utcnow() db.commit() return True async def trim_story_item( story_id: str, item_id: str, data: StoryItemTrim, db: Session, ) -> Optional[StoryItemDetail]: """ Trim a story item (update trim_start_ms and trim_end_ms). Args: story_id: Story ID item_id: Story item ID data: Trim data (trim_start_ms, trim_end_ms) db: Database session Returns: Updated item detail or None if not found """ # Get the item item = ( db.query(DBStoryItem) .filter_by( id=item_id, story_id=story_id, ) .first() ) if not item: return None # Get the generation generation = db.query(DBGeneration).filter_by(id=item.generation_id).first() if not generation: return None # Validate trim values don't exceed duration max_duration_ms = int(generation.duration * 1000) if data.trim_start_ms + data.trim_end_ms >= max_duration_ms: return None # Invalid trim - would result in zero or negative duration # Update trim values item.trim_start_ms = data.trim_start_ms item.trim_end_ms = data.trim_end_ms # Update story updated_at story = db.query(DBStory).filter_by(id=story_id).first() if story: story.updated_at = datetime.utcnow() db.commit() db.refresh(item) # Get profile name profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first() return _build_item_detail(item, generation, profile.name if profile else "Unknown", db) async def split_story_item( story_id: str, item_id: str, data: StoryItemSplit, db: Session, ) -> Optional[List[StoryItemDetail]]: """ Split a story item at a given time, creating two clips. Args: story_id: Story ID item_id: Story item ID to split data: Split data (split_time_ms - time within clip to split at) db: Database session Returns: List of two updated item details (original and new) or None if not found/invalid """ # Get the item with a row lock to prevent concurrent splits on the # same clip (e.g. from rapid double-clicks racing each other). item = ( db.query(DBStoryItem) .filter_by( id=item_id, story_id=story_id, ) .with_for_update() .first() ) if not item: return None # Get the generation generation = db.query(DBGeneration).filter_by(id=item.generation_id).first() if not generation: return None # Calculate effective duration and validate split point current_trim_start = getattr(item, "trim_start_ms", 0) current_trim_end = getattr(item, "trim_end_ms", 0) original_duration_ms = int(generation.duration * 1000) effective_duration_ms = original_duration_ms - current_trim_start - current_trim_end # Validate split_time_ms is within the effective duration if data.split_time_ms <= 0 or data.split_time_ms >= effective_duration_ms: return None # Invalid split point # Calculate the absolute time in the original audio where we're splitting absolute_split_ms = current_trim_start + data.split_time_ms # Update original clip: trim from the end item.trim_end_ms = original_duration_ms - absolute_split_ms # Create new clip: starts after the split, trimmed from the start new_item = DBStoryItem( id=str(uuid.uuid4()), story_id=story_id, generation_id=item.generation_id, # Same generation, different trim version_id=getattr(item, "version_id", None), # Preserve pinned version start_time_ms=item.start_time_ms + data.split_time_ms, track=item.track, trim_start_ms=absolute_split_ms, trim_end_ms=current_trim_end, created_at=datetime.utcnow(), ) db.add(new_item) # Update story updated_at story = db.query(DBStory).filter_by(id=story_id).first() if story: story.updated_at = datetime.utcnow() db.commit() db.refresh(item) db.refresh(new_item) # Get profile name profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first() profile_name = profile.name if profile else "Unknown" return [ _build_item_detail(item, generation, profile_name, db), _build_item_detail(new_item, generation, profile_name, db), ] async def duplicate_story_item( story_id: str, item_id: str, db: Session, ) -> Optional[StoryItemDetail]: """ Duplicate a story item, creating a copy with all properties. Args: story_id: Story ID item_id: Story item ID to duplicate db: Database session Returns: New item detail or None if not found """ # Get the original item original_item = ( db.query(DBStoryItem) .filter_by( id=item_id, story_id=story_id, ) .first() ) if not original_item: return None # Get the generation generation = db.query(DBGeneration).filter_by(id=original_item.generation_id).first() if not generation: return None # Calculate effective duration current_trim_start = getattr(original_item, "trim_start_ms", 0) current_trim_end = getattr(original_item, "trim_end_ms", 0) original_duration_ms = int(generation.duration * 1000) effective_duration_ms = original_duration_ms - current_trim_start - current_trim_end # Create duplicate item - place it right after the original new_item = DBStoryItem( id=str(uuid.uuid4()), story_id=story_id, generation_id=original_item.generation_id, # Same generation as original version_id=getattr(original_item, "version_id", None), # Preserve pinned version start_time_ms=original_item.start_time_ms + effective_duration_ms + 200, # 200ms gap track=original_item.track, trim_start_ms=current_trim_start, trim_end_ms=current_trim_end, created_at=datetime.utcnow(), ) db.add(new_item) # Update story updated_at story = db.query(DBStory).filter_by(id=story_id).first() if story: story.updated_at = datetime.utcnow() db.commit() db.refresh(new_item) # Get profile name profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first() return _build_item_detail(new_item, generation, profile.name if profile else "Unknown", db) async def update_story_item_times( story_id: str, data: StoryItemBatchUpdate, db: Session, ) -> bool: """ Update story item timecodes. Args: story_id: Story ID data: Batch update data with timecodes db: Database session Returns: True if updated, False if story not found or invalid """ story = db.query(DBStory).filter_by(id=story_id).first() if not story: return False # Get all items for this story items = db.query(DBStoryItem).filter_by(story_id=story_id).all() item_map = {item.generation_id: item for item in items} # Verify all generation IDs belong to this story and update timecodes for update in data.updates: if update.generation_id not in item_map: return False item_map[update.generation_id].start_time_ms = update.start_time_ms # Update story updated_at story.updated_at = datetime.utcnow() db.commit() return True async def reorder_story_items( story_id: str, generation_ids: List[str], db: Session, gap_ms: int = 200, ) -> Optional[List[StoryItemDetail]]: """ Reorder story items and recalculate timecodes. Args: story_id: Story ID generation_ids: List of generation IDs in the desired order db: Database session gap_ms: Gap in milliseconds between items (default 200ms) Returns: Updated list of story items with new timecodes, or None if invalid """ story = db.query(DBStory).filter_by(id=story_id).first() if not story: return None # Get all items for this story with their generation data items_with_gen = ( db.query(DBStoryItem, DBGeneration, DBVoiceProfile.name.label("profile_name")) .join(DBGeneration, DBStoryItem.generation_id == DBGeneration.id) .join(DBVoiceProfile, DBGeneration.profile_id == DBVoiceProfile.id) .filter(DBStoryItem.story_id == story_id) .all() ) # Create maps for quick lookup item_map = {item.generation_id: (item, gen, profile_name) for item, gen, profile_name in items_with_gen} # Verify all generation IDs belong to this story if set(generation_ids) != set(item_map.keys()): return None # Recalculate timecodes based on new order current_time_ms = 0 updated_items = [] for gen_id in generation_ids: item, generation, profile_name = item_map[gen_id] # Update the item's start time item.start_time_ms = current_time_ms # Calculate the duration in ms duration_ms = int(generation.duration * 1000) # Move to next position (current end + gap) current_time_ms += duration_ms + gap_ms # Build the response item updated_items.append(_build_item_detail(item, generation, profile_name, db)) # Update story updated_at story.updated_at = datetime.utcnow() db.commit() return updated_items async def set_story_item_version( story_id: str, item_id: str, data: StoryItemVersionUpdate, db: Session, ) -> Optional[StoryItemDetail]: """ Pin a story item to a specific generation version. Args: story_id: Story ID item_id: Story item ID data: Version update data (version_id or null for default) db: Database session Returns: Updated item detail or None if not found """ item = ( db.query(DBStoryItem) .filter_by( id=item_id, story_id=story_id, ) .first() ) if not item: return None generation = db.query(DBGeneration).filter_by(id=item.generation_id).first() if not generation: return None # Validate version_id belongs to this generation if provided if data.version_id: from ..database import GenerationVersion as DBGenerationVersion version = ( db.query(DBGenerationVersion) .filter_by( id=data.version_id, generation_id=item.generation_id, ) .first() ) if not version: return None item.version_id = data.version_id # Update story updated_at story = db.query(DBStory).filter_by(id=story_id).first() if story: story.updated_at = datetime.utcnow() db.commit() db.refresh(item) profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first() return _build_item_detail(item, generation, profile.name if profile else "Unknown", db) async def export_story_audio( story_id: str, db: Session, ) -> Optional[bytes]: """ Export story as single mixed audio file with timecode-based mixing. Args: story_id: Story ID db: Database session Returns: Audio file bytes or None if story not found """ story = db.query(DBStory).filter_by(id=story_id).first() if not story: return None # Get all items ordered by start_time_ms items = ( db.query(DBStoryItem, DBGeneration) .join(DBGeneration, DBStoryItem.generation_id == DBGeneration.id) .filter(DBStoryItem.story_id == story_id) .order_by(DBStoryItem.start_time_ms) .all() ) if not items: return None # Load all audio files and calculate total duration audio_data = [] sample_rate = 24000 # Default sample rate for item, generation in items: # Resolve audio path: use pinned version if set, otherwise generation default resolved_audio_path = generation.audio_path if getattr(item, "version_id", None): from ..database import GenerationVersion as DBGenerationVersion version = db.query(DBGenerationVersion).filter_by(id=item.version_id).first() if version: resolved_audio_path = version.audio_path audio_path = config.resolve_storage_path(resolved_audio_path) if audio_path is None or not audio_path.exists(): continue try: audio, sr = load_audio(str(audio_path), sample_rate=sample_rate) sample_rate = sr # Use actual sample rate from first file # Get trim values trim_start_ms = getattr(item, "trim_start_ms", 0) trim_end_ms = getattr(item, "trim_end_ms", 0) # Calculate effective duration original_duration_ms = int(generation.duration * 1000) effective_duration_ms = original_duration_ms - trim_start_ms - trim_end_ms # Slice audio based on trim values trim_start_sample = int((trim_start_ms / 1000.0) * sample_rate) trim_end_sample = int((trim_end_ms / 1000.0) * sample_rate) # Extract the trimmed portion if trim_end_ms > 0: trimmed_audio = ( audio[trim_start_sample:-trim_end_sample] if trim_end_sample > 0 else audio[trim_start_sample:] ) else: trimmed_audio = audio[trim_start_sample:] # Store audio with its timecode info start_time_ms = item.start_time_ms audio_data.append( { "audio": trimmed_audio, "start_time_ms": start_time_ms, "duration_ms": effective_duration_ms, } ) except Exception: # Skip files that can't be loaded continue if not audio_data: return None # Calculate total duration: max(start_time_ms + duration_ms) max_end_time_ms = max((data["start_time_ms"] + data["duration_ms"] for data in audio_data), default=0) # Convert to samples total_samples = int((max_end_time_ms / 1000.0) * sample_rate) # Create output buffer initialized to zeros final_audio = np.zeros(total_samples, dtype=np.float32) # Mix each audio segment at its timecode position for data in audio_data: audio = data["audio"] start_time_ms = data["start_time_ms"] # Calculate start sample index start_sample = int((start_time_ms / 1000.0) * sample_rate) # Ensure we don't exceed buffer bounds audio_length = len(audio) end_sample = min(start_sample + audio_length, total_samples) if start_sample < total_samples: # Trim audio if it extends beyond buffer audio_to_mix = audio[: end_sample - start_sample] # Mix: add audio to existing buffer (overlapping audio will sum) # Normalize to prevent clipping (simple approach: divide by max) final_audio[start_sample:end_sample] += audio_to_mix # Normalize to prevent clipping max_val = np.abs(final_audio).max() if max_val > 1.0: final_audio = final_audio / max_val # Save to temporary file with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: tmp_path = tmp.name try: save_audio(final_audio, tmp_path, sample_rate) # Read file bytes with open(tmp_path, "rb") as f: audio_bytes = f.read() return audio_bytes finally: # Clean up temp file Path(tmp_path).unlink(missing_ok=True)