mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-04-04 19:49:01 +08:00
210 lines
5.9 KiB
Python
210 lines
5.9 KiB
Python
"""
|
|
Supports saving and restoring webui and extensions from a known working set of commits
|
|
"""
|
|
|
|
import os
|
|
import tqdm
|
|
import diskcache
|
|
|
|
from datetime import datetime
|
|
import git
|
|
|
|
from modules import shared, extensions, errors
|
|
from modules.paths_internal import script_path, config_states_dir
|
|
|
|
all_config_states = {}
|
|
config_states_cache = diskcache.Cache(config_states_dir, size_limit=2**20) # 1 MB
|
|
|
|
|
|
def list_config_states():
|
|
global all_config_states
|
|
|
|
all_config_states.clear()
|
|
config_states = []
|
|
|
|
for key in list(config_states_cache):
|
|
config_states.append(config_states_cache[key])
|
|
|
|
for cs in config_states:
|
|
timestamp = datetime.fromtimestamp(cs["created_at"]).strftime("%Y-%m-%d %H:%M:%S")
|
|
name = cs.get("name", "Config")
|
|
full_name = f"{name}: {timestamp}"
|
|
all_config_states[full_name] = cs
|
|
|
|
return all_config_states
|
|
|
|
|
|
def get_webui_config():
|
|
webui_repo = None
|
|
|
|
try:
|
|
if os.path.exists(os.path.join(script_path, ".git")):
|
|
webui_repo = git.Repo(script_path)
|
|
except Exception:
|
|
errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
|
|
|
|
webui_remote = None
|
|
webui_commit_hash = None
|
|
webui_commit_date = None
|
|
webui_branch = None
|
|
if webui_repo and not webui_repo.bare:
|
|
try:
|
|
webui_remote = next(webui_repo.remote().urls, None)
|
|
head = webui_repo.head.commit
|
|
webui_commit_date = webui_repo.head.commit.committed_date
|
|
webui_commit_hash = head.hexsha
|
|
webui_branch = webui_repo.active_branch.name
|
|
|
|
except Exception:
|
|
webui_remote = None
|
|
|
|
return {
|
|
"remote": webui_remote,
|
|
"commit_hash": webui_commit_hash,
|
|
"commit_date": webui_commit_date,
|
|
"branch": webui_branch,
|
|
}
|
|
|
|
|
|
def get_extension_config():
|
|
ext_config = {}
|
|
|
|
for ext in extensions.extensions:
|
|
ext.read_info_from_repo()
|
|
|
|
entry = {
|
|
"name": ext.name,
|
|
"path": ext.path,
|
|
"enabled": ext.enabled,
|
|
"is_builtin": ext.is_builtin,
|
|
"remote": ext.remote,
|
|
"commit_hash": ext.commit_hash,
|
|
"commit_date": ext.commit_date,
|
|
"branch": ext.branch,
|
|
"have_info_from_repo": ext.have_info_from_repo,
|
|
}
|
|
|
|
ext_config[ext.name] = entry
|
|
|
|
return ext_config
|
|
|
|
|
|
def save_config(filename, config_state):
|
|
config_states_cache[filename] = config_state
|
|
|
|
|
|
def get_config():
|
|
creation_time = datetime.now().timestamp()
|
|
webui_config = get_webui_config()
|
|
ext_config = get_extension_config()
|
|
|
|
return {
|
|
"created_at": creation_time,
|
|
"webui": webui_config,
|
|
"extensions": ext_config,
|
|
}
|
|
|
|
|
|
def restore_webui_config(config):
|
|
print("* Restoring webui state...")
|
|
|
|
if "webui" not in config:
|
|
print("Error: No webui data saved to config")
|
|
return
|
|
|
|
webui_config = config["webui"]
|
|
|
|
if "commit_hash" not in webui_config:
|
|
print("Error: No commit saved to webui config")
|
|
return
|
|
|
|
webui_commit_hash = webui_config.get("commit_hash", None)
|
|
webui_repo = None
|
|
|
|
try:
|
|
if os.path.exists(os.path.join(script_path, ".git")):
|
|
webui_repo = git.Repo(script_path)
|
|
except Exception:
|
|
errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
|
|
return
|
|
|
|
try:
|
|
webui_repo.git.fetch(all=True)
|
|
webui_repo.git.reset(webui_commit_hash, hard=True)
|
|
print(f"* Restored webui to commit {webui_commit_hash}.")
|
|
except Exception:
|
|
errors.report(f"Error restoring webui to commit{webui_commit_hash}")
|
|
|
|
|
|
def restore_extension_config(config):
|
|
print("* Restoring extension state...")
|
|
|
|
if "extensions" not in config:
|
|
print("Error: No extension data saved to config")
|
|
return
|
|
|
|
ext_config = config["extensions"]
|
|
|
|
results = []
|
|
disabled = []
|
|
|
|
for ext in tqdm.tqdm(extensions.extensions):
|
|
if ext.is_builtin:
|
|
continue
|
|
|
|
ext.read_info_from_repo()
|
|
current_commit = ext.commit_hash
|
|
|
|
if ext.name not in ext_config:
|
|
ext.disabled = True
|
|
disabled.append(ext.name)
|
|
results.append((ext, current_commit[:8], False, "Saved extension state not found in config, marking as disabled"))
|
|
continue
|
|
|
|
entry = ext_config[ext.name]
|
|
|
|
if "commit_hash" in entry and entry["commit_hash"]:
|
|
try:
|
|
ext.fetch_and_reset_hard(entry["commit_hash"])
|
|
ext.read_info_from_repo()
|
|
if current_commit != entry["commit_hash"]:
|
|
results.append((ext, current_commit[:8], True, entry["commit_hash"][:8]))
|
|
except Exception as ex:
|
|
results.append((ext, current_commit[:8], False, ex))
|
|
else:
|
|
results.append((ext, current_commit[:8], False, "No commit hash found in config"))
|
|
|
|
if not entry.get("enabled", False):
|
|
ext.disabled = True
|
|
disabled.append(ext.name)
|
|
else:
|
|
ext.disabled = False
|
|
|
|
shared.opts.disabled_extensions = disabled
|
|
shared.opts.save(shared.config_filename)
|
|
|
|
print("* Finished restoring extensions. Results:")
|
|
for ext, prev_commit, success, result in results:
|
|
if success:
|
|
print(f" + {ext.name}: {prev_commit} -> {result}")
|
|
else:
|
|
print(f" ! {ext.name}: FAILURE ({result})")
|
|
|
|
|
|
def convert_old_cached_data():
|
|
for file in os.listdir(config_states_dir):
|
|
if not file.endswith(".json"):
|
|
continue
|
|
|
|
import json
|
|
|
|
with open(os.path.join(config_states_dir, file)) as config:
|
|
config_state = json.load(config)
|
|
|
|
filename = os.path.splitext(file)[0]
|
|
config_states_cache[filename] = config_state
|
|
os.replace(os.path.join(config_states_dir, file), os.path.join(script_path, "tmp", file))
|
|
|
|
|
|
convert_old_cached_data()
|