384 lines
16 KiB
Python
384 lines
16 KiB
Python
"""
|
|
HuggingFace Hub download progress tracking.
|
|
"""
|
|
|
|
from typing import Optional, Callable
|
|
from contextlib import contextmanager
|
|
import logging
|
|
import threading
|
|
import sys
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class HFProgressTracker:
|
|
"""Tracks HuggingFace Hub download progress by intercepting tqdm."""
|
|
|
|
def __init__(self, progress_callback: Optional[Callable] = None, filter_non_downloads: bool = False):
|
|
self.progress_callback = progress_callback
|
|
self.filter_non_downloads = filter_non_downloads # Only filter if True
|
|
self._original_tqdm_class = None
|
|
self._lock = threading.Lock()
|
|
self._total_downloaded = 0
|
|
self._total_size = 0
|
|
self._file_sizes = {} # Track sizes of individual files
|
|
self._file_downloaded = {} # Track downloaded bytes per file
|
|
self._current_filename = ""
|
|
self._active_tqdms = {} # Track active tqdm instances
|
|
self._hf_tqdm_original_update = None # For monkey-patching hf's tqdm
|
|
|
|
def _create_tracked_tqdm_class(self):
|
|
"""Create a tqdm subclass that tracks progress."""
|
|
tracker = self
|
|
original_tqdm = self._original_tqdm_class
|
|
|
|
class TrackedTqdm(original_tqdm):
|
|
"""A tqdm subclass that reports progress to our tracker."""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
# Extract filename from desc before passing to parent
|
|
desc = kwargs.get("desc", "")
|
|
if not desc and args:
|
|
first_arg = args[0]
|
|
if isinstance(first_arg, str):
|
|
desc = first_arg
|
|
|
|
filename = ""
|
|
if desc:
|
|
# Try to extract filename from description
|
|
# HuggingFace Hub uses format like "model.safetensors: 0%|..."
|
|
if ":" in desc:
|
|
filename = desc.split(":")[0].strip()
|
|
else:
|
|
filename = desc.strip()
|
|
|
|
# Filter out non-standard kwargs that huggingface_hub might pass
|
|
# These are custom kwargs that tqdm doesn't understand
|
|
filtered_kwargs = {}
|
|
# Known tqdm kwargs - pass these through
|
|
tqdm_kwargs = {
|
|
"iterable",
|
|
"desc",
|
|
"total",
|
|
"leave",
|
|
"file",
|
|
"ncols",
|
|
"mininterval",
|
|
"maxinterval",
|
|
"miniters",
|
|
"ascii",
|
|
"disable",
|
|
"unit",
|
|
"unit_scale",
|
|
"dynamic_ncols",
|
|
"smoothing",
|
|
"bar_format",
|
|
"initial",
|
|
"position",
|
|
"postfix",
|
|
"unit_divisor",
|
|
"write_bytes",
|
|
"lock_args",
|
|
"nrows",
|
|
"colour",
|
|
"color",
|
|
"delay",
|
|
"gui",
|
|
"disable_default",
|
|
"pos",
|
|
}
|
|
for key, value in kwargs.items():
|
|
if key in tqdm_kwargs:
|
|
filtered_kwargs[key] = value
|
|
|
|
# Force-enable the progress bar — we're tracking progress ourselves,
|
|
# we don't need tqdm to render to a terminal, but we DO need
|
|
# self.n to be updated when update() is called.
|
|
filtered_kwargs["disable"] = False
|
|
|
|
# Try to initialize with filtered kwargs, fall back to all kwargs if that fails
|
|
try:
|
|
super().__init__(*args, **filtered_kwargs)
|
|
except TypeError:
|
|
# If filtering failed, try with all kwargs (maybe tqdm version accepts them)
|
|
kwargs["disable"] = False
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self._tracker_filename = filename or "unknown"
|
|
|
|
with tracker._lock:
|
|
if filename:
|
|
tracker._current_filename = filename
|
|
tracker._active_tqdms[id(self)] = {
|
|
"filename": self._tracker_filename,
|
|
}
|
|
|
|
def update(self, n=1):
|
|
result = super().update(n)
|
|
|
|
# Report progress
|
|
with tracker._lock:
|
|
if id(self) in tracker._active_tqdms:
|
|
filename = tracker._active_tqdms[id(self)]["filename"]
|
|
current = getattr(self, "n", 0)
|
|
total = getattr(self, "total", 0)
|
|
|
|
if total and total > 0:
|
|
# Always filter out non-byte progress bars (e.g., "Fetching 12 files")
|
|
# These cause crazy percentages because they're counting files, not bytes
|
|
if self._is_non_byte_progress(filename):
|
|
return result
|
|
|
|
# When model is cached, also filter out generation-related progress
|
|
if tracker.filter_non_downloads:
|
|
if not self._is_download_progress(filename):
|
|
return result
|
|
|
|
# Update per-file tracking
|
|
tracker._file_sizes[filename] = total
|
|
tracker._file_downloaded[filename] = current
|
|
|
|
# Calculate totals across all files
|
|
tracker._total_size = sum(tracker._file_sizes.values())
|
|
tracker._total_downloaded = sum(tracker._file_downloaded.values())
|
|
|
|
# Only report progress once we have a meaningful total (at least 1MB)
|
|
# This avoids the "100% at 0MB" issue when small config
|
|
# files are counted before the real model files
|
|
MIN_TOTAL_BYTES = 1_000_000 # 1MB
|
|
if tracker._total_size < MIN_TOTAL_BYTES:
|
|
return result
|
|
|
|
# Call progress callback
|
|
if tracker.progress_callback:
|
|
tracker.progress_callback(tracker._total_downloaded, tracker._total_size, filename)
|
|
|
|
return result
|
|
|
|
def _is_non_byte_progress(self, filename: str) -> bool:
|
|
"""Check if this progress bar should be SKIPPED (returns True to skip).
|
|
|
|
We want to track byte-based progress bars. This method identifies
|
|
progress bars that count files/items instead of bytes, which would
|
|
cause crazy percentages if mixed with our byte counting.
|
|
|
|
Returns:
|
|
True = SKIP this bar (it's not byte-based)
|
|
False = TRACK this bar (it counts bytes)
|
|
"""
|
|
if not filename:
|
|
return False
|
|
|
|
filename_lower = filename.lower()
|
|
|
|
# Skip "Fetching X files" - it counts files (total=12), not bytes
|
|
# Don't skip "Downloading (incomplete total...)" - that IS byte-based
|
|
skip_patterns = [
|
|
"fetching", # "Fetching 12 files" has total=12 files, not bytes
|
|
]
|
|
return any(pattern in filename_lower for pattern in skip_patterns)
|
|
|
|
def _is_download_progress(self, filename: str) -> bool:
|
|
"""Check if this is a real file download progress bar vs internal processing."""
|
|
if not filename or filename == "unknown":
|
|
return False
|
|
|
|
# Real downloads have file extensions
|
|
download_extensions = [
|
|
".safetensors",
|
|
".bin",
|
|
".pt",
|
|
".pth", # Model weights
|
|
".json",
|
|
".txt",
|
|
".py", # Config files
|
|
".msgpack",
|
|
".h5", # Other formats
|
|
]
|
|
|
|
filename_lower = filename.lower()
|
|
has_extension = any(filename_lower.endswith(ext) for ext in download_extensions)
|
|
|
|
# Skip generation-related progress indicators
|
|
skip_patterns = ["segment", "processing", "generating", "loading"]
|
|
has_skip_pattern = any(pattern in filename_lower for pattern in skip_patterns)
|
|
|
|
return has_extension and not has_skip_pattern
|
|
|
|
def close(self):
|
|
with tracker._lock:
|
|
if id(self) in tracker._active_tqdms:
|
|
del tracker._active_tqdms[id(self)]
|
|
return super().close()
|
|
|
|
return TrackedTqdm
|
|
|
|
@contextmanager
|
|
def patch_download(self):
|
|
"""Context manager to patch tqdm for progress tracking."""
|
|
try:
|
|
import tqdm as tqdm_module
|
|
|
|
# Store original tqdm class
|
|
self._original_tqdm_class = tqdm_module.tqdm
|
|
|
|
# Reset totals
|
|
with self._lock:
|
|
self._total_downloaded = 0
|
|
self._total_size = 0
|
|
self._file_sizes = {}
|
|
self._file_downloaded = {}
|
|
self._current_filename = ""
|
|
self._active_tqdms = {}
|
|
|
|
# Create our tracked tqdm class
|
|
tracked_tqdm = self._create_tracked_tqdm_class()
|
|
|
|
# Patch tqdm.tqdm
|
|
tqdm_module.tqdm = tracked_tqdm
|
|
|
|
# Also patch tqdm.auto.tqdm if it exists (used by huggingface_hub)
|
|
self._original_tqdm_auto = None
|
|
if hasattr(tqdm_module, "auto") and hasattr(tqdm_module.auto, "tqdm"):
|
|
self._original_tqdm_auto = tqdm_module.auto.tqdm
|
|
tqdm_module.auto.tqdm = tracked_tqdm
|
|
|
|
# Patch in sys.modules to catch already-imported references
|
|
# huggingface_hub uses: from tqdm.auto import tqdm as base_tqdm
|
|
# So we need to patch both 'tqdm' and 'base_tqdm' attributes
|
|
self._patched_modules = {}
|
|
tqdm_attr_names = ["tqdm", "base_tqdm", "old_tqdm"] # Various names used
|
|
|
|
patched_count = 0
|
|
for module_name in list(sys.modules.keys()):
|
|
if "huggingface" in module_name or module_name.startswith("tqdm"):
|
|
try:
|
|
module = sys.modules[module_name]
|
|
for attr_name in tqdm_attr_names:
|
|
if hasattr(module, attr_name):
|
|
attr = getattr(module, attr_name)
|
|
# Only patch if it's a tqdm class (not already patched)
|
|
is_tqdm_class = (
|
|
attr is self._original_tqdm_class
|
|
or (self._original_tqdm_auto and attr is self._original_tqdm_auto)
|
|
or (
|
|
hasattr(attr, "__name__")
|
|
and attr.__name__ == "tqdm"
|
|
and hasattr(attr, "update")
|
|
) # tqdm classes have update method
|
|
)
|
|
if is_tqdm_class:
|
|
key = f"{module_name}.{attr_name}"
|
|
self._patched_modules[key] = (module, attr_name, attr)
|
|
setattr(module, attr_name, tracked_tqdm)
|
|
patched_count += 1
|
|
except (AttributeError, TypeError):
|
|
pass
|
|
|
|
# ALSO monkey-patch the update method on huggingface_hub's tqdm class
|
|
# This is needed because the class was already defined at import time
|
|
self._hf_tqdm_original_update = None
|
|
try:
|
|
from huggingface_hub.utils import tqdm as hf_tqdm_module
|
|
|
|
if hasattr(hf_tqdm_module, "tqdm"):
|
|
hf_tqdm_class = hf_tqdm_module.tqdm
|
|
self._hf_tqdm_original_update = hf_tqdm_class.update
|
|
|
|
# Create a wrapper that calls our tracking
|
|
tracker = self # Reference to HFProgressTracker instance
|
|
|
|
def patched_update(tqdm_self, n=1):
|
|
result = tracker._hf_tqdm_original_update(tqdm_self, n)
|
|
|
|
# Track this progress
|
|
with tracker._lock:
|
|
desc = getattr(tqdm_self, "desc", "") or ""
|
|
current = getattr(tqdm_self, "n", 0)
|
|
total = getattr(tqdm_self, "total", 0) or 0
|
|
|
|
# Skip non-byte progress bars
|
|
if "fetching" in desc.lower():
|
|
return result
|
|
|
|
# Skip until we have a meaningful total (at least 1MB)
|
|
# This avoids the "100% at 0MB" issue when small config
|
|
# files are counted before the real model files
|
|
MIN_TOTAL_BYTES = 1_000_000 # 1MB
|
|
if total >= MIN_TOTAL_BYTES:
|
|
tracker._total_downloaded = current
|
|
tracker._total_size = total
|
|
|
|
if tracker.progress_callback:
|
|
tracker.progress_callback(current, total, desc)
|
|
|
|
return result
|
|
|
|
hf_tqdm_class.update = patched_update
|
|
patched_count += 1
|
|
logger.debug("Monkey-patched huggingface_hub.utils.tqdm.tqdm.update")
|
|
except (ImportError, AttributeError) as e:
|
|
logger.warning("Could not monkey-patch hf_tqdm: %s", e)
|
|
|
|
logger.debug("Patched %d tqdm references", patched_count)
|
|
|
|
yield
|
|
|
|
except ImportError:
|
|
# If tqdm not available, just yield without patching
|
|
yield
|
|
finally:
|
|
# Restore original tqdm
|
|
if self._original_tqdm_class:
|
|
try:
|
|
import tqdm as tqdm_module
|
|
|
|
tqdm_module.tqdm = self._original_tqdm_class
|
|
|
|
if self._original_tqdm_auto:
|
|
tqdm_module.auto.tqdm = self._original_tqdm_auto
|
|
|
|
# Restore patched modules
|
|
for key, (module, attr_name, original) in self._patched_modules.items():
|
|
try:
|
|
if module and original:
|
|
setattr(module, attr_name, original)
|
|
except (AttributeError, TypeError):
|
|
pass
|
|
self._patched_modules = {}
|
|
|
|
# Restore hf_tqdm's original update method
|
|
if self._hf_tqdm_original_update:
|
|
try:
|
|
from huggingface_hub.utils import tqdm as hf_tqdm_module
|
|
|
|
if hasattr(hf_tqdm_module, "tqdm"):
|
|
hf_tqdm_module.tqdm.update = self._hf_tqdm_original_update
|
|
except (ImportError, AttributeError):
|
|
pass
|
|
self._hf_tqdm_original_update = None
|
|
|
|
except (ImportError, AttributeError):
|
|
pass
|
|
|
|
|
|
def create_hf_progress_callback(model_name: str, progress_manager):
|
|
"""Create a progress callback for HuggingFace downloads."""
|
|
|
|
def callback(downloaded: int, total: int, filename: str = ""):
|
|
"""Progress callback.
|
|
|
|
Note: We send updates even when total=0 (unknown) to provide feedback
|
|
during the "incomplete total" phase of huggingface_hub downloads.
|
|
The frontend handles total=0 gracefully.
|
|
"""
|
|
progress_manager.update_progress(
|
|
model_name=model_name,
|
|
current=downloaded,
|
|
total=total,
|
|
filename=filename or "",
|
|
status="downloading",
|
|
)
|
|
|
|
return callback
|