Initial commit
This commit is contained in:
368
backend/services/history.py
Normal file
368
backend/services/history.py
Normal file
@@ -0,0 +1,368 @@
|
||||
"""
|
||||
Generation history management module.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import or_
|
||||
|
||||
from ..models import GenerationRequest, GenerationResponse, HistoryQuery, HistoryResponse, HistoryListResponse, GenerationVersionResponse, EffectConfig
|
||||
from ..database import Generation as DBGeneration, GenerationVersion as DBGenerationVersion, VoiceProfile as DBVoiceProfile
|
||||
from .. import config
|
||||
|
||||
|
||||
def _get_versions_for_generation(generation_id: str, db: Session) -> tuple:
|
||||
"""Get versions list and active version ID for a generation."""
|
||||
import json
|
||||
versions_rows = (
|
||||
db.query(DBGenerationVersion)
|
||||
.filter_by(generation_id=generation_id)
|
||||
.order_by(DBGenerationVersion.created_at)
|
||||
.all()
|
||||
)
|
||||
if not versions_rows:
|
||||
return None, None
|
||||
|
||||
versions = []
|
||||
active_version_id = None
|
||||
for v in versions_rows:
|
||||
effects_chain = None
|
||||
if v.effects_chain:
|
||||
try:
|
||||
raw = json.loads(v.effects_chain)
|
||||
effects_chain = [EffectConfig(**e) for e in raw]
|
||||
except Exception:
|
||||
pass
|
||||
versions.append(GenerationVersionResponse(
|
||||
id=v.id,
|
||||
generation_id=v.generation_id,
|
||||
label=v.label,
|
||||
audio_path=v.audio_path,
|
||||
effects_chain=effects_chain,
|
||||
is_default=v.is_default,
|
||||
created_at=v.created_at,
|
||||
))
|
||||
if v.is_default:
|
||||
active_version_id = v.id
|
||||
|
||||
return versions, active_version_id
|
||||
|
||||
|
||||
async def create_generation(
|
||||
profile_id: str,
|
||||
text: str,
|
||||
language: str,
|
||||
audio_path: str,
|
||||
duration: float,
|
||||
seed: Optional[int],
|
||||
db: Session,
|
||||
instruct: Optional[str] = None,
|
||||
generation_id: Optional[str] = None,
|
||||
status: str = "completed",
|
||||
engine: Optional[str] = "qwen",
|
||||
model_size: Optional[str] = None,
|
||||
) -> GenerationResponse:
|
||||
"""
|
||||
Create a new generation history entry.
|
||||
|
||||
Args:
|
||||
profile_id: Profile ID used for generation
|
||||
text: Generated text
|
||||
language: Language code
|
||||
audio_path: Path where audio was saved
|
||||
duration: Audio duration in seconds
|
||||
seed: Random seed used (if any)
|
||||
db: Database session
|
||||
instruct: Natural language instruction used (if any)
|
||||
generation_id: Pre-assigned ID (for async generation flow)
|
||||
status: Generation status (generating, completed, failed)
|
||||
engine: TTS engine used (qwen, luxtts, chatterbox, chatterbox_turbo)
|
||||
model_size: Model size variant (1.7B, 0.6B) — only relevant for qwen
|
||||
|
||||
Returns:
|
||||
Created generation entry
|
||||
"""
|
||||
db_generation = DBGeneration(
|
||||
id=generation_id or str(uuid.uuid4()),
|
||||
profile_id=profile_id,
|
||||
text=text,
|
||||
language=language,
|
||||
audio_path=audio_path,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
instruct=instruct,
|
||||
engine=engine,
|
||||
model_size=model_size,
|
||||
status=status,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
db.add(db_generation)
|
||||
db.commit()
|
||||
db.refresh(db_generation)
|
||||
|
||||
return GenerationResponse.model_validate(db_generation)
|
||||
|
||||
|
||||
async def update_generation_status(
|
||||
generation_id: str,
|
||||
status: str,
|
||||
db: Session,
|
||||
audio_path: Optional[str] = None,
|
||||
duration: Optional[float] = None,
|
||||
error: Optional[str] = None,
|
||||
) -> Optional[GenerationResponse]:
|
||||
"""Update the status of a generation (used by async generation flow)."""
|
||||
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if not generation:
|
||||
return None
|
||||
|
||||
generation.status = status
|
||||
if audio_path is not None:
|
||||
generation.audio_path = audio_path
|
||||
if duration is not None:
|
||||
generation.duration = duration
|
||||
if error is not None:
|
||||
generation.error = error
|
||||
|
||||
db.commit()
|
||||
db.refresh(generation)
|
||||
return GenerationResponse.model_validate(generation)
|
||||
|
||||
|
||||
async def get_generation(
|
||||
generation_id: str,
|
||||
db: Session,
|
||||
) -> Optional[GenerationResponse]:
|
||||
"""
|
||||
Get a generation by ID.
|
||||
|
||||
Args:
|
||||
generation_id: Generation ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Generation or None if not found
|
||||
"""
|
||||
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if not generation:
|
||||
return None
|
||||
|
||||
return GenerationResponse.model_validate(generation)
|
||||
|
||||
|
||||
async def list_generations(
|
||||
query: HistoryQuery,
|
||||
db: Session,
|
||||
) -> HistoryListResponse:
|
||||
"""
|
||||
List generations with optional filters.
|
||||
|
||||
Args:
|
||||
query: Query parameters (filters, pagination)
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
HistoryListResponse with items and total count
|
||||
"""
|
||||
# Build base query with join to get profile name
|
||||
q = db.query(
|
||||
DBGeneration,
|
||||
DBVoiceProfile.name.label('profile_name')
|
||||
).join(
|
||||
DBVoiceProfile,
|
||||
DBGeneration.profile_id == DBVoiceProfile.id
|
||||
)
|
||||
|
||||
# Apply profile filter
|
||||
if query.profile_id:
|
||||
q = q.filter(DBGeneration.profile_id == query.profile_id)
|
||||
|
||||
# Apply search filter (searches in text content)
|
||||
if query.search:
|
||||
search_pattern = f"%{query.search}%"
|
||||
q = q.filter(DBGeneration.text.like(search_pattern))
|
||||
|
||||
# Get total count before pagination
|
||||
total_count = q.count()
|
||||
|
||||
# Apply ordering (newest first)
|
||||
q = q.order_by(DBGeneration.created_at.desc())
|
||||
|
||||
# Apply pagination
|
||||
q = q.offset(query.offset).limit(query.limit)
|
||||
|
||||
# Execute query
|
||||
results = q.all()
|
||||
|
||||
# Convert to HistoryResponse with profile_name
|
||||
items = []
|
||||
for generation, profile_name in results:
|
||||
versions, active_version_id = _get_versions_for_generation(generation.id, db)
|
||||
items.append(HistoryResponse(
|
||||
id=generation.id,
|
||||
profile_id=generation.profile_id,
|
||||
profile_name=profile_name,
|
||||
text=generation.text,
|
||||
language=generation.language,
|
||||
audio_path=generation.audio_path,
|
||||
duration=generation.duration,
|
||||
seed=generation.seed,
|
||||
instruct=generation.instruct,
|
||||
engine=generation.engine or "qwen",
|
||||
model_size=generation.model_size,
|
||||
status=generation.status or "completed",
|
||||
error=generation.error,
|
||||
is_favorited=bool(generation.is_favorited),
|
||||
created_at=generation.created_at,
|
||||
versions=versions,
|
||||
active_version_id=active_version_id,
|
||||
))
|
||||
|
||||
return HistoryListResponse(
|
||||
items=items,
|
||||
total=total_count,
|
||||
)
|
||||
|
||||
|
||||
async def delete_generation(
|
||||
generation_id: str,
|
||||
db: Session,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a generation.
|
||||
|
||||
Args:
|
||||
generation_id: Generation ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if not generation:
|
||||
return False
|
||||
|
||||
# Delete all version files and records
|
||||
from . import versions as versions_mod
|
||||
versions_mod.delete_versions_for_generation(generation_id, db)
|
||||
|
||||
# Delete main audio file (if not already removed by version cleanup)
|
||||
if generation.audio_path:
|
||||
audio_path = config.resolve_storage_path(generation.audio_path)
|
||||
if audio_path is not None and audio_path.exists():
|
||||
audio_path.unlink()
|
||||
|
||||
# Delete from database
|
||||
db.delete(generation)
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def delete_failed_generations(db: Session) -> int:
|
||||
"""
|
||||
Delete every generation whose status is 'failed'.
|
||||
|
||||
Used by the "Clear failed" action in the UI so users can tidy up
|
||||
history after the model wasn't loaded, the app was closed mid-run,
|
||||
or a generation otherwise errored out (see issue #410).
|
||||
|
||||
Returns:
|
||||
Number of generations deleted.
|
||||
"""
|
||||
from . import versions as versions_mod
|
||||
|
||||
failed = db.query(DBGeneration).filter(DBGeneration.status == "failed").all()
|
||||
count = 0
|
||||
for generation in failed:
|
||||
# Clean up version files/rows first.
|
||||
versions_mod.delete_versions_for_generation(generation.id, db)
|
||||
|
||||
# Remove the main audio file if it somehow made it to disk.
|
||||
if generation.audio_path:
|
||||
audio_path = config.resolve_storage_path(generation.audio_path)
|
||||
if audio_path is not None and audio_path.exists():
|
||||
try:
|
||||
audio_path.unlink()
|
||||
except OSError:
|
||||
# Best-effort cleanup — don't abort the whole sweep
|
||||
# if a single file can't be removed.
|
||||
pass
|
||||
|
||||
db.delete(generation)
|
||||
count += 1
|
||||
|
||||
db.commit()
|
||||
return count
|
||||
|
||||
|
||||
async def delete_generations_by_profile(
|
||||
profile_id: str,
|
||||
db: Session,
|
||||
) -> int:
|
||||
"""
|
||||
Delete all generations for a profile.
|
||||
|
||||
Args:
|
||||
profile_id: Profile ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Number of generations deleted
|
||||
"""
|
||||
generations = db.query(DBGeneration).filter_by(profile_id=profile_id).all()
|
||||
|
||||
count = 0
|
||||
for generation in generations:
|
||||
# Delete associated version files and rows first
|
||||
from . import versions as versions_mod
|
||||
versions_mod.delete_versions_for_generation(generation.id, db)
|
||||
|
||||
# Delete audio file
|
||||
audio_path = config.resolve_storage_path(generation.audio_path)
|
||||
if audio_path is not None and audio_path.exists():
|
||||
audio_path.unlink()
|
||||
|
||||
# Delete from database
|
||||
db.delete(generation)
|
||||
count += 1
|
||||
|
||||
db.commit()
|
||||
|
||||
return count
|
||||
|
||||
|
||||
async def get_generation_stats(db: Session) -> dict:
|
||||
"""
|
||||
Get generation statistics.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Statistics dictionary
|
||||
"""
|
||||
from sqlalchemy import func
|
||||
|
||||
total = db.query(func.count(DBGeneration.id)).scalar()
|
||||
|
||||
total_duration = db.query(func.sum(DBGeneration.duration)).scalar() or 0
|
||||
|
||||
# Get generations by profile
|
||||
by_profile = db.query(
|
||||
DBGeneration.profile_id,
|
||||
func.count(DBGeneration.id).label('count')
|
||||
).group_by(DBGeneration.profile_id).all()
|
||||
|
||||
return {
|
||||
"total_generations": total,
|
||||
"total_duration_seconds": total_duration,
|
||||
"generations_by_profile": {
|
||||
profile_id: count for profile_id, count in by_profile
|
||||
},
|
||||
}
|
||||
Reference in New Issue
Block a user