mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-04-06 04:39:01 +08:00
Fix DAT models download (#16302)
This commit is contained in:
parent
5865da28d1
commit
984b952eb3
@ -49,7 +49,18 @@ class UpscalerDAT(Upscaler):
|
|||||||
scaler.local_data_path = modelloader.load_file_from_url(
|
scaler.local_data_path = modelloader.load_file_from_url(
|
||||||
scaler.data_path,
|
scaler.data_path,
|
||||||
model_dir=self.model_download_path,
|
model_dir=self.model_download_path,
|
||||||
|
hash_prefix=scaler.sha256,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if os.path.getsize(scaler.local_data_path) < 200:
|
||||||
|
# Re-download if the file is too small, probably an LFS pointer
|
||||||
|
scaler.local_data_path = modelloader.load_file_from_url(
|
||||||
|
scaler.data_path,
|
||||||
|
model_dir=self.model_download_path,
|
||||||
|
hash_prefix=scaler.sha256,
|
||||||
|
re_download=True,
|
||||||
|
)
|
||||||
|
|
||||||
if not os.path.exists(scaler.local_data_path):
|
if not os.path.exists(scaler.local_data_path):
|
||||||
raise FileNotFoundError(f"DAT data missing: {scaler.local_data_path}")
|
raise FileNotFoundError(f"DAT data missing: {scaler.local_data_path}")
|
||||||
return scaler
|
return scaler
|
||||||
@ -60,20 +71,23 @@ def get_dat_models(scaler):
|
|||||||
return [
|
return [
|
||||||
UpscalerData(
|
UpscalerData(
|
||||||
name="DAT x2",
|
name="DAT x2",
|
||||||
path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x2.pth",
|
path="https://huggingface.co/w-e-w/DAT/resolve/main/experiments/pretrained_models/DAT/DAT_x2.pth",
|
||||||
scale=2,
|
scale=2,
|
||||||
upscaler=scaler,
|
upscaler=scaler,
|
||||||
|
sha256='7760aa96e4ee77e29d4f89c3a4486200042e019461fdb8aa286f49aa00b89b51',
|
||||||
),
|
),
|
||||||
UpscalerData(
|
UpscalerData(
|
||||||
name="DAT x3",
|
name="DAT x3",
|
||||||
path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x3.pth",
|
path="https://huggingface.co/w-e-w/DAT/resolve/main/experiments/pretrained_models/DAT/DAT_x3.pth",
|
||||||
scale=3,
|
scale=3,
|
||||||
upscaler=scaler,
|
upscaler=scaler,
|
||||||
|
sha256='581973e02c06f90d4eb90acf743ec9604f56f3c2c6f9e1e2c2b38ded1f80d197',
|
||||||
),
|
),
|
||||||
UpscalerData(
|
UpscalerData(
|
||||||
name="DAT x4",
|
name="DAT x4",
|
||||||
path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x4.pth",
|
path="https://huggingface.co/w-e-w/DAT/resolve/main/experiments/pretrained_models/DAT/DAT_x4.pth",
|
||||||
scale=4,
|
scale=4,
|
||||||
upscaler=scaler,
|
upscaler=scaler,
|
||||||
|
sha256='391a6ce69899dff5ea3214557e9d585608254579217169faf3d4c353caff049e',
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
@ -10,6 +10,7 @@ import torch
|
|||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
|
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
|
||||||
|
from modules.util import load_file_from_url # noqa, backwards compatibility
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import spandrel
|
import spandrel
|
||||||
@ -17,30 +18,6 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def load_file_from_url(
|
|
||||||
url: str,
|
|
||||||
*,
|
|
||||||
model_dir: str,
|
|
||||||
progress: bool = True,
|
|
||||||
file_name: str | None = None,
|
|
||||||
hash_prefix: str | None = None,
|
|
||||||
) -> str:
|
|
||||||
"""Download a file from `url` into `model_dir`, using the file present if possible.
|
|
||||||
|
|
||||||
Returns the path to the downloaded file.
|
|
||||||
"""
|
|
||||||
os.makedirs(model_dir, exist_ok=True)
|
|
||||||
if not file_name:
|
|
||||||
parts = urlparse(url)
|
|
||||||
file_name = os.path.basename(parts.path)
|
|
||||||
cached_file = os.path.abspath(os.path.join(model_dir, file_name))
|
|
||||||
if not os.path.exists(cached_file):
|
|
||||||
print(f'Downloading: "{url}" to {cached_file}\n')
|
|
||||||
from torch.hub import download_url_to_file
|
|
||||||
download_url_to_file(url, cached_file, progress=progress, hash_prefix=hash_prefix)
|
|
||||||
return cached_file
|
|
||||||
|
|
||||||
|
|
||||||
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None, hash_prefix=None) -> list:
|
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None, hash_prefix=None) -> list:
|
||||||
"""
|
"""
|
||||||
A one-and done loader to try finding the desired models in specified directories.
|
A one-and done loader to try finding the desired models in specified directories.
|
||||||
|
@ -93,13 +93,14 @@ class UpscalerData:
|
|||||||
scaler: Upscaler = None
|
scaler: Upscaler = None
|
||||||
model: None
|
model: None
|
||||||
|
|
||||||
def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
|
def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None, sha256: str = None):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.data_path = path
|
self.data_path = path
|
||||||
self.local_data_path = path
|
self.local_data_path = path
|
||||||
self.scaler = upscaler
|
self.scaler = upscaler
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.sha256 = sha256
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<UpscalerData name={self.name} path={self.data_path} scale={self.scale}>"
|
return f"<UpscalerData name={self.name} path={self.data_path} scale={self.scale}>"
|
||||||
|
@ -211,3 +211,80 @@ Requested path was: {path}
|
|||||||
subprocess.Popen(["explorer.exe", subprocess.check_output(["wslpath", "-w", path])])
|
subprocess.Popen(["explorer.exe", subprocess.check_output(["wslpath", "-w", path])])
|
||||||
else:
|
else:
|
||||||
subprocess.Popen(["xdg-open", path])
|
subprocess.Popen(["xdg-open", path])
|
||||||
|
|
||||||
|
|
||||||
|
def load_file_from_url(
|
||||||
|
url: str,
|
||||||
|
*,
|
||||||
|
model_dir: str,
|
||||||
|
progress: bool = True,
|
||||||
|
file_name: str | None = None,
|
||||||
|
hash_prefix: str | None = None,
|
||||||
|
re_download: bool = False,
|
||||||
|
) -> str:
|
||||||
|
"""Download a file from `url` into `model_dir`, using the file present if possible.
|
||||||
|
Returns the path to the downloaded file.
|
||||||
|
|
||||||
|
file_name: if specified, it will be used as the filename, otherwise the filename will be extracted from the url.
|
||||||
|
file is downloaded to {file_name}.tmp then moved to the final location after download is complete.
|
||||||
|
hash_prefix: sha256 hex string, if provided, the hash of the downloaded file will be checked against this prefix.
|
||||||
|
if the hash does not match, the temporary file is deleted and a ValueError is raised.
|
||||||
|
re_download: forcibly re-download the file even if it already exists.
|
||||||
|
"""
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
import requests
|
||||||
|
try:
|
||||||
|
from tqdm import tqdm
|
||||||
|
except ImportError:
|
||||||
|
class tqdm:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def update(self, n=1, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not file_name:
|
||||||
|
parts = urlparse(url)
|
||||||
|
file_name = os.path.basename(parts.path)
|
||||||
|
|
||||||
|
cached_file = os.path.abspath(os.path.join(model_dir, file_name))
|
||||||
|
|
||||||
|
if re_download or not os.path.exists(cached_file):
|
||||||
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
|
temp_file = os.path.join(model_dir, f"{file_name}.tmp")
|
||||||
|
print(f'\nDownloading: "{url}" to {cached_file}')
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
total_size = int(response.headers.get('content-length', 0))
|
||||||
|
with tqdm(total=total_size, unit='B', unit_scale=True, desc=file_name, disable=not progress) as progress_bar:
|
||||||
|
with open(temp_file, 'wb') as file:
|
||||||
|
for chunk in response.iter_content(chunk_size=1024):
|
||||||
|
if chunk:
|
||||||
|
file.write(chunk)
|
||||||
|
progress_bar.update(len(chunk))
|
||||||
|
|
||||||
|
if hash_prefix and not compare_sha256(temp_file, hash_prefix):
|
||||||
|
print(f"Hash mismatch for {temp_file}. Deleting the temporary file.")
|
||||||
|
os.remove(temp_file)
|
||||||
|
raise ValueError(f"File hash does not match the expected hash prefix {hash_prefix}!")
|
||||||
|
|
||||||
|
os.rename(temp_file, cached_file)
|
||||||
|
return cached_file
|
||||||
|
|
||||||
|
|
||||||
|
def compare_sha256(file_path: str, hash_prefix: str) -> bool:
|
||||||
|
"""Check if the SHA256 hash of the file matches the given prefix."""
|
||||||
|
import hashlib
|
||||||
|
hash_sha256 = hashlib.sha256()
|
||||||
|
blksize = 1024 * 1024
|
||||||
|
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
for chunk in iter(lambda: f.read(blksize), b""):
|
||||||
|
hash_sha256.update(chunk)
|
||||||
|
return hash_sha256.hexdigest().startswith(hash_prefix.strip().lower())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user