""" HuggingFace Hub download progress tracking. """ from typing import Optional, Callable from contextlib import contextmanager import logging import threading import sys logger = logging.getLogger(__name__) class HFProgressTracker: """Tracks HuggingFace Hub download progress by intercepting tqdm.""" def __init__(self, progress_callback: Optional[Callable] = None, filter_non_downloads: bool = False): self.progress_callback = progress_callback self.filter_non_downloads = filter_non_downloads # Only filter if True self._original_tqdm_class = None self._lock = threading.Lock() self._total_downloaded = 0 self._total_size = 0 self._file_sizes = {} # Track sizes of individual files self._file_downloaded = {} # Track downloaded bytes per file self._current_filename = "" self._active_tqdms = {} # Track active tqdm instances self._hf_tqdm_original_update = None # For monkey-patching hf's tqdm def _create_tracked_tqdm_class(self): """Create a tqdm subclass that tracks progress.""" tracker = self original_tqdm = self._original_tqdm_class class TrackedTqdm(original_tqdm): """A tqdm subclass that reports progress to our tracker.""" def __init__(self, *args, **kwargs): # Extract filename from desc before passing to parent desc = kwargs.get("desc", "") if not desc and args: first_arg = args[0] if isinstance(first_arg, str): desc = first_arg filename = "" if desc: # Try to extract filename from description # HuggingFace Hub uses format like "model.safetensors: 0%|..." if ":" in desc: filename = desc.split(":")[0].strip() else: filename = desc.strip() # Filter out non-standard kwargs that huggingface_hub might pass # These are custom kwargs that tqdm doesn't understand filtered_kwargs = {} # Known tqdm kwargs - pass these through tqdm_kwargs = { "iterable", "desc", "total", "leave", "file", "ncols", "mininterval", "maxinterval", "miniters", "ascii", "disable", "unit", "unit_scale", "dynamic_ncols", "smoothing", "bar_format", "initial", "position", "postfix", "unit_divisor", "write_bytes", "lock_args", "nrows", "colour", "color", "delay", "gui", "disable_default", "pos", } for key, value in kwargs.items(): if key in tqdm_kwargs: filtered_kwargs[key] = value # Force-enable the progress bar — we're tracking progress ourselves, # we don't need tqdm to render to a terminal, but we DO need # self.n to be updated when update() is called. filtered_kwargs["disable"] = False # Try to initialize with filtered kwargs, fall back to all kwargs if that fails try: super().__init__(*args, **filtered_kwargs) except TypeError: # If filtering failed, try with all kwargs (maybe tqdm version accepts them) kwargs["disable"] = False super().__init__(*args, **kwargs) self._tracker_filename = filename or "unknown" with tracker._lock: if filename: tracker._current_filename = filename tracker._active_tqdms[id(self)] = { "filename": self._tracker_filename, } def update(self, n=1): result = super().update(n) # Report progress with tracker._lock: if id(self) in tracker._active_tqdms: filename = tracker._active_tqdms[id(self)]["filename"] current = getattr(self, "n", 0) total = getattr(self, "total", 0) if total and total > 0: # Always filter out non-byte progress bars (e.g., "Fetching 12 files") # These cause crazy percentages because they're counting files, not bytes if self._is_non_byte_progress(filename): return result # When model is cached, also filter out generation-related progress if tracker.filter_non_downloads: if not self._is_download_progress(filename): return result # Update per-file tracking tracker._file_sizes[filename] = total tracker._file_downloaded[filename] = current # Calculate totals across all files tracker._total_size = sum(tracker._file_sizes.values()) tracker._total_downloaded = sum(tracker._file_downloaded.values()) # Only report progress once we have a meaningful total (at least 1MB) # This avoids the "100% at 0MB" issue when small config # files are counted before the real model files MIN_TOTAL_BYTES = 1_000_000 # 1MB if tracker._total_size < MIN_TOTAL_BYTES: return result # Call progress callback if tracker.progress_callback: tracker.progress_callback(tracker._total_downloaded, tracker._total_size, filename) return result def _is_non_byte_progress(self, filename: str) -> bool: """Check if this progress bar should be SKIPPED (returns True to skip). We want to track byte-based progress bars. This method identifies progress bars that count files/items instead of bytes, which would cause crazy percentages if mixed with our byte counting. Returns: True = SKIP this bar (it's not byte-based) False = TRACK this bar (it counts bytes) """ if not filename: return False filename_lower = filename.lower() # Skip "Fetching X files" - it counts files (total=12), not bytes # Don't skip "Downloading (incomplete total...)" - that IS byte-based skip_patterns = [ "fetching", # "Fetching 12 files" has total=12 files, not bytes ] return any(pattern in filename_lower for pattern in skip_patterns) def _is_download_progress(self, filename: str) -> bool: """Check if this is a real file download progress bar vs internal processing.""" if not filename or filename == "unknown": return False # Real downloads have file extensions download_extensions = [ ".safetensors", ".bin", ".pt", ".pth", # Model weights ".json", ".txt", ".py", # Config files ".msgpack", ".h5", # Other formats ] filename_lower = filename.lower() has_extension = any(filename_lower.endswith(ext) for ext in download_extensions) # Skip generation-related progress indicators skip_patterns = ["segment", "processing", "generating", "loading"] has_skip_pattern = any(pattern in filename_lower for pattern in skip_patterns) return has_extension and not has_skip_pattern def close(self): with tracker._lock: if id(self) in tracker._active_tqdms: del tracker._active_tqdms[id(self)] return super().close() return TrackedTqdm @contextmanager def patch_download(self): """Context manager to patch tqdm for progress tracking.""" try: import tqdm as tqdm_module # Store original tqdm class self._original_tqdm_class = tqdm_module.tqdm # Reset totals with self._lock: self._total_downloaded = 0 self._total_size = 0 self._file_sizes = {} self._file_downloaded = {} self._current_filename = "" self._active_tqdms = {} # Create our tracked tqdm class tracked_tqdm = self._create_tracked_tqdm_class() # Patch tqdm.tqdm tqdm_module.tqdm = tracked_tqdm # Also patch tqdm.auto.tqdm if it exists (used by huggingface_hub) self._original_tqdm_auto = None if hasattr(tqdm_module, "auto") and hasattr(tqdm_module.auto, "tqdm"): self._original_tqdm_auto = tqdm_module.auto.tqdm tqdm_module.auto.tqdm = tracked_tqdm # Patch in sys.modules to catch already-imported references # huggingface_hub uses: from tqdm.auto import tqdm as base_tqdm # So we need to patch both 'tqdm' and 'base_tqdm' attributes self._patched_modules = {} tqdm_attr_names = ["tqdm", "base_tqdm", "old_tqdm"] # Various names used patched_count = 0 for module_name in list(sys.modules.keys()): if "huggingface" in module_name or module_name.startswith("tqdm"): try: module = sys.modules[module_name] for attr_name in tqdm_attr_names: if hasattr(module, attr_name): attr = getattr(module, attr_name) # Only patch if it's a tqdm class (not already patched) is_tqdm_class = ( attr is self._original_tqdm_class or (self._original_tqdm_auto and attr is self._original_tqdm_auto) or ( hasattr(attr, "__name__") and attr.__name__ == "tqdm" and hasattr(attr, "update") ) # tqdm classes have update method ) if is_tqdm_class: key = f"{module_name}.{attr_name}" self._patched_modules[key] = (module, attr_name, attr) setattr(module, attr_name, tracked_tqdm) patched_count += 1 except (AttributeError, TypeError): pass # ALSO monkey-patch the update method on huggingface_hub's tqdm class # This is needed because the class was already defined at import time self._hf_tqdm_original_update = None try: from huggingface_hub.utils import tqdm as hf_tqdm_module if hasattr(hf_tqdm_module, "tqdm"): hf_tqdm_class = hf_tqdm_module.tqdm self._hf_tqdm_original_update = hf_tqdm_class.update # Create a wrapper that calls our tracking tracker = self # Reference to HFProgressTracker instance def patched_update(tqdm_self, n=1): result = tracker._hf_tqdm_original_update(tqdm_self, n) # Track this progress with tracker._lock: desc = getattr(tqdm_self, "desc", "") or "" current = getattr(tqdm_self, "n", 0) total = getattr(tqdm_self, "total", 0) or 0 # Skip non-byte progress bars if "fetching" in desc.lower(): return result # Skip until we have a meaningful total (at least 1MB) # This avoids the "100% at 0MB" issue when small config # files are counted before the real model files MIN_TOTAL_BYTES = 1_000_000 # 1MB if total >= MIN_TOTAL_BYTES: tracker._total_downloaded = current tracker._total_size = total if tracker.progress_callback: tracker.progress_callback(current, total, desc) return result hf_tqdm_class.update = patched_update patched_count += 1 logger.debug("Monkey-patched huggingface_hub.utils.tqdm.tqdm.update") except (ImportError, AttributeError) as e: logger.warning("Could not monkey-patch hf_tqdm: %s", e) logger.debug("Patched %d tqdm references", patched_count) yield except ImportError: # If tqdm not available, just yield without patching yield finally: # Restore original tqdm if self._original_tqdm_class: try: import tqdm as tqdm_module tqdm_module.tqdm = self._original_tqdm_class if self._original_tqdm_auto: tqdm_module.auto.tqdm = self._original_tqdm_auto # Restore patched modules for key, (module, attr_name, original) in self._patched_modules.items(): try: if module and original: setattr(module, attr_name, original) except (AttributeError, TypeError): pass self._patched_modules = {} # Restore hf_tqdm's original update method if self._hf_tqdm_original_update: try: from huggingface_hub.utils import tqdm as hf_tqdm_module if hasattr(hf_tqdm_module, "tqdm"): hf_tqdm_module.tqdm.update = self._hf_tqdm_original_update except (ImportError, AttributeError): pass self._hf_tqdm_original_update = None except (ImportError, AttributeError): pass def create_hf_progress_callback(model_name: str, progress_manager): """Create a progress callback for HuggingFace downloads.""" def callback(downloaded: int, total: int, filename: str = ""): """Progress callback. Note: We send updates even when total=0 (unknown) to provide feedback during the "incomplete total" phase of huggingface_hub downloads. The frontend handles total=0 gracefully. """ progress_manager.update_progress( model_name=model_name, current=downloaded, total=total, filename=filename or "", status="downloading", ) return callback