Initial commit
This commit is contained in:
270
backend/utils/hf_offline_patch.py
Normal file
270
backend/utils/hf_offline_patch.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""Monkey-patch huggingface_hub to force offline mode with cached models.
|
||||
|
||||
Prevents mlx_audio / transformers from making network requests when models
|
||||
are already downloaded. Must be imported BEFORE mlx_audio.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# huggingface_hub reads ``HF_HUB_OFFLINE`` once at import time into
|
||||
# ``huggingface_hub.constants.HF_HUB_OFFLINE``; transformers mirrors that into
|
||||
# ``transformers.utils.hub._is_offline_mode`` at *its* import time. Toggling
|
||||
# ``os.environ`` after either module is imported does not flip those cached
|
||||
# bools, and the hot paths (``_http._default_backend_factory``,
|
||||
# ``transformers.utils.hub.is_offline_mode``) read the bools — not the env.
|
||||
# We mutate the cached constants directly, guarded by a refcount so
|
||||
# concurrent inference threads share a single offline window safely.
|
||||
|
||||
_offline_lock = threading.RLock()
|
||||
_offline_refcount = 0
|
||||
_saved_env: Optional[str] = None
|
||||
_saved_hf_const: Optional[bool] = None
|
||||
_saved_transformers_const: Optional[bool] = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def force_offline_if_cached(is_cached: bool, model_label: str = ""):
|
||||
"""Force offline mode for the duration of a cached-model operation.
|
||||
|
||||
Flips ``HF_HUB_OFFLINE`` in the process env **and** in the cached bools
|
||||
inside ``huggingface_hub.constants`` / ``transformers.utils.hub`` so HTTP
|
||||
adapters and offline-mode checks actually see the change. Uses a refcount
|
||||
so multiple concurrent inference threads share a single offline window
|
||||
and the last one to exit restores state.
|
||||
|
||||
If *is_cached* is ``False`` the block runs normally (network allowed).
|
||||
|
||||
Args:
|
||||
is_cached: Whether the model weights are already on disk.
|
||||
model_label: Human-readable name used in log messages.
|
||||
"""
|
||||
if not is_cached:
|
||||
yield
|
||||
return
|
||||
|
||||
global _offline_refcount, _saved_env, _saved_hf_const, _saved_transformers_const
|
||||
|
||||
with _offline_lock:
|
||||
if _offline_refcount == 0:
|
||||
# Snapshot prior state, apply new state, roll back on *any*
|
||||
# failure. Catching only ImportError here would let a partially
|
||||
# broken install (RuntimeError, AttributeError from a half-init
|
||||
# module, etc.) leave the cached HF constants mutated without
|
||||
# bumping the refcount — a persistent offline leak that outlives
|
||||
# the process and is miserable to debug.
|
||||
prev_env = os.environ.get("HF_HUB_OFFLINE")
|
||||
prev_hf: Optional[bool] = None
|
||||
prev_tf: Optional[bool] = None
|
||||
try:
|
||||
try:
|
||||
import huggingface_hub.constants as hf_const
|
||||
|
||||
prev_hf = hf_const.HF_HUB_OFFLINE
|
||||
hf_const.HF_HUB_OFFLINE = True
|
||||
except ImportError:
|
||||
prev_hf = None
|
||||
|
||||
try:
|
||||
import transformers.utils.hub as tf_hub
|
||||
|
||||
prev_tf = tf_hub._is_offline_mode
|
||||
tf_hub._is_offline_mode = True
|
||||
except ImportError:
|
||||
prev_tf = None
|
||||
|
||||
os.environ["HF_HUB_OFFLINE"] = "1"
|
||||
except BaseException:
|
||||
# Roll back whatever we already changed, then re-raise so
|
||||
# the caller sees the real failure.
|
||||
if prev_hf is not None:
|
||||
try:
|
||||
import huggingface_hub.constants as hf_const
|
||||
|
||||
hf_const.HF_HUB_OFFLINE = prev_hf
|
||||
except ImportError:
|
||||
pass
|
||||
if prev_tf is not None:
|
||||
try:
|
||||
import transformers.utils.hub as tf_hub
|
||||
|
||||
tf_hub._is_offline_mode = prev_tf
|
||||
except ImportError:
|
||||
pass
|
||||
if prev_env is not None:
|
||||
os.environ["HF_HUB_OFFLINE"] = prev_env
|
||||
else:
|
||||
os.environ.pop("HF_HUB_OFFLINE", None)
|
||||
raise
|
||||
|
||||
_saved_env = prev_env
|
||||
_saved_hf_const = prev_hf
|
||||
_saved_transformers_const = prev_tf
|
||||
logger.info(
|
||||
"[offline-guard] %s is cached — forcing offline mode",
|
||||
model_label or "model",
|
||||
)
|
||||
_offline_refcount += 1
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
with _offline_lock:
|
||||
_offline_refcount -= 1
|
||||
if _offline_refcount == 0:
|
||||
if _saved_env is not None:
|
||||
os.environ["HF_HUB_OFFLINE"] = _saved_env
|
||||
else:
|
||||
os.environ.pop("HF_HUB_OFFLINE", None)
|
||||
if _saved_hf_const is not None:
|
||||
try:
|
||||
import huggingface_hub.constants as hf_const
|
||||
|
||||
hf_const.HF_HUB_OFFLINE = _saved_hf_const
|
||||
except ImportError:
|
||||
pass
|
||||
if _saved_transformers_const is not None:
|
||||
try:
|
||||
import transformers.utils.hub as tf_hub
|
||||
|
||||
tf_hub._is_offline_mode = _saved_transformers_const
|
||||
except ImportError:
|
||||
pass
|
||||
_saved_env = None
|
||||
_saved_hf_const = None
|
||||
_saved_transformers_const = None
|
||||
|
||||
|
||||
_mistral_regex_patched = False
|
||||
|
||||
|
||||
def patch_transformers_mistral_regex():
|
||||
"""Make transformers' tokenizer load robust to HuggingFace metadata failures.
|
||||
|
||||
transformers 4.57.x added ``PreTrainedTokenizerBase._patch_mistral_regex``
|
||||
which unconditionally calls ``huggingface_hub.model_info(repo_id)`` during
|
||||
every non-local tokenizer load to check whether the model is a Mistral
|
||||
variant. That call raises on ``HF_HUB_OFFLINE=1`` and on plain network
|
||||
failures, killing unrelated loads (Qwen TTS, TADA, etc.).
|
||||
|
||||
Voicebox never loads Mistral models, so the rewrite the function would
|
||||
apply is a no-op for us anyway. Wrap the method so any exception from the
|
||||
metadata lookup returns the tokenizer unchanged — matching the success-path
|
||||
behavior for non-Mistral repos (transformers 4.57.3,
|
||||
``tokenization_utils_base.py:2503``).
|
||||
"""
|
||||
global _mistral_regex_patched
|
||||
if _mistral_regex_patched:
|
||||
return
|
||||
|
||||
try:
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
except ImportError:
|
||||
logger.debug("transformers not available, skipping mistral-regex patch")
|
||||
return
|
||||
|
||||
original = getattr(PreTrainedTokenizerBase, "_patch_mistral_regex", None)
|
||||
if original is None:
|
||||
logger.debug(
|
||||
"transformers has no _patch_mistral_regex attribute, skipping patch",
|
||||
)
|
||||
return
|
||||
|
||||
def safe_patch_mistral_regex(cls, tokenizer, pretrained_model_name_or_path, *args, **kwargs):
|
||||
try:
|
||||
return original(tokenizer, pretrained_model_name_or_path, *args, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"[mistral-regex-patch] suppressed %s for %r, returning tokenizer as-is",
|
||||
type(exc).__name__,
|
||||
pretrained_model_name_or_path,
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
PreTrainedTokenizerBase._patch_mistral_regex = classmethod(safe_patch_mistral_regex)
|
||||
_mistral_regex_patched = True
|
||||
logger.debug("installed _patch_mistral_regex wrapper")
|
||||
|
||||
|
||||
def patch_huggingface_hub_offline():
|
||||
"""Monkey-patch huggingface_hub to force offline mode."""
|
||||
try:
|
||||
import huggingface_hub # noqa: F401 -- need the package loaded
|
||||
from huggingface_hub import constants as hf_constants
|
||||
from huggingface_hub.file_download import _try_to_load_from_cache
|
||||
|
||||
original_try_load = _try_to_load_from_cache
|
||||
|
||||
def _patched_try_to_load_from_cache(
|
||||
repo_id: str,
|
||||
filename: str,
|
||||
cache_dir: Union[str, Path, None] = None,
|
||||
revision: Optional[str] = None,
|
||||
repo_type: Optional[str] = None,
|
||||
):
|
||||
result = original_try_load(
|
||||
repo_id=repo_id,
|
||||
filename=filename,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
repo_type=repo_type,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
cache_path = Path(hf_constants.HF_HUB_CACHE) / f"models--{repo_id.replace('/', '--')}"
|
||||
logger.debug("file not cached: %s/%s (expected at %s)", repo_id, filename, cache_path)
|
||||
else:
|
||||
logger.debug("cache hit: %s/%s", repo_id, filename)
|
||||
|
||||
return result
|
||||
|
||||
import huggingface_hub.file_download as fd
|
||||
|
||||
fd._try_to_load_from_cache = _patched_try_to_load_from_cache
|
||||
logger.debug("huggingface_hub patched for offline mode")
|
||||
|
||||
except ImportError:
|
||||
logger.debug("huggingface_hub not available, skipping offline patch")
|
||||
except Exception:
|
||||
logger.exception("failed to patch huggingface_hub for offline mode")
|
||||
|
||||
|
||||
def ensure_original_qwen_config_cached():
|
||||
"""Symlink the original Qwen repo cache to the MLX community version.
|
||||
|
||||
mlx_audio may try to fetch config from the original Qwen repo. If only
|
||||
the MLX community variant is cached, create a symlink so the cache lookup
|
||||
succeeds without a network request.
|
||||
"""
|
||||
try:
|
||||
from huggingface_hub import constants as hf_constants
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
original_repo = "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
|
||||
mlx_repo = "mlx-community/Qwen3-TTS-12Hz-1.7B-Base-bf16"
|
||||
|
||||
cache_dir = Path(hf_constants.HF_HUB_CACHE)
|
||||
original_path = cache_dir / f"models--{original_repo.replace('/', '--')}"
|
||||
mlx_path = cache_dir / f"models--{mlx_repo.replace('/', '--')}"
|
||||
|
||||
if not original_path.exists() and mlx_path.exists():
|
||||
try:
|
||||
original_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
original_path.symlink_to(mlx_path, target_is_directory=True)
|
||||
logger.info("created cache symlink: %s -> %s", original_repo, mlx_repo)
|
||||
except Exception:
|
||||
logger.warning("could not create cache symlink for %s", original_repo, exc_info=True)
|
||||
|
||||
|
||||
if os.environ.get("VOICEBOX_OFFLINE_PATCH", "1") != "0":
|
||||
patch_huggingface_hub_offline()
|
||||
patch_transformers_mistral_regex()
|
||||
ensure_original_qwen_config_cached()
|
||||
Reference in New Issue
Block a user