mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-19 21:00:14 +08:00
rework loras api
This commit is contained in:
parent
7e02a00c81
commit
eb95809501
@ -3,7 +3,6 @@ import os
|
|||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
from typing import Union
|
from typing import Union
|
||||||
import scripts.api as api
|
|
||||||
|
|
||||||
from modules import shared, devices, sd_models, errors, scripts
|
from modules import shared, devices, sd_models, errors, scripts
|
||||||
|
|
||||||
@ -449,8 +448,3 @@ available_lora_aliases = {}
|
|||||||
loaded_loras = []
|
loaded_loras = []
|
||||||
|
|
||||||
list_available_loras()
|
list_available_loras()
|
||||||
try:
|
|
||||||
import modules.script_callbacks as script_callbacks
|
|
||||||
script_callbacks.on_app_started(api.api)
|
|
||||||
except:
|
|
||||||
pass
|
|
@ -1,31 +0,0 @@
|
|||||||
from fastapi import FastAPI
|
|
||||||
import gradio as gr
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import lora
|
|
||||||
|
|
||||||
def get_lora_prompts(path):
|
|
||||||
directory, filename = os.path.split(path)
|
|
||||||
name_without_ext = os.path.splitext(filename)[0]
|
|
||||||
new_filename = name_without_ext + '.civitai.info'
|
|
||||||
try:
|
|
||||||
new_path = os.path.join(directory, new_filename)
|
|
||||||
if os.path.exists(new_path):
|
|
||||||
with open(new_path, 'r') as f:
|
|
||||||
data = json.load(f)
|
|
||||||
trained_words = data.get('trainedWords', [])
|
|
||||||
if len(trained_words) > 0:
|
|
||||||
result = ','.join(trained_words)
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
return ''
|
|
||||||
else:
|
|
||||||
return ''
|
|
||||||
except Exception as e:
|
|
||||||
return ''
|
|
||||||
|
|
||||||
def api(_: gr.Blocks, app: FastAPI):
|
|
||||||
@app.get("/sdapi/v1/loras")
|
|
||||||
async def get_loras():
|
|
||||||
return [{"name": name, "path": lora.available_loras[name].filename, "prompt": get_lora_prompts(lora.available_loras[name].filename)} for name in lora.available_loras]
|
|
||||||
|
|
@ -1,12 +1,12 @@
|
|||||||
import torch
|
import torch
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
import lora
|
import lora
|
||||||
import extra_networks_lora
|
import extra_networks_lora
|
||||||
import ui_extra_networks_lora
|
import ui_extra_networks_lora
|
||||||
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
||||||
|
|
||||||
|
|
||||||
def unload():
|
def unload():
|
||||||
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
|
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
|
||||||
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
|
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
|
||||||
@ -60,3 +60,22 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra
|
|||||||
shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), {
|
shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), {
|
||||||
"lora_functional": shared.OptionInfo(False, "Lora: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
|
"lora_functional": shared.OptionInfo(False, "Lora: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
def create_lora_json(obj: lora.LoraOnDisk):
|
||||||
|
return {
|
||||||
|
"name": obj.name,
|
||||||
|
"alias": obj.alias,
|
||||||
|
"path": obj.filename,
|
||||||
|
"metadata": obj.metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def api_loras(_: gr.Blocks, app: FastAPI):
|
||||||
|
@app.get("/sdapi/v1/loras")
|
||||||
|
async def get_loras():
|
||||||
|
return [create_lora_json(obj) for obj in lora.available_loras.values()]
|
||||||
|
|
||||||
|
|
||||||
|
script_callbacks.on_app_started(api_loras)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user