369 lines
10 KiB
Python
369 lines
10 KiB
Python
"""
|
|
Generation history management module.
|
|
"""
|
|
|
|
from typing import List, Optional, Tuple
|
|
from datetime import datetime
|
|
import uuid
|
|
import shutil
|
|
from pathlib import Path
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy import or_
|
|
|
|
from ..models import GenerationRequest, GenerationResponse, HistoryQuery, HistoryResponse, HistoryListResponse, GenerationVersionResponse, EffectConfig
|
|
from ..database import Generation as DBGeneration, GenerationVersion as DBGenerationVersion, VoiceProfile as DBVoiceProfile
|
|
from .. import config
|
|
|
|
|
|
def _get_versions_for_generation(generation_id: str, db: Session) -> tuple:
|
|
"""Get versions list and active version ID for a generation."""
|
|
import json
|
|
versions_rows = (
|
|
db.query(DBGenerationVersion)
|
|
.filter_by(generation_id=generation_id)
|
|
.order_by(DBGenerationVersion.created_at)
|
|
.all()
|
|
)
|
|
if not versions_rows:
|
|
return None, None
|
|
|
|
versions = []
|
|
active_version_id = None
|
|
for v in versions_rows:
|
|
effects_chain = None
|
|
if v.effects_chain:
|
|
try:
|
|
raw = json.loads(v.effects_chain)
|
|
effects_chain = [EffectConfig(**e) for e in raw]
|
|
except Exception:
|
|
pass
|
|
versions.append(GenerationVersionResponse(
|
|
id=v.id,
|
|
generation_id=v.generation_id,
|
|
label=v.label,
|
|
audio_path=v.audio_path,
|
|
effects_chain=effects_chain,
|
|
is_default=v.is_default,
|
|
created_at=v.created_at,
|
|
))
|
|
if v.is_default:
|
|
active_version_id = v.id
|
|
|
|
return versions, active_version_id
|
|
|
|
|
|
async def create_generation(
|
|
profile_id: str,
|
|
text: str,
|
|
language: str,
|
|
audio_path: str,
|
|
duration: float,
|
|
seed: Optional[int],
|
|
db: Session,
|
|
instruct: Optional[str] = None,
|
|
generation_id: Optional[str] = None,
|
|
status: str = "completed",
|
|
engine: Optional[str] = "qwen",
|
|
model_size: Optional[str] = None,
|
|
) -> GenerationResponse:
|
|
"""
|
|
Create a new generation history entry.
|
|
|
|
Args:
|
|
profile_id: Profile ID used for generation
|
|
text: Generated text
|
|
language: Language code
|
|
audio_path: Path where audio was saved
|
|
duration: Audio duration in seconds
|
|
seed: Random seed used (if any)
|
|
db: Database session
|
|
instruct: Natural language instruction used (if any)
|
|
generation_id: Pre-assigned ID (for async generation flow)
|
|
status: Generation status (generating, completed, failed)
|
|
engine: TTS engine used (qwen, luxtts, chatterbox, chatterbox_turbo)
|
|
model_size: Model size variant (1.7B, 0.6B) — only relevant for qwen
|
|
|
|
Returns:
|
|
Created generation entry
|
|
"""
|
|
db_generation = DBGeneration(
|
|
id=generation_id or str(uuid.uuid4()),
|
|
profile_id=profile_id,
|
|
text=text,
|
|
language=language,
|
|
audio_path=audio_path,
|
|
duration=duration,
|
|
seed=seed,
|
|
instruct=instruct,
|
|
engine=engine,
|
|
model_size=model_size,
|
|
status=status,
|
|
created_at=datetime.utcnow(),
|
|
)
|
|
|
|
db.add(db_generation)
|
|
db.commit()
|
|
db.refresh(db_generation)
|
|
|
|
return GenerationResponse.model_validate(db_generation)
|
|
|
|
|
|
async def update_generation_status(
|
|
generation_id: str,
|
|
status: str,
|
|
db: Session,
|
|
audio_path: Optional[str] = None,
|
|
duration: Optional[float] = None,
|
|
error: Optional[str] = None,
|
|
) -> Optional[GenerationResponse]:
|
|
"""Update the status of a generation (used by async generation flow)."""
|
|
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
|
|
if not generation:
|
|
return None
|
|
|
|
generation.status = status
|
|
if audio_path is not None:
|
|
generation.audio_path = audio_path
|
|
if duration is not None:
|
|
generation.duration = duration
|
|
if error is not None:
|
|
generation.error = error
|
|
|
|
db.commit()
|
|
db.refresh(generation)
|
|
return GenerationResponse.model_validate(generation)
|
|
|
|
|
|
async def get_generation(
|
|
generation_id: str,
|
|
db: Session,
|
|
) -> Optional[GenerationResponse]:
|
|
"""
|
|
Get a generation by ID.
|
|
|
|
Args:
|
|
generation_id: Generation ID
|
|
db: Database session
|
|
|
|
Returns:
|
|
Generation or None if not found
|
|
"""
|
|
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
|
|
if not generation:
|
|
return None
|
|
|
|
return GenerationResponse.model_validate(generation)
|
|
|
|
|
|
async def list_generations(
|
|
query: HistoryQuery,
|
|
db: Session,
|
|
) -> HistoryListResponse:
|
|
"""
|
|
List generations with optional filters.
|
|
|
|
Args:
|
|
query: Query parameters (filters, pagination)
|
|
db: Database session
|
|
|
|
Returns:
|
|
HistoryListResponse with items and total count
|
|
"""
|
|
# Build base query with join to get profile name
|
|
q = db.query(
|
|
DBGeneration,
|
|
DBVoiceProfile.name.label('profile_name')
|
|
).join(
|
|
DBVoiceProfile,
|
|
DBGeneration.profile_id == DBVoiceProfile.id
|
|
)
|
|
|
|
# Apply profile filter
|
|
if query.profile_id:
|
|
q = q.filter(DBGeneration.profile_id == query.profile_id)
|
|
|
|
# Apply search filter (searches in text content)
|
|
if query.search:
|
|
search_pattern = f"%{query.search}%"
|
|
q = q.filter(DBGeneration.text.like(search_pattern))
|
|
|
|
# Get total count before pagination
|
|
total_count = q.count()
|
|
|
|
# Apply ordering (newest first)
|
|
q = q.order_by(DBGeneration.created_at.desc())
|
|
|
|
# Apply pagination
|
|
q = q.offset(query.offset).limit(query.limit)
|
|
|
|
# Execute query
|
|
results = q.all()
|
|
|
|
# Convert to HistoryResponse with profile_name
|
|
items = []
|
|
for generation, profile_name in results:
|
|
versions, active_version_id = _get_versions_for_generation(generation.id, db)
|
|
items.append(HistoryResponse(
|
|
id=generation.id,
|
|
profile_id=generation.profile_id,
|
|
profile_name=profile_name,
|
|
text=generation.text,
|
|
language=generation.language,
|
|
audio_path=generation.audio_path,
|
|
duration=generation.duration,
|
|
seed=generation.seed,
|
|
instruct=generation.instruct,
|
|
engine=generation.engine or "qwen",
|
|
model_size=generation.model_size,
|
|
status=generation.status or "completed",
|
|
error=generation.error,
|
|
is_favorited=bool(generation.is_favorited),
|
|
created_at=generation.created_at,
|
|
versions=versions,
|
|
active_version_id=active_version_id,
|
|
))
|
|
|
|
return HistoryListResponse(
|
|
items=items,
|
|
total=total_count,
|
|
)
|
|
|
|
|
|
async def delete_generation(
|
|
generation_id: str,
|
|
db: Session,
|
|
) -> bool:
|
|
"""
|
|
Delete a generation.
|
|
|
|
Args:
|
|
generation_id: Generation ID
|
|
db: Database session
|
|
|
|
Returns:
|
|
True if deleted, False if not found
|
|
"""
|
|
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
|
|
if not generation:
|
|
return False
|
|
|
|
# Delete all version files and records
|
|
from . import versions as versions_mod
|
|
versions_mod.delete_versions_for_generation(generation_id, db)
|
|
|
|
# Delete main audio file (if not already removed by version cleanup)
|
|
if generation.audio_path:
|
|
audio_path = config.resolve_storage_path(generation.audio_path)
|
|
if audio_path is not None and audio_path.exists():
|
|
audio_path.unlink()
|
|
|
|
# Delete from database
|
|
db.delete(generation)
|
|
db.commit()
|
|
|
|
return True
|
|
|
|
|
|
async def delete_failed_generations(db: Session) -> int:
|
|
"""
|
|
Delete every generation whose status is 'failed'.
|
|
|
|
Used by the "Clear failed" action in the UI so users can tidy up
|
|
history after the model wasn't loaded, the app was closed mid-run,
|
|
or a generation otherwise errored out (see issue #410).
|
|
|
|
Returns:
|
|
Number of generations deleted.
|
|
"""
|
|
from . import versions as versions_mod
|
|
|
|
failed = db.query(DBGeneration).filter(DBGeneration.status == "failed").all()
|
|
count = 0
|
|
for generation in failed:
|
|
# Clean up version files/rows first.
|
|
versions_mod.delete_versions_for_generation(generation.id, db)
|
|
|
|
# Remove the main audio file if it somehow made it to disk.
|
|
if generation.audio_path:
|
|
audio_path = config.resolve_storage_path(generation.audio_path)
|
|
if audio_path is not None and audio_path.exists():
|
|
try:
|
|
audio_path.unlink()
|
|
except OSError:
|
|
# Best-effort cleanup — don't abort the whole sweep
|
|
# if a single file can't be removed.
|
|
pass
|
|
|
|
db.delete(generation)
|
|
count += 1
|
|
|
|
db.commit()
|
|
return count
|
|
|
|
|
|
async def delete_generations_by_profile(
|
|
profile_id: str,
|
|
db: Session,
|
|
) -> int:
|
|
"""
|
|
Delete all generations for a profile.
|
|
|
|
Args:
|
|
profile_id: Profile ID
|
|
db: Database session
|
|
|
|
Returns:
|
|
Number of generations deleted
|
|
"""
|
|
generations = db.query(DBGeneration).filter_by(profile_id=profile_id).all()
|
|
|
|
count = 0
|
|
for generation in generations:
|
|
# Delete associated version files and rows first
|
|
from . import versions as versions_mod
|
|
versions_mod.delete_versions_for_generation(generation.id, db)
|
|
|
|
# Delete audio file
|
|
audio_path = config.resolve_storage_path(generation.audio_path)
|
|
if audio_path is not None and audio_path.exists():
|
|
audio_path.unlink()
|
|
|
|
# Delete from database
|
|
db.delete(generation)
|
|
count += 1
|
|
|
|
db.commit()
|
|
|
|
return count
|
|
|
|
|
|
async def get_generation_stats(db: Session) -> dict:
|
|
"""
|
|
Get generation statistics.
|
|
|
|
Args:
|
|
db: Database session
|
|
|
|
Returns:
|
|
Statistics dictionary
|
|
"""
|
|
from sqlalchemy import func
|
|
|
|
total = db.query(func.count(DBGeneration.id)).scalar()
|
|
|
|
total_duration = db.query(func.sum(DBGeneration.duration)).scalar() or 0
|
|
|
|
# Get generations by profile
|
|
by_profile = db.query(
|
|
DBGeneration.profile_id,
|
|
func.count(DBGeneration.id).label('count')
|
|
).group_by(DBGeneration.profile_id).all()
|
|
|
|
return {
|
|
"total_generations": total,
|
|
"total_duration_seconds": total_duration,
|
|
"generations_by_profile": {
|
|
profile_id: count for profile_id, count in by_profile
|
|
},
|
|
}
|