2023-01-21 21:15:53 +08:00
import torch
2023-01-28 22:18:47 +08:00
import gradio as gr
2023-05-09 16:25:46 +08:00
from fastapi import FastAPI
2023-01-21 21:15:53 +08:00
import lora
import extra_networks_lora
import ui_extra_networks_lora
2023-01-25 16:29:46 +08:00
from modules import script_callbacks , ui_extra_networks , extra_networks , shared
2023-01-21 21:15:53 +08:00
def unload ( ) :
torch . nn . Linear . forward = torch . nn . Linear_forward_before_lora
2023-03-26 04:06:33 +08:00
torch . nn . Linear . _load_from_state_dict = torch . nn . Linear_load_state_dict_before_lora
2023-01-21 21:15:53 +08:00
torch . nn . Conv2d . forward = torch . nn . Conv2d_forward_before_lora
2023-03-26 04:06:33 +08:00
torch . nn . Conv2d . _load_from_state_dict = torch . nn . Conv2d_load_state_dict_before_lora
2023-03-26 15:44:20 +08:00
torch . nn . MultiheadAttention . forward = torch . nn . MultiheadAttention_forward_before_lora
torch . nn . MultiheadAttention . _load_from_state_dict = torch . nn . MultiheadAttention_load_state_dict_before_lora
2023-01-21 21:15:53 +08:00
def before_ui ( ) :
ui_extra_networks . register_page ( ui_extra_networks_lora . ExtraNetworksPageLora ( ) )
extra_networks . register_extra_network ( extra_networks_lora . ExtraNetworkLora ( ) )
if not hasattr ( torch . nn , ' Linear_forward_before_lora ' ) :
torch . nn . Linear_forward_before_lora = torch . nn . Linear . forward
2023-03-26 04:06:33 +08:00
if not hasattr ( torch . nn , ' Linear_load_state_dict_before_lora ' ) :
torch . nn . Linear_load_state_dict_before_lora = torch . nn . Linear . _load_from_state_dict
2023-01-21 21:15:53 +08:00
if not hasattr ( torch . nn , ' Conv2d_forward_before_lora ' ) :
torch . nn . Conv2d_forward_before_lora = torch . nn . Conv2d . forward
2023-03-26 04:06:33 +08:00
if not hasattr ( torch . nn , ' Conv2d_load_state_dict_before_lora ' ) :
torch . nn . Conv2d_load_state_dict_before_lora = torch . nn . Conv2d . _load_from_state_dict
2023-03-26 15:44:20 +08:00
if not hasattr ( torch . nn , ' MultiheadAttention_forward_before_lora ' ) :
torch . nn . MultiheadAttention_forward_before_lora = torch . nn . MultiheadAttention . forward
if not hasattr ( torch . nn , ' MultiheadAttention_load_state_dict_before_lora ' ) :
torch . nn . MultiheadAttention_load_state_dict_before_lora = torch . nn . MultiheadAttention . _load_from_state_dict
2023-01-21 21:15:53 +08:00
torch . nn . Linear . forward = lora . lora_Linear_forward
2023-03-26 04:06:33 +08:00
torch . nn . Linear . _load_from_state_dict = lora . lora_Linear_load_state_dict
2023-01-21 21:15:53 +08:00
torch . nn . Conv2d . forward = lora . lora_Conv2d_forward
2023-03-26 04:06:33 +08:00
torch . nn . Conv2d . _load_from_state_dict = lora . lora_Conv2d_load_state_dict
2023-03-26 15:44:20 +08:00
torch . nn . MultiheadAttention . forward = lora . lora_MultiheadAttention_forward
torch . nn . MultiheadAttention . _load_from_state_dict = lora . lora_MultiheadAttention_load_state_dict
2023-01-21 21:15:53 +08:00
script_callbacks . on_model_loaded ( lora . assign_lora_names_to_compvis_modules )
script_callbacks . on_script_unloaded ( unload )
script_callbacks . on_before_ui ( before_ui )
2023-05-08 12:28:30 +08:00
script_callbacks . on_infotext_pasted ( lora . infotext_pasted )
2023-01-25 16:29:46 +08:00
shared . options_templates . update ( shared . options_section ( ( ' extra_networks ' , " Extra Networks " ) , {
2023-05-10 16:05:02 +08:00
" sd_lora " : shared . OptionInfo ( " None " , " Add Lora to prompt " , gr . Dropdown , lambda : { " choices " : [ " None " ] + list ( lora . available_loras ) } , refresh = lora . list_available_loras ) ,
2023-01-25 16:29:46 +08:00
} ) )
2023-05-08 17:07:43 +08:00
shared . options_templates . update ( shared . options_section ( ( ' compatibility ' , " Compatibility " ) , {
" lora_functional " : shared . OptionInfo ( False , " Lora: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension " ) ,
} ) )
2023-05-09 16:25:46 +08:00
def create_lora_json ( obj : lora . LoraOnDisk ) :
return {
" name " : obj . name ,
" alias " : obj . alias ,
" path " : obj . filename ,
" metadata " : obj . metadata ,
}
def api_loras ( _ : gr . Blocks , app : FastAPI ) :
@app.get ( " /sdapi/v1/loras " )
async def get_loras ( ) :
return [ create_lora_json ( obj ) for obj in lora . available_loras . values ( ) ]
script_callbacks . on_app_started ( api_loras )