Initial commit
This commit is contained in:
925
backend/services/stories.py
Normal file
925
backend/services/stories.py
Normal file
@@ -0,0 +1,925 @@
|
||||
"""
|
||||
Story management module.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
|
||||
from .. import config
|
||||
from ..models import (
|
||||
StoryCreate,
|
||||
StoryResponse,
|
||||
StoryDetailResponse,
|
||||
StoryItemDetail,
|
||||
StoryItemCreate,
|
||||
StoryItemBatchUpdate,
|
||||
StoryItemMove,
|
||||
StoryItemTrim,
|
||||
StoryItemSplit,
|
||||
StoryItemVersionUpdate,
|
||||
)
|
||||
from ..database import (
|
||||
Story as DBStory,
|
||||
StoryItem as DBStoryItem,
|
||||
Generation as DBGeneration,
|
||||
VoiceProfile as DBVoiceProfile,
|
||||
)
|
||||
from .history import _get_versions_for_generation
|
||||
from ..utils.audio import load_audio, save_audio
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _build_item_detail(
|
||||
item: DBStoryItem,
|
||||
generation: DBGeneration,
|
||||
profile_name: str,
|
||||
db: Session,
|
||||
) -> StoryItemDetail:
|
||||
"""Build a StoryItemDetail with version info from a story item and its generation."""
|
||||
versions, active_version_id = _get_versions_for_generation(generation.id, db)
|
||||
|
||||
# Resolve the audio path: if version_id is set, use that version's audio
|
||||
audio_path = generation.audio_path
|
||||
if item.version_id and versions:
|
||||
for v in versions:
|
||||
if v.id == item.version_id:
|
||||
audio_path = v.audio_path
|
||||
break
|
||||
|
||||
return StoryItemDetail(
|
||||
id=item.id,
|
||||
story_id=item.story_id,
|
||||
generation_id=item.generation_id,
|
||||
version_id=getattr(item, "version_id", None),
|
||||
start_time_ms=item.start_time_ms,
|
||||
track=item.track,
|
||||
trim_start_ms=getattr(item, "trim_start_ms", 0),
|
||||
trim_end_ms=getattr(item, "trim_end_ms", 0),
|
||||
created_at=item.created_at,
|
||||
profile_id=generation.profile_id,
|
||||
profile_name=profile_name,
|
||||
text=generation.text,
|
||||
language=generation.language,
|
||||
audio_path=audio_path,
|
||||
duration=generation.duration,
|
||||
seed=generation.seed,
|
||||
instruct=generation.instruct,
|
||||
generation_created_at=generation.created_at,
|
||||
versions=versions,
|
||||
active_version_id=active_version_id,
|
||||
)
|
||||
|
||||
|
||||
async def create_story(
|
||||
data: StoryCreate,
|
||||
db: Session,
|
||||
) -> StoryResponse:
|
||||
"""
|
||||
Create a new story.
|
||||
|
||||
Args:
|
||||
data: Story creation data
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Created story
|
||||
"""
|
||||
db_story = DBStory(
|
||||
id=str(uuid.uuid4()),
|
||||
name=data.name,
|
||||
description=data.description,
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
db.add(db_story)
|
||||
db.commit()
|
||||
db.refresh(db_story)
|
||||
|
||||
item_count = db.query(func.count(DBStoryItem.id)).filter(DBStoryItem.story_id == db_story.id).scalar()
|
||||
|
||||
response = StoryResponse.model_validate(db_story)
|
||||
response.item_count = item_count
|
||||
return response
|
||||
|
||||
|
||||
async def list_stories(
|
||||
db: Session,
|
||||
) -> List[StoryResponse]:
|
||||
"""
|
||||
List all stories.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of stories with item counts
|
||||
"""
|
||||
stories = db.query(DBStory).order_by(DBStory.updated_at.desc()).all()
|
||||
|
||||
result = []
|
||||
for story in stories:
|
||||
item_count = db.query(func.count(DBStoryItem.id)).filter(DBStoryItem.story_id == story.id).scalar()
|
||||
|
||||
response = StoryResponse.model_validate(story)
|
||||
response.item_count = item_count
|
||||
result.append(response)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def get_story(
|
||||
story_id: str,
|
||||
db: Session,
|
||||
) -> Optional[StoryDetailResponse]:
|
||||
"""
|
||||
Get a story with all its items.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Story with items or None if not found
|
||||
"""
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if not story:
|
||||
return None
|
||||
|
||||
items = (
|
||||
db.query(DBStoryItem, DBGeneration, DBVoiceProfile.name.label("profile_name"))
|
||||
.join(DBGeneration, DBStoryItem.generation_id == DBGeneration.id)
|
||||
.join(DBVoiceProfile, DBGeneration.profile_id == DBVoiceProfile.id)
|
||||
.filter(DBStoryItem.story_id == story_id)
|
||||
.order_by(DBStoryItem.start_time_ms)
|
||||
.all()
|
||||
)
|
||||
|
||||
item_details = []
|
||||
for item, generation, profile_name in items:
|
||||
item_details.append(_build_item_detail(item, generation, profile_name, db))
|
||||
|
||||
response = StoryDetailResponse.model_validate(story)
|
||||
response.items = item_details
|
||||
return response
|
||||
|
||||
|
||||
async def update_story(
|
||||
story_id: str,
|
||||
data: StoryCreate,
|
||||
db: Session,
|
||||
) -> Optional[StoryResponse]:
|
||||
"""
|
||||
Update a story.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
data: Update data
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Updated story or None if not found
|
||||
"""
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if not story:
|
||||
return None
|
||||
|
||||
story.name = data.name
|
||||
story.description = data.description
|
||||
story.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(story)
|
||||
|
||||
item_count = db.query(func.count(DBStoryItem.id)).filter(DBStoryItem.story_id == story.id).scalar()
|
||||
|
||||
response = StoryResponse.model_validate(story)
|
||||
response.item_count = item_count
|
||||
return response
|
||||
|
||||
|
||||
async def delete_story(
|
||||
story_id: str,
|
||||
db: Session,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a story and all its items.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if not story:
|
||||
return False
|
||||
|
||||
# Delete all items
|
||||
db.query(DBStoryItem).filter_by(story_id=story_id).delete()
|
||||
|
||||
# Delete story
|
||||
db.delete(story)
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def add_item_to_story(
|
||||
story_id: str,
|
||||
data: StoryItemCreate,
|
||||
db: Session,
|
||||
) -> Optional[StoryItemDetail]:
|
||||
"""
|
||||
Add a generation to a story.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
data: Item creation data
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Created item detail or None if story/generation not found
|
||||
"""
|
||||
# Verify story exists
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if not story:
|
||||
return None
|
||||
|
||||
# Verify generation exists
|
||||
generation = db.query(DBGeneration).filter_by(id=data.generation_id).first()
|
||||
if not generation:
|
||||
return None
|
||||
|
||||
# Check if generation is already in story
|
||||
existing = db.query(DBStoryItem).filter_by(story_id=story_id, generation_id=data.generation_id).first()
|
||||
if existing:
|
||||
# Return existing item
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
|
||||
return _build_item_detail(existing, generation, profile.name if profile else "Unknown", db)
|
||||
|
||||
# Get track from data or default to 0
|
||||
track = data.track if data.track is not None else 0
|
||||
|
||||
# Calculate start_time_ms if not provided
|
||||
if data.start_time_ms is not None:
|
||||
start_time_ms = data.start_time_ms
|
||||
else:
|
||||
existing_items = (
|
||||
db.query(DBStoryItem, DBGeneration)
|
||||
.join(DBGeneration, DBStoryItem.generation_id == DBGeneration.id)
|
||||
.filter(
|
||||
DBStoryItem.story_id == story_id,
|
||||
DBStoryItem.track == track,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not existing_items:
|
||||
start_time_ms = 0
|
||||
else:
|
||||
max_end_time_ms = 0
|
||||
for item, gen in existing_items:
|
||||
item_end_ms = item.start_time_ms + int(gen.duration * 1000)
|
||||
max_end_time_ms = max(max_end_time_ms, item_end_ms)
|
||||
|
||||
# Add 200ms gap after the last item
|
||||
start_time_ms = max_end_time_ms + 200
|
||||
|
||||
# Create item
|
||||
item = DBStoryItem(
|
||||
id=str(uuid.uuid4()),
|
||||
story_id=story_id,
|
||||
generation_id=data.generation_id,
|
||||
start_time_ms=start_time_ms,
|
||||
track=track,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
db.add(item)
|
||||
|
||||
# Update story updated_at
|
||||
story.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
|
||||
# Get profile name
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
|
||||
|
||||
return _build_item_detail(item, generation, profile.name if profile else "Unknown", db)
|
||||
|
||||
|
||||
async def move_story_item(
|
||||
story_id: str,
|
||||
item_id: str,
|
||||
data: StoryItemMove,
|
||||
db: Session,
|
||||
) -> Optional[StoryItemDetail]:
|
||||
"""
|
||||
Move a story item (update position and/or track).
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
item_id: Story item ID
|
||||
data: New position and track data
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Updated item detail or None if not found
|
||||
"""
|
||||
# Get the item
|
||||
item = (
|
||||
db.query(DBStoryItem)
|
||||
.filter_by(
|
||||
id=item_id,
|
||||
story_id=story_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not item:
|
||||
return None
|
||||
|
||||
# Get the generation
|
||||
generation = db.query(DBGeneration).filter_by(id=item.generation_id).first()
|
||||
if not generation:
|
||||
return None
|
||||
|
||||
# Update position and track
|
||||
item.start_time_ms = data.start_time_ms
|
||||
item.track = data.track
|
||||
|
||||
# Update story updated_at
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if story:
|
||||
story.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
|
||||
# Get profile name
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
|
||||
|
||||
return _build_item_detail(item, generation, profile.name if profile else "Unknown", db)
|
||||
|
||||
|
||||
async def remove_item_from_story(
|
||||
story_id: str,
|
||||
item_id: str,
|
||||
db: Session,
|
||||
) -> bool:
|
||||
"""
|
||||
Remove a story item from a story.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
item_id: Story item ID to remove
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
True if removed, False if not found
|
||||
"""
|
||||
item = (
|
||||
db.query(DBStoryItem)
|
||||
.filter_by(
|
||||
id=item_id,
|
||||
story_id=story_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not item:
|
||||
return False
|
||||
|
||||
# Delete item
|
||||
db.delete(item)
|
||||
|
||||
# Update story updated_at
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if story:
|
||||
story.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
async def trim_story_item(
|
||||
story_id: str,
|
||||
item_id: str,
|
||||
data: StoryItemTrim,
|
||||
db: Session,
|
||||
) -> Optional[StoryItemDetail]:
|
||||
"""
|
||||
Trim a story item (update trim_start_ms and trim_end_ms).
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
item_id: Story item ID
|
||||
data: Trim data (trim_start_ms, trim_end_ms)
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Updated item detail or None if not found
|
||||
"""
|
||||
# Get the item
|
||||
item = (
|
||||
db.query(DBStoryItem)
|
||||
.filter_by(
|
||||
id=item_id,
|
||||
story_id=story_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not item:
|
||||
return None
|
||||
|
||||
# Get the generation
|
||||
generation = db.query(DBGeneration).filter_by(id=item.generation_id).first()
|
||||
if not generation:
|
||||
return None
|
||||
|
||||
# Validate trim values don't exceed duration
|
||||
max_duration_ms = int(generation.duration * 1000)
|
||||
if data.trim_start_ms + data.trim_end_ms >= max_duration_ms:
|
||||
return None # Invalid trim - would result in zero or negative duration
|
||||
|
||||
# Update trim values
|
||||
item.trim_start_ms = data.trim_start_ms
|
||||
item.trim_end_ms = data.trim_end_ms
|
||||
|
||||
# Update story updated_at
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if story:
|
||||
story.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
|
||||
# Get profile name
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
|
||||
|
||||
return _build_item_detail(item, generation, profile.name if profile else "Unknown", db)
|
||||
|
||||
|
||||
async def split_story_item(
|
||||
story_id: str,
|
||||
item_id: str,
|
||||
data: StoryItemSplit,
|
||||
db: Session,
|
||||
) -> Optional[List[StoryItemDetail]]:
|
||||
"""
|
||||
Split a story item at a given time, creating two clips.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
item_id: Story item ID to split
|
||||
data: Split data (split_time_ms - time within clip to split at)
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of two updated item details (original and new) or None if not found/invalid
|
||||
"""
|
||||
# Get the item with a row lock to prevent concurrent splits on the
|
||||
# same clip (e.g. from rapid double-clicks racing each other).
|
||||
item = (
|
||||
db.query(DBStoryItem)
|
||||
.filter_by(
|
||||
id=item_id,
|
||||
story_id=story_id,
|
||||
)
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
if not item:
|
||||
return None
|
||||
|
||||
# Get the generation
|
||||
generation = db.query(DBGeneration).filter_by(id=item.generation_id).first()
|
||||
if not generation:
|
||||
return None
|
||||
|
||||
# Calculate effective duration and validate split point
|
||||
current_trim_start = getattr(item, "trim_start_ms", 0)
|
||||
current_trim_end = getattr(item, "trim_end_ms", 0)
|
||||
original_duration_ms = int(generation.duration * 1000)
|
||||
effective_duration_ms = original_duration_ms - current_trim_start - current_trim_end
|
||||
|
||||
# Validate split_time_ms is within the effective duration
|
||||
if data.split_time_ms <= 0 or data.split_time_ms >= effective_duration_ms:
|
||||
return None # Invalid split point
|
||||
|
||||
# Calculate the absolute time in the original audio where we're splitting
|
||||
absolute_split_ms = current_trim_start + data.split_time_ms
|
||||
|
||||
# Update original clip: trim from the end
|
||||
item.trim_end_ms = original_duration_ms - absolute_split_ms
|
||||
|
||||
# Create new clip: starts after the split, trimmed from the start
|
||||
new_item = DBStoryItem(
|
||||
id=str(uuid.uuid4()),
|
||||
story_id=story_id,
|
||||
generation_id=item.generation_id, # Same generation, different trim
|
||||
version_id=getattr(item, "version_id", None), # Preserve pinned version
|
||||
start_time_ms=item.start_time_ms + data.split_time_ms,
|
||||
track=item.track,
|
||||
trim_start_ms=absolute_split_ms,
|
||||
trim_end_ms=current_trim_end,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
db.add(new_item)
|
||||
|
||||
# Update story updated_at
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if story:
|
||||
story.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
db.refresh(new_item)
|
||||
|
||||
# Get profile name
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
|
||||
profile_name = profile.name if profile else "Unknown"
|
||||
|
||||
return [
|
||||
_build_item_detail(item, generation, profile_name, db),
|
||||
_build_item_detail(new_item, generation, profile_name, db),
|
||||
]
|
||||
|
||||
|
||||
async def duplicate_story_item(
|
||||
story_id: str,
|
||||
item_id: str,
|
||||
db: Session,
|
||||
) -> Optional[StoryItemDetail]:
|
||||
"""
|
||||
Duplicate a story item, creating a copy with all properties.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
item_id: Story item ID to duplicate
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
New item detail or None if not found
|
||||
"""
|
||||
# Get the original item
|
||||
original_item = (
|
||||
db.query(DBStoryItem)
|
||||
.filter_by(
|
||||
id=item_id,
|
||||
story_id=story_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not original_item:
|
||||
return None
|
||||
|
||||
# Get the generation
|
||||
generation = db.query(DBGeneration).filter_by(id=original_item.generation_id).first()
|
||||
if not generation:
|
||||
return None
|
||||
|
||||
# Calculate effective duration
|
||||
current_trim_start = getattr(original_item, "trim_start_ms", 0)
|
||||
current_trim_end = getattr(original_item, "trim_end_ms", 0)
|
||||
original_duration_ms = int(generation.duration * 1000)
|
||||
effective_duration_ms = original_duration_ms - current_trim_start - current_trim_end
|
||||
|
||||
# Create duplicate item - place it right after the original
|
||||
new_item = DBStoryItem(
|
||||
id=str(uuid.uuid4()),
|
||||
story_id=story_id,
|
||||
generation_id=original_item.generation_id, # Same generation as original
|
||||
version_id=getattr(original_item, "version_id", None), # Preserve pinned version
|
||||
start_time_ms=original_item.start_time_ms + effective_duration_ms + 200, # 200ms gap
|
||||
track=original_item.track,
|
||||
trim_start_ms=current_trim_start,
|
||||
trim_end_ms=current_trim_end,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
db.add(new_item)
|
||||
|
||||
# Update story updated_at
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if story:
|
||||
story.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(new_item)
|
||||
|
||||
# Get profile name
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
|
||||
|
||||
return _build_item_detail(new_item, generation, profile.name if profile else "Unknown", db)
|
||||
|
||||
|
||||
async def update_story_item_times(
|
||||
story_id: str,
|
||||
data: StoryItemBatchUpdate,
|
||||
db: Session,
|
||||
) -> bool:
|
||||
"""
|
||||
Update story item timecodes.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
data: Batch update data with timecodes
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
True if updated, False if story not found or invalid
|
||||
"""
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if not story:
|
||||
return False
|
||||
|
||||
# Get all items for this story
|
||||
items = db.query(DBStoryItem).filter_by(story_id=story_id).all()
|
||||
item_map = {item.generation_id: item for item in items}
|
||||
|
||||
# Verify all generation IDs belong to this story and update timecodes
|
||||
for update in data.updates:
|
||||
if update.generation_id not in item_map:
|
||||
return False
|
||||
item_map[update.generation_id].start_time_ms = update.start_time_ms
|
||||
|
||||
# Update story updated_at
|
||||
story.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
async def reorder_story_items(
|
||||
story_id: str,
|
||||
generation_ids: List[str],
|
||||
db: Session,
|
||||
gap_ms: int = 200,
|
||||
) -> Optional[List[StoryItemDetail]]:
|
||||
"""
|
||||
Reorder story items and recalculate timecodes.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
generation_ids: List of generation IDs in the desired order
|
||||
db: Database session
|
||||
gap_ms: Gap in milliseconds between items (default 200ms)
|
||||
|
||||
Returns:
|
||||
Updated list of story items with new timecodes, or None if invalid
|
||||
"""
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if not story:
|
||||
return None
|
||||
|
||||
# Get all items for this story with their generation data
|
||||
items_with_gen = (
|
||||
db.query(DBStoryItem, DBGeneration, DBVoiceProfile.name.label("profile_name"))
|
||||
.join(DBGeneration, DBStoryItem.generation_id == DBGeneration.id)
|
||||
.join(DBVoiceProfile, DBGeneration.profile_id == DBVoiceProfile.id)
|
||||
.filter(DBStoryItem.story_id == story_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Create maps for quick lookup
|
||||
item_map = {item.generation_id: (item, gen, profile_name) for item, gen, profile_name in items_with_gen}
|
||||
|
||||
# Verify all generation IDs belong to this story
|
||||
if set(generation_ids) != set(item_map.keys()):
|
||||
return None
|
||||
|
||||
# Recalculate timecodes based on new order
|
||||
current_time_ms = 0
|
||||
updated_items = []
|
||||
|
||||
for gen_id in generation_ids:
|
||||
item, generation, profile_name = item_map[gen_id]
|
||||
|
||||
# Update the item's start time
|
||||
item.start_time_ms = current_time_ms
|
||||
|
||||
# Calculate the duration in ms
|
||||
duration_ms = int(generation.duration * 1000)
|
||||
|
||||
# Move to next position (current end + gap)
|
||||
current_time_ms += duration_ms + gap_ms
|
||||
|
||||
# Build the response item
|
||||
updated_items.append(_build_item_detail(item, generation, profile_name, db))
|
||||
|
||||
# Update story updated_at
|
||||
story.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
return updated_items
|
||||
|
||||
|
||||
async def set_story_item_version(
|
||||
story_id: str,
|
||||
item_id: str,
|
||||
data: StoryItemVersionUpdate,
|
||||
db: Session,
|
||||
) -> Optional[StoryItemDetail]:
|
||||
"""
|
||||
Pin a story item to a specific generation version.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
item_id: Story item ID
|
||||
data: Version update data (version_id or null for default)
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Updated item detail or None if not found
|
||||
"""
|
||||
item = (
|
||||
db.query(DBStoryItem)
|
||||
.filter_by(
|
||||
id=item_id,
|
||||
story_id=story_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not item:
|
||||
return None
|
||||
|
||||
generation = db.query(DBGeneration).filter_by(id=item.generation_id).first()
|
||||
if not generation:
|
||||
return None
|
||||
|
||||
# Validate version_id belongs to this generation if provided
|
||||
if data.version_id:
|
||||
from ..database import GenerationVersion as DBGenerationVersion
|
||||
|
||||
version = (
|
||||
db.query(DBGenerationVersion)
|
||||
.filter_by(
|
||||
id=data.version_id,
|
||||
generation_id=item.generation_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not version:
|
||||
return None
|
||||
|
||||
item.version_id = data.version_id
|
||||
|
||||
# Update story updated_at
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if story:
|
||||
story.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
|
||||
|
||||
return _build_item_detail(item, generation, profile.name if profile else "Unknown", db)
|
||||
|
||||
|
||||
async def export_story_audio(
|
||||
story_id: str,
|
||||
db: Session,
|
||||
) -> Optional[bytes]:
|
||||
"""
|
||||
Export story as single mixed audio file with timecode-based mixing.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Audio file bytes or None if story not found
|
||||
"""
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if not story:
|
||||
return None
|
||||
|
||||
# Get all items ordered by start_time_ms
|
||||
items = (
|
||||
db.query(DBStoryItem, DBGeneration)
|
||||
.join(DBGeneration, DBStoryItem.generation_id == DBGeneration.id)
|
||||
.filter(DBStoryItem.story_id == story_id)
|
||||
.order_by(DBStoryItem.start_time_ms)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not items:
|
||||
return None
|
||||
|
||||
# Load all audio files and calculate total duration
|
||||
audio_data = []
|
||||
sample_rate = 24000 # Default sample rate
|
||||
|
||||
for item, generation in items:
|
||||
# Resolve audio path: use pinned version if set, otherwise generation default
|
||||
resolved_audio_path = generation.audio_path
|
||||
if getattr(item, "version_id", None):
|
||||
from ..database import GenerationVersion as DBGenerationVersion
|
||||
|
||||
version = db.query(DBGenerationVersion).filter_by(id=item.version_id).first()
|
||||
if version:
|
||||
resolved_audio_path = version.audio_path
|
||||
|
||||
audio_path = config.resolve_storage_path(resolved_audio_path)
|
||||
if audio_path is None or not audio_path.exists():
|
||||
continue
|
||||
|
||||
try:
|
||||
audio, sr = load_audio(str(audio_path), sample_rate=sample_rate)
|
||||
sample_rate = sr # Use actual sample rate from first file
|
||||
|
||||
# Get trim values
|
||||
trim_start_ms = getattr(item, "trim_start_ms", 0)
|
||||
trim_end_ms = getattr(item, "trim_end_ms", 0)
|
||||
|
||||
# Calculate effective duration
|
||||
original_duration_ms = int(generation.duration * 1000)
|
||||
effective_duration_ms = original_duration_ms - trim_start_ms - trim_end_ms
|
||||
|
||||
# Slice audio based on trim values
|
||||
trim_start_sample = int((trim_start_ms / 1000.0) * sample_rate)
|
||||
trim_end_sample = int((trim_end_ms / 1000.0) * sample_rate)
|
||||
|
||||
# Extract the trimmed portion
|
||||
if trim_end_ms > 0:
|
||||
trimmed_audio = (
|
||||
audio[trim_start_sample:-trim_end_sample] if trim_end_sample > 0 else audio[trim_start_sample:]
|
||||
)
|
||||
else:
|
||||
trimmed_audio = audio[trim_start_sample:]
|
||||
|
||||
# Store audio with its timecode info
|
||||
start_time_ms = item.start_time_ms
|
||||
|
||||
audio_data.append(
|
||||
{
|
||||
"audio": trimmed_audio,
|
||||
"start_time_ms": start_time_ms,
|
||||
"duration_ms": effective_duration_ms,
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# Skip files that can't be loaded
|
||||
continue
|
||||
|
||||
if not audio_data:
|
||||
return None
|
||||
|
||||
# Calculate total duration: max(start_time_ms + duration_ms)
|
||||
max_end_time_ms = max((data["start_time_ms"] + data["duration_ms"] for data in audio_data), default=0)
|
||||
|
||||
# Convert to samples
|
||||
total_samples = int((max_end_time_ms / 1000.0) * sample_rate)
|
||||
|
||||
# Create output buffer initialized to zeros
|
||||
final_audio = np.zeros(total_samples, dtype=np.float32)
|
||||
|
||||
# Mix each audio segment at its timecode position
|
||||
for data in audio_data:
|
||||
audio = data["audio"]
|
||||
start_time_ms = data["start_time_ms"]
|
||||
|
||||
# Calculate start sample index
|
||||
start_sample = int((start_time_ms / 1000.0) * sample_rate)
|
||||
|
||||
# Ensure we don't exceed buffer bounds
|
||||
audio_length = len(audio)
|
||||
end_sample = min(start_sample + audio_length, total_samples)
|
||||
|
||||
if start_sample < total_samples:
|
||||
# Trim audio if it extends beyond buffer
|
||||
audio_to_mix = audio[: end_sample - start_sample]
|
||||
|
||||
# Mix: add audio to existing buffer (overlapping audio will sum)
|
||||
# Normalize to prevent clipping (simple approach: divide by max)
|
||||
final_audio[start_sample:end_sample] += audio_to_mix
|
||||
|
||||
# Normalize to prevent clipping
|
||||
max_val = np.abs(final_audio).max()
|
||||
if max_val > 1.0:
|
||||
final_audio = final_audio / max_val
|
||||
|
||||
# Save to temporary file
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
save_audio(final_audio, tmp_path, sample_rate)
|
||||
|
||||
# Read file bytes
|
||||
with open(tmp_path, "rb") as f:
|
||||
audio_bytes = f.read()
|
||||
|
||||
return audio_bytes
|
||||
finally:
|
||||
# Clean up temp file
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
Reference in New Issue
Block a user