diff --git a/modules/dat_model.py b/modules/dat_model.py index 495d5f493..298d160d1 100644 --- a/modules/dat_model.py +++ b/modules/dat_model.py @@ -49,7 +49,18 @@ class UpscalerDAT(Upscaler): scaler.local_data_path = modelloader.load_file_from_url( scaler.data_path, model_dir=self.model_download_path, + hash_prefix=scaler.sha256, ) + + if os.path.getsize(scaler.local_data_path) < 200: + # Re-download if the file is too small, probably an LFS pointer + scaler.local_data_path = modelloader.load_file_from_url( + scaler.data_path, + model_dir=self.model_download_path, + hash_prefix=scaler.sha256, + re_download=True, + ) + if not os.path.exists(scaler.local_data_path): raise FileNotFoundError(f"DAT data missing: {scaler.local_data_path}") return scaler @@ -60,20 +71,23 @@ def get_dat_models(scaler): return [ UpscalerData( name="DAT x2", - path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x2.pth", + path="https://huggingface.co/w-e-w/DAT/resolve/main/experiments/pretrained_models/DAT/DAT_x2.pth", scale=2, upscaler=scaler, + sha256='7760aa96e4ee77e29d4f89c3a4486200042e019461fdb8aa286f49aa00b89b51', ), UpscalerData( name="DAT x3", - path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x3.pth", + path="https://huggingface.co/w-e-w/DAT/resolve/main/experiments/pretrained_models/DAT/DAT_x3.pth", scale=3, upscaler=scaler, + sha256='581973e02c06f90d4eb90acf743ec9604f56f3c2c6f9e1e2c2b38ded1f80d197', ), UpscalerData( name="DAT x4", - path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x4.pth", + path="https://huggingface.co/w-e-w/DAT/resolve/main/experiments/pretrained_models/DAT/DAT_x4.pth", scale=4, upscaler=scaler, + sha256='391a6ce69899dff5ea3214557e9d585608254579217169faf3d4c353caff049e', ), ] diff --git a/modules/modelloader.py b/modules/modelloader.py index 36e7415af..f5a2ff79c 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -10,6 +10,7 @@ import torch from modules import shared from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone +from modules.util import load_file_from_url # noqa, backwards compatibility if TYPE_CHECKING: import spandrel @@ -17,30 +18,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -def load_file_from_url( - url: str, - *, - model_dir: str, - progress: bool = True, - file_name: str | None = None, - hash_prefix: str | None = None, -) -> str: - """Download a file from `url` into `model_dir`, using the file present if possible. - - Returns the path to the downloaded file. - """ - os.makedirs(model_dir, exist_ok=True) - if not file_name: - parts = urlparse(url) - file_name = os.path.basename(parts.path) - cached_file = os.path.abspath(os.path.join(model_dir, file_name)) - if not os.path.exists(cached_file): - print(f'Downloading: "{url}" to {cached_file}\n') - from torch.hub import download_url_to_file - download_url_to_file(url, cached_file, progress=progress, hash_prefix=hash_prefix) - return cached_file - - def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None, hash_prefix=None) -> list: """ A one-and done loader to try finding the desired models in specified directories. diff --git a/modules/upscaler.py b/modules/upscaler.py index 507881fed..12ab3547c 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -93,13 +93,14 @@ class UpscalerData: scaler: Upscaler = None model: None - def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None): + def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None, sha256: str = None): self.name = name self.data_path = path self.local_data_path = path self.scaler = upscaler self.scale = scale self.model = model + self.sha256 = sha256 def __repr__(self): return f"" diff --git a/modules/util.py b/modules/util.py index 7911b0db7..baeba2fa2 100644 --- a/modules/util.py +++ b/modules/util.py @@ -211,3 +211,80 @@ Requested path was: {path} subprocess.Popen(["explorer.exe", subprocess.check_output(["wslpath", "-w", path])]) else: subprocess.Popen(["xdg-open", path]) + + +def load_file_from_url( + url: str, + *, + model_dir: str, + progress: bool = True, + file_name: str | None = None, + hash_prefix: str | None = None, + re_download: bool = False, +) -> str: + """Download a file from `url` into `model_dir`, using the file present if possible. + Returns the path to the downloaded file. + + file_name: if specified, it will be used as the filename, otherwise the filename will be extracted from the url. + file is downloaded to {file_name}.tmp then moved to the final location after download is complete. + hash_prefix: sha256 hex string, if provided, the hash of the downloaded file will be checked against this prefix. + if the hash does not match, the temporary file is deleted and a ValueError is raised. + re_download: forcibly re-download the file even if it already exists. + """ + from urllib.parse import urlparse + import requests + try: + from tqdm import tqdm + except ImportError: + class tqdm: + def __init__(self, *args, **kwargs): + pass + + def update(self, n=1, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + if not file_name: + parts = urlparse(url) + file_name = os.path.basename(parts.path) + + cached_file = os.path.abspath(os.path.join(model_dir, file_name)) + + if re_download or not os.path.exists(cached_file): + os.makedirs(model_dir, exist_ok=True) + temp_file = os.path.join(model_dir, f"{file_name}.tmp") + print(f'\nDownloading: "{url}" to {cached_file}') + response = requests.get(url, stream=True) + response.raise_for_status() + total_size = int(response.headers.get('content-length', 0)) + with tqdm(total=total_size, unit='B', unit_scale=True, desc=file_name, disable=not progress) as progress_bar: + with open(temp_file, 'wb') as file: + for chunk in response.iter_content(chunk_size=1024): + if chunk: + file.write(chunk) + progress_bar.update(len(chunk)) + + if hash_prefix and not compare_sha256(temp_file, hash_prefix): + print(f"Hash mismatch for {temp_file}. Deleting the temporary file.") + os.remove(temp_file) + raise ValueError(f"File hash does not match the expected hash prefix {hash_prefix}!") + + os.rename(temp_file, cached_file) + return cached_file + + +def compare_sha256(file_path: str, hash_prefix: str) -> bool: + """Check if the SHA256 hash of the file matches the given prefix.""" + import hashlib + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 + + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(blksize), b""): + hash_sha256.update(chunk) + return hash_sha256.hexdigest().startswith(hash_prefix.strip().lower())