Initial commit
This commit is contained in:
461
backend/services/export_import.py
Normal file
461
backend/services/export_import.py
Normal file
@@ -0,0 +1,461 @@
|
||||
"""
|
||||
Voice profile export/import module.
|
||||
|
||||
Handles exporting profiles to ZIP archives and importing them back.
|
||||
Also handles exporting individual generations.
|
||||
"""
|
||||
|
||||
import json
|
||||
import zipfile
|
||||
import io
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..models import VoiceProfileResponse
|
||||
from ..database import VoiceProfile as DBVoiceProfile, ProfileSample as DBProfileSample, Generation as DBGeneration, GenerationVersion as DBGenerationVersion
|
||||
from .profiles import create_profile, add_profile_sample
|
||||
from ..models import VoiceProfileCreate
|
||||
from .. import config
|
||||
|
||||
|
||||
def _get_unique_profile_name(name: str, db: Session) -> str:
|
||||
"""
|
||||
Get a unique profile name by appending a number if needed.
|
||||
|
||||
Args:
|
||||
name: Original profile name
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Unique profile name
|
||||
"""
|
||||
base_name = name
|
||||
counter = 1
|
||||
|
||||
while True:
|
||||
existing = db.query(DBVoiceProfile).filter_by(name=name).first()
|
||||
if not existing:
|
||||
return name
|
||||
|
||||
name = f"{base_name} ({counter})"
|
||||
counter += 1
|
||||
|
||||
|
||||
def export_profile_to_zip(profile_id: str, db: Session) -> bytes:
|
||||
"""
|
||||
Export a voice profile to a ZIP archive.
|
||||
|
||||
Args:
|
||||
profile_id: Profile ID to export
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ZIP file contents as bytes
|
||||
|
||||
Raises:
|
||||
ValueError: If profile not found or has no samples
|
||||
"""
|
||||
# Get profile
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
|
||||
if not profile:
|
||||
raise ValueError(f"Profile {profile_id} not found")
|
||||
|
||||
# Get all samples
|
||||
samples = db.query(DBProfileSample).filter_by(profile_id=profile_id).all()
|
||||
if not samples:
|
||||
raise ValueError(f"Profile {profile_id} has no samples")
|
||||
|
||||
# Create ZIP in memory
|
||||
zip_buffer = io.BytesIO()
|
||||
|
||||
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
|
||||
# Check if profile has avatar
|
||||
has_avatar = False
|
||||
if profile.avatar_path:
|
||||
avatar_path = config.resolve_storage_path(profile.avatar_path)
|
||||
if avatar_path is not None and avatar_path.exists():
|
||||
has_avatar = True
|
||||
# Add avatar to ZIP root with original extension
|
||||
avatar_ext = avatar_path.suffix
|
||||
zip_file.write(avatar_path, f"avatar{avatar_ext}")
|
||||
|
||||
# Create manifest.json
|
||||
manifest = {
|
||||
"version": "1.0",
|
||||
"profile": {
|
||||
"name": profile.name,
|
||||
"description": profile.description,
|
||||
"language": profile.language,
|
||||
},
|
||||
"has_avatar": has_avatar,
|
||||
}
|
||||
zip_file.writestr("manifest.json", json.dumps(manifest, indent=2))
|
||||
|
||||
# Create samples.json mapping
|
||||
samples_data = {}
|
||||
profile_dir = config.get_profiles_dir() / profile_id
|
||||
|
||||
for sample in samples:
|
||||
# Get filename from audio_path (should be {sample_id}.wav)
|
||||
audio_path = config.resolve_storage_path(sample.audio_path)
|
||||
if audio_path is None:
|
||||
raise ValueError(f"Audio file not found: {sample.audio_path}")
|
||||
filename = audio_path.name
|
||||
|
||||
# Read audio file
|
||||
if not audio_path.exists():
|
||||
raise ValueError(f"Audio file not found: {audio_path}")
|
||||
|
||||
# Add to samples directory in ZIP
|
||||
zip_path = f"samples/{filename}"
|
||||
zip_file.write(audio_path, zip_path)
|
||||
|
||||
# Map filename to reference text
|
||||
samples_data[filename] = sample.reference_text
|
||||
|
||||
zip_file.writestr("samples.json", json.dumps(samples_data, indent=2))
|
||||
|
||||
zip_buffer.seek(0)
|
||||
return zip_buffer.read()
|
||||
|
||||
|
||||
async def import_profile_from_zip(file_bytes: bytes, db: Session) -> VoiceProfileResponse:
|
||||
"""
|
||||
Import a voice profile from a ZIP archive.
|
||||
|
||||
Args:
|
||||
file_bytes: ZIP file contents
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Created profile
|
||||
|
||||
Raises:
|
||||
ValueError: If ZIP is invalid or missing required files
|
||||
"""
|
||||
zip_buffer = io.BytesIO(file_bytes)
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_buffer, 'r') as zip_file:
|
||||
# Validate ZIP structure
|
||||
namelist = zip_file.namelist()
|
||||
|
||||
if "manifest.json" not in namelist:
|
||||
raise ValueError("ZIP archive missing manifest.json")
|
||||
|
||||
if "samples.json" not in namelist:
|
||||
raise ValueError("ZIP archive missing samples.json")
|
||||
|
||||
# Read manifest
|
||||
manifest_data = json.loads(zip_file.read("manifest.json"))
|
||||
|
||||
if "version" not in manifest_data:
|
||||
raise ValueError("Invalid manifest.json: missing version")
|
||||
|
||||
if "profile" not in manifest_data:
|
||||
raise ValueError("Invalid manifest.json: missing profile")
|
||||
|
||||
profile_data = manifest_data["profile"]
|
||||
|
||||
# Read samples mapping
|
||||
samples_data = json.loads(zip_file.read("samples.json"))
|
||||
|
||||
if not isinstance(samples_data, dict):
|
||||
raise ValueError("Invalid samples.json: must be a dictionary")
|
||||
|
||||
# Get unique profile name
|
||||
original_name = profile_data.get("name", "Imported Profile")
|
||||
unique_name = _get_unique_profile_name(original_name, db)
|
||||
|
||||
# Create profile
|
||||
profile_create = VoiceProfileCreate(
|
||||
name=unique_name,
|
||||
description=profile_data.get("description"),
|
||||
language=profile_data.get("language", "en"),
|
||||
)
|
||||
|
||||
profile = await create_profile(profile_create, db)
|
||||
|
||||
# Extract and add samples
|
||||
profile_dir = config.get_profiles_dir() / profile.id
|
||||
profile_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Handle avatar if present
|
||||
avatar_files = [f for f in namelist if f.startswith("avatar.")]
|
||||
if avatar_files:
|
||||
try:
|
||||
avatar_file = avatar_files[0]
|
||||
# Extract to temporary file
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(suffix=Path(avatar_file).suffix, delete=False) as tmp:
|
||||
tmp.write(zip_file.read(avatar_file))
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
from .profiles import upload_avatar
|
||||
await upload_avatar(profile.id, tmp_path, db)
|
||||
finally:
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
except Exception as e:
|
||||
# Avatar import is optional - continue even if it fails
|
||||
pass
|
||||
|
||||
for filename, reference_text in samples_data.items():
|
||||
# Validate filename
|
||||
if not filename.endswith('.wav'):
|
||||
raise ValueError(f"Invalid sample filename: {filename} (must be .wav)")
|
||||
|
||||
# Extract audio file to temp location
|
||||
zip_path = f"samples/{filename}"
|
||||
|
||||
if zip_path not in namelist:
|
||||
raise ValueError(f"Sample file not found in ZIP: {zip_path}")
|
||||
|
||||
# Extract to temporary file
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
||||
tmp.write(zip_file.read(zip_path))
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
# Add sample to profile
|
||||
await add_profile_sample(
|
||||
profile.id,
|
||||
tmp_path,
|
||||
reference_text,
|
||||
db,
|
||||
)
|
||||
finally:
|
||||
# Clean up temp file
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
|
||||
return profile
|
||||
|
||||
except zipfile.BadZipFile:
|
||||
raise ValueError("Invalid ZIP file")
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON in archive: {e}")
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise
|
||||
raise ValueError(f"Error importing profile: {str(e)}")
|
||||
|
||||
|
||||
def export_generation_to_zip(generation_id: str, db: Session) -> bytes:
|
||||
"""
|
||||
Export a generation to a ZIP archive.
|
||||
|
||||
Args:
|
||||
generation_id: Generation ID to export
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ZIP file contents as bytes
|
||||
|
||||
Raises:
|
||||
ValueError: If generation not found
|
||||
"""
|
||||
# Get generation
|
||||
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if not generation:
|
||||
raise ValueError(f"Generation {generation_id} not found")
|
||||
|
||||
# Get profile info
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
|
||||
if not profile:
|
||||
raise ValueError(f"Profile {generation.profile_id} not found")
|
||||
|
||||
# Get all versions for this generation
|
||||
versions = (
|
||||
db.query(DBGenerationVersion)
|
||||
.filter_by(generation_id=generation_id)
|
||||
.order_by(DBGenerationVersion.created_at)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Create ZIP in memory
|
||||
zip_buffer = io.BytesIO()
|
||||
|
||||
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
|
||||
# Build version manifest entries
|
||||
version_entries = []
|
||||
for v in versions:
|
||||
v_path = config.resolve_storage_path(v.audio_path)
|
||||
effects_chain = None
|
||||
if v.effects_chain:
|
||||
effects_chain = json.loads(v.effects_chain)
|
||||
version_entries.append({
|
||||
"id": v.id,
|
||||
"label": v.label,
|
||||
"is_default": v.is_default,
|
||||
"effects_chain": effects_chain,
|
||||
"filename": v_path.name,
|
||||
})
|
||||
|
||||
manifest = {
|
||||
"version": "1.0",
|
||||
"generation": {
|
||||
"id": generation.id,
|
||||
"text": generation.text,
|
||||
"language": generation.language,
|
||||
"duration": generation.duration,
|
||||
"seed": generation.seed,
|
||||
"instruct": generation.instruct,
|
||||
"created_at": generation.created_at.isoformat(),
|
||||
},
|
||||
"profile": {
|
||||
"id": profile.id,
|
||||
"name": profile.name,
|
||||
"description": profile.description,
|
||||
"language": profile.language,
|
||||
},
|
||||
"versions": version_entries,
|
||||
}
|
||||
zip_file.writestr("manifest.json", json.dumps(manifest, indent=2))
|
||||
|
||||
# Add all version audio files
|
||||
for v in versions:
|
||||
v_path = config.resolve_storage_path(v.audio_path)
|
||||
if v_path is not None and v_path.exists():
|
||||
zip_file.write(v_path, f"audio/{v_path.name}")
|
||||
|
||||
# Fallback: if no versions exist, include the generation's main audio
|
||||
if not versions:
|
||||
audio_path = config.resolve_storage_path(generation.audio_path)
|
||||
if audio_path is not None and audio_path.exists():
|
||||
zip_file.write(audio_path, f"audio/{audio_path.name}")
|
||||
|
||||
zip_buffer.seek(0)
|
||||
return zip_buffer.read()
|
||||
|
||||
|
||||
async def import_generation_from_zip(file_bytes: bytes, db: Session) -> dict:
|
||||
"""
|
||||
Import a generation from a ZIP archive.
|
||||
|
||||
Args:
|
||||
file_bytes: ZIP file contents
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Dictionary with generation ID and profile info
|
||||
|
||||
Raises:
|
||||
ValueError: If ZIP is invalid or missing required files
|
||||
"""
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from .. import config
|
||||
|
||||
zip_buffer = io.BytesIO(file_bytes)
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_buffer, 'r') as zip_file:
|
||||
# Validate ZIP structure
|
||||
namelist = zip_file.namelist()
|
||||
|
||||
if "manifest.json" not in namelist:
|
||||
raise ValueError("ZIP archive missing manifest.json")
|
||||
|
||||
# Read manifest
|
||||
manifest_data = json.loads(zip_file.read("manifest.json"))
|
||||
|
||||
if "version" not in manifest_data:
|
||||
raise ValueError("Invalid manifest.json: missing version")
|
||||
|
||||
if "generation" not in manifest_data:
|
||||
raise ValueError("Invalid manifest.json: missing generation data")
|
||||
|
||||
generation_data = manifest_data["generation"]
|
||||
profile_data = manifest_data.get("profile", {})
|
||||
|
||||
# Validate required fields
|
||||
required_fields = ["text", "language", "duration"]
|
||||
for field in required_fields:
|
||||
if field not in generation_data:
|
||||
raise ValueError(f"Invalid manifest.json: missing generation.{field}")
|
||||
|
||||
# Find audio file in archive
|
||||
audio_files = [f for f in namelist if f.startswith("audio/") and f.endswith(".wav")]
|
||||
if not audio_files:
|
||||
raise ValueError("No audio file found in ZIP archive")
|
||||
|
||||
audio_file_path = audio_files[0]
|
||||
|
||||
# Check if we should match an existing profile or create metadata
|
||||
profile_id = None
|
||||
profile_name = profile_data.get("name", "Unknown Profile")
|
||||
|
||||
# Try to find matching profile by name
|
||||
if profile_name and profile_name != "Unknown Profile":
|
||||
existing_profile = db.query(DBVoiceProfile).filter_by(name=profile_name).first()
|
||||
if existing_profile:
|
||||
profile_id = existing_profile.id
|
||||
|
||||
# If no matching profile, use a placeholder or the first available profile
|
||||
if not profile_id:
|
||||
# Get any profile, or None if no profiles exist
|
||||
any_profile = db.query(DBVoiceProfile).first()
|
||||
if any_profile:
|
||||
profile_id = any_profile.id
|
||||
profile_name = any_profile.name
|
||||
else:
|
||||
raise ValueError("No voice profiles found. Please create a profile before importing generations.")
|
||||
|
||||
# Extract audio file to temporary location
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
||||
tmp.write(zip_file.read(audio_file_path))
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
# Create generations directory
|
||||
generations_dir = config.get_generations_dir()
|
||||
generations_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate new ID for this generation
|
||||
new_generation_id = str(__import__('uuid').uuid4())
|
||||
|
||||
# Copy audio to generations directory
|
||||
audio_dest = generations_dir / f"{new_generation_id}.wav"
|
||||
shutil.copy(tmp_path, audio_dest)
|
||||
|
||||
# Create generation record
|
||||
db_generation = DBGeneration(
|
||||
id=new_generation_id,
|
||||
profile_id=profile_id,
|
||||
text=generation_data["text"],
|
||||
language=generation_data["language"],
|
||||
audio_path=config.to_storage_path(audio_dest),
|
||||
duration=generation_data["duration"],
|
||||
seed=generation_data.get("seed"),
|
||||
instruct=generation_data.get("instruct"),
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
db.add(db_generation)
|
||||
db.commit()
|
||||
db.refresh(db_generation)
|
||||
|
||||
return {
|
||||
"id": db_generation.id,
|
||||
"profile_id": profile_id,
|
||||
"profile_name": profile_name,
|
||||
"text": db_generation.text,
|
||||
"message": f"Generation imported successfully (assigned to profile: {profile_name})"
|
||||
}
|
||||
|
||||
finally:
|
||||
# Clean up temp file
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
|
||||
except zipfile.BadZipFile:
|
||||
raise ValueError("Invalid ZIP file")
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON in archive: {e}")
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise
|
||||
raise ValueError(f"Error importing generation: {str(e)}")
|
||||
Reference in New Issue
Block a user