Initial commit
This commit is contained in:
383
backend/utils/hf_progress.py
Normal file
383
backend/utils/hf_progress.py
Normal file
@@ -0,0 +1,383 @@
|
||||
"""
|
||||
HuggingFace Hub download progress tracking.
|
||||
"""
|
||||
|
||||
from typing import Optional, Callable
|
||||
from contextlib import contextmanager
|
||||
import logging
|
||||
import threading
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HFProgressTracker:
|
||||
"""Tracks HuggingFace Hub download progress by intercepting tqdm."""
|
||||
|
||||
def __init__(self, progress_callback: Optional[Callable] = None, filter_non_downloads: bool = False):
|
||||
self.progress_callback = progress_callback
|
||||
self.filter_non_downloads = filter_non_downloads # Only filter if True
|
||||
self._original_tqdm_class = None
|
||||
self._lock = threading.Lock()
|
||||
self._total_downloaded = 0
|
||||
self._total_size = 0
|
||||
self._file_sizes = {} # Track sizes of individual files
|
||||
self._file_downloaded = {} # Track downloaded bytes per file
|
||||
self._current_filename = ""
|
||||
self._active_tqdms = {} # Track active tqdm instances
|
||||
self._hf_tqdm_original_update = None # For monkey-patching hf's tqdm
|
||||
|
||||
def _create_tracked_tqdm_class(self):
|
||||
"""Create a tqdm subclass that tracks progress."""
|
||||
tracker = self
|
||||
original_tqdm = self._original_tqdm_class
|
||||
|
||||
class TrackedTqdm(original_tqdm):
|
||||
"""A tqdm subclass that reports progress to our tracker."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Extract filename from desc before passing to parent
|
||||
desc = kwargs.get("desc", "")
|
||||
if not desc and args:
|
||||
first_arg = args[0]
|
||||
if isinstance(first_arg, str):
|
||||
desc = first_arg
|
||||
|
||||
filename = ""
|
||||
if desc:
|
||||
# Try to extract filename from description
|
||||
# HuggingFace Hub uses format like "model.safetensors: 0%|..."
|
||||
if ":" in desc:
|
||||
filename = desc.split(":")[0].strip()
|
||||
else:
|
||||
filename = desc.strip()
|
||||
|
||||
# Filter out non-standard kwargs that huggingface_hub might pass
|
||||
# These are custom kwargs that tqdm doesn't understand
|
||||
filtered_kwargs = {}
|
||||
# Known tqdm kwargs - pass these through
|
||||
tqdm_kwargs = {
|
||||
"iterable",
|
||||
"desc",
|
||||
"total",
|
||||
"leave",
|
||||
"file",
|
||||
"ncols",
|
||||
"mininterval",
|
||||
"maxinterval",
|
||||
"miniters",
|
||||
"ascii",
|
||||
"disable",
|
||||
"unit",
|
||||
"unit_scale",
|
||||
"dynamic_ncols",
|
||||
"smoothing",
|
||||
"bar_format",
|
||||
"initial",
|
||||
"position",
|
||||
"postfix",
|
||||
"unit_divisor",
|
||||
"write_bytes",
|
||||
"lock_args",
|
||||
"nrows",
|
||||
"colour",
|
||||
"color",
|
||||
"delay",
|
||||
"gui",
|
||||
"disable_default",
|
||||
"pos",
|
||||
}
|
||||
for key, value in kwargs.items():
|
||||
if key in tqdm_kwargs:
|
||||
filtered_kwargs[key] = value
|
||||
|
||||
# Force-enable the progress bar — we're tracking progress ourselves,
|
||||
# we don't need tqdm to render to a terminal, but we DO need
|
||||
# self.n to be updated when update() is called.
|
||||
filtered_kwargs["disable"] = False
|
||||
|
||||
# Try to initialize with filtered kwargs, fall back to all kwargs if that fails
|
||||
try:
|
||||
super().__init__(*args, **filtered_kwargs)
|
||||
except TypeError:
|
||||
# If filtering failed, try with all kwargs (maybe tqdm version accepts them)
|
||||
kwargs["disable"] = False
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._tracker_filename = filename or "unknown"
|
||||
|
||||
with tracker._lock:
|
||||
if filename:
|
||||
tracker._current_filename = filename
|
||||
tracker._active_tqdms[id(self)] = {
|
||||
"filename": self._tracker_filename,
|
||||
}
|
||||
|
||||
def update(self, n=1):
|
||||
result = super().update(n)
|
||||
|
||||
# Report progress
|
||||
with tracker._lock:
|
||||
if id(self) in tracker._active_tqdms:
|
||||
filename = tracker._active_tqdms[id(self)]["filename"]
|
||||
current = getattr(self, "n", 0)
|
||||
total = getattr(self, "total", 0)
|
||||
|
||||
if total and total > 0:
|
||||
# Always filter out non-byte progress bars (e.g., "Fetching 12 files")
|
||||
# These cause crazy percentages because they're counting files, not bytes
|
||||
if self._is_non_byte_progress(filename):
|
||||
return result
|
||||
|
||||
# When model is cached, also filter out generation-related progress
|
||||
if tracker.filter_non_downloads:
|
||||
if not self._is_download_progress(filename):
|
||||
return result
|
||||
|
||||
# Update per-file tracking
|
||||
tracker._file_sizes[filename] = total
|
||||
tracker._file_downloaded[filename] = current
|
||||
|
||||
# Calculate totals across all files
|
||||
tracker._total_size = sum(tracker._file_sizes.values())
|
||||
tracker._total_downloaded = sum(tracker._file_downloaded.values())
|
||||
|
||||
# Only report progress once we have a meaningful total (at least 1MB)
|
||||
# This avoids the "100% at 0MB" issue when small config
|
||||
# files are counted before the real model files
|
||||
MIN_TOTAL_BYTES = 1_000_000 # 1MB
|
||||
if tracker._total_size < MIN_TOTAL_BYTES:
|
||||
return result
|
||||
|
||||
# Call progress callback
|
||||
if tracker.progress_callback:
|
||||
tracker.progress_callback(tracker._total_downloaded, tracker._total_size, filename)
|
||||
|
||||
return result
|
||||
|
||||
def _is_non_byte_progress(self, filename: str) -> bool:
|
||||
"""Check if this progress bar should be SKIPPED (returns True to skip).
|
||||
|
||||
We want to track byte-based progress bars. This method identifies
|
||||
progress bars that count files/items instead of bytes, which would
|
||||
cause crazy percentages if mixed with our byte counting.
|
||||
|
||||
Returns:
|
||||
True = SKIP this bar (it's not byte-based)
|
||||
False = TRACK this bar (it counts bytes)
|
||||
"""
|
||||
if not filename:
|
||||
return False
|
||||
|
||||
filename_lower = filename.lower()
|
||||
|
||||
# Skip "Fetching X files" - it counts files (total=12), not bytes
|
||||
# Don't skip "Downloading (incomplete total...)" - that IS byte-based
|
||||
skip_patterns = [
|
||||
"fetching", # "Fetching 12 files" has total=12 files, not bytes
|
||||
]
|
||||
return any(pattern in filename_lower for pattern in skip_patterns)
|
||||
|
||||
def _is_download_progress(self, filename: str) -> bool:
|
||||
"""Check if this is a real file download progress bar vs internal processing."""
|
||||
if not filename or filename == "unknown":
|
||||
return False
|
||||
|
||||
# Real downloads have file extensions
|
||||
download_extensions = [
|
||||
".safetensors",
|
||||
".bin",
|
||||
".pt",
|
||||
".pth", # Model weights
|
||||
".json",
|
||||
".txt",
|
||||
".py", # Config files
|
||||
".msgpack",
|
||||
".h5", # Other formats
|
||||
]
|
||||
|
||||
filename_lower = filename.lower()
|
||||
has_extension = any(filename_lower.endswith(ext) for ext in download_extensions)
|
||||
|
||||
# Skip generation-related progress indicators
|
||||
skip_patterns = ["segment", "processing", "generating", "loading"]
|
||||
has_skip_pattern = any(pattern in filename_lower for pattern in skip_patterns)
|
||||
|
||||
return has_extension and not has_skip_pattern
|
||||
|
||||
def close(self):
|
||||
with tracker._lock:
|
||||
if id(self) in tracker._active_tqdms:
|
||||
del tracker._active_tqdms[id(self)]
|
||||
return super().close()
|
||||
|
||||
return TrackedTqdm
|
||||
|
||||
@contextmanager
|
||||
def patch_download(self):
|
||||
"""Context manager to patch tqdm for progress tracking."""
|
||||
try:
|
||||
import tqdm as tqdm_module
|
||||
|
||||
# Store original tqdm class
|
||||
self._original_tqdm_class = tqdm_module.tqdm
|
||||
|
||||
# Reset totals
|
||||
with self._lock:
|
||||
self._total_downloaded = 0
|
||||
self._total_size = 0
|
||||
self._file_sizes = {}
|
||||
self._file_downloaded = {}
|
||||
self._current_filename = ""
|
||||
self._active_tqdms = {}
|
||||
|
||||
# Create our tracked tqdm class
|
||||
tracked_tqdm = self._create_tracked_tqdm_class()
|
||||
|
||||
# Patch tqdm.tqdm
|
||||
tqdm_module.tqdm = tracked_tqdm
|
||||
|
||||
# Also patch tqdm.auto.tqdm if it exists (used by huggingface_hub)
|
||||
self._original_tqdm_auto = None
|
||||
if hasattr(tqdm_module, "auto") and hasattr(tqdm_module.auto, "tqdm"):
|
||||
self._original_tqdm_auto = tqdm_module.auto.tqdm
|
||||
tqdm_module.auto.tqdm = tracked_tqdm
|
||||
|
||||
# Patch in sys.modules to catch already-imported references
|
||||
# huggingface_hub uses: from tqdm.auto import tqdm as base_tqdm
|
||||
# So we need to patch both 'tqdm' and 'base_tqdm' attributes
|
||||
self._patched_modules = {}
|
||||
tqdm_attr_names = ["tqdm", "base_tqdm", "old_tqdm"] # Various names used
|
||||
|
||||
patched_count = 0
|
||||
for module_name in list(sys.modules.keys()):
|
||||
if "huggingface" in module_name or module_name.startswith("tqdm"):
|
||||
try:
|
||||
module = sys.modules[module_name]
|
||||
for attr_name in tqdm_attr_names:
|
||||
if hasattr(module, attr_name):
|
||||
attr = getattr(module, attr_name)
|
||||
# Only patch if it's a tqdm class (not already patched)
|
||||
is_tqdm_class = (
|
||||
attr is self._original_tqdm_class
|
||||
or (self._original_tqdm_auto and attr is self._original_tqdm_auto)
|
||||
or (
|
||||
hasattr(attr, "__name__")
|
||||
and attr.__name__ == "tqdm"
|
||||
and hasattr(attr, "update")
|
||||
) # tqdm classes have update method
|
||||
)
|
||||
if is_tqdm_class:
|
||||
key = f"{module_name}.{attr_name}"
|
||||
self._patched_modules[key] = (module, attr_name, attr)
|
||||
setattr(module, attr_name, tracked_tqdm)
|
||||
patched_count += 1
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
|
||||
# ALSO monkey-patch the update method on huggingface_hub's tqdm class
|
||||
# This is needed because the class was already defined at import time
|
||||
self._hf_tqdm_original_update = None
|
||||
try:
|
||||
from huggingface_hub.utils import tqdm as hf_tqdm_module
|
||||
|
||||
if hasattr(hf_tqdm_module, "tqdm"):
|
||||
hf_tqdm_class = hf_tqdm_module.tqdm
|
||||
self._hf_tqdm_original_update = hf_tqdm_class.update
|
||||
|
||||
# Create a wrapper that calls our tracking
|
||||
tracker = self # Reference to HFProgressTracker instance
|
||||
|
||||
def patched_update(tqdm_self, n=1):
|
||||
result = tracker._hf_tqdm_original_update(tqdm_self, n)
|
||||
|
||||
# Track this progress
|
||||
with tracker._lock:
|
||||
desc = getattr(tqdm_self, "desc", "") or ""
|
||||
current = getattr(tqdm_self, "n", 0)
|
||||
total = getattr(tqdm_self, "total", 0) or 0
|
||||
|
||||
# Skip non-byte progress bars
|
||||
if "fetching" in desc.lower():
|
||||
return result
|
||||
|
||||
# Skip until we have a meaningful total (at least 1MB)
|
||||
# This avoids the "100% at 0MB" issue when small config
|
||||
# files are counted before the real model files
|
||||
MIN_TOTAL_BYTES = 1_000_000 # 1MB
|
||||
if total >= MIN_TOTAL_BYTES:
|
||||
tracker._total_downloaded = current
|
||||
tracker._total_size = total
|
||||
|
||||
if tracker.progress_callback:
|
||||
tracker.progress_callback(current, total, desc)
|
||||
|
||||
return result
|
||||
|
||||
hf_tqdm_class.update = patched_update
|
||||
patched_count += 1
|
||||
logger.debug("Monkey-patched huggingface_hub.utils.tqdm.tqdm.update")
|
||||
except (ImportError, AttributeError) as e:
|
||||
logger.warning("Could not monkey-patch hf_tqdm: %s", e)
|
||||
|
||||
logger.debug("Patched %d tqdm references", patched_count)
|
||||
|
||||
yield
|
||||
|
||||
except ImportError:
|
||||
# If tqdm not available, just yield without patching
|
||||
yield
|
||||
finally:
|
||||
# Restore original tqdm
|
||||
if self._original_tqdm_class:
|
||||
try:
|
||||
import tqdm as tqdm_module
|
||||
|
||||
tqdm_module.tqdm = self._original_tqdm_class
|
||||
|
||||
if self._original_tqdm_auto:
|
||||
tqdm_module.auto.tqdm = self._original_tqdm_auto
|
||||
|
||||
# Restore patched modules
|
||||
for key, (module, attr_name, original) in self._patched_modules.items():
|
||||
try:
|
||||
if module and original:
|
||||
setattr(module, attr_name, original)
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
self._patched_modules = {}
|
||||
|
||||
# Restore hf_tqdm's original update method
|
||||
if self._hf_tqdm_original_update:
|
||||
try:
|
||||
from huggingface_hub.utils import tqdm as hf_tqdm_module
|
||||
|
||||
if hasattr(hf_tqdm_module, "tqdm"):
|
||||
hf_tqdm_module.tqdm.update = self._hf_tqdm_original_update
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
self._hf_tqdm_original_update = None
|
||||
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
|
||||
def create_hf_progress_callback(model_name: str, progress_manager):
|
||||
"""Create a progress callback for HuggingFace downloads."""
|
||||
|
||||
def callback(downloaded: int, total: int, filename: str = ""):
|
||||
"""Progress callback.
|
||||
|
||||
Note: We send updates even when total=0 (unknown) to provide feedback
|
||||
during the "incomplete total" phase of huggingface_hub downloads.
|
||||
The frontend handles total=0 gracefully.
|
||||
"""
|
||||
progress_manager.update_progress(
|
||||
model_name=model_name,
|
||||
current=downloaded,
|
||||
total=total,
|
||||
filename=filename or "",
|
||||
status="downloading",
|
||||
)
|
||||
|
||||
return callback
|
||||
Reference in New Issue
Block a user