2022-11-02 00:13:59 +08:00
import base64
import io
2023-07-15 12:44:37 +08:00
import os
2022-10-30 03:45:29 +08:00
import time
2023-01-03 22:45:16 +08:00
import datetime
2022-10-17 14:58:42 +08:00
import uvicorn
2023-08-20 21:41:27 +08:00
import ipaddress
import requests
2023-03-26 02:16:35 +08:00
import gradio as gr
2022-11-03 11:51:22 +08:00
from threading import Lock
2022-11-23 17:43:58 +08:00
from io import BytesIO
2023-03-16 03:11:04 +08:00
from fastapi import APIRouter , Depends , FastAPI , Request , Response
2022-11-15 16:12:34 +08:00
from fastapi . security import HTTPBasic , HTTPBasicCredentials
2023-03-16 03:11:04 +08:00
from fastapi . exceptions import HTTPException
from fastapi . responses import JSONResponse
from fastapi . encoders import jsonable_encoder
2022-11-15 16:12:34 +08:00
from secrets import compare_digest
2022-10-17 14:58:42 +08:00
import modules . shared as shared
2024-04-20 10:29:22 +08:00
from modules import sd_samplers , deepbooru , sd_hijack , images , scripts , ui , postprocessing , errors , restart , shared_items , script_callbacks , infotext_utils , sd_models , sd_schedulers
2023-05-10 13:25:25 +08:00
from modules . api import models
from modules . shared import opts
2022-10-22 07:27:40 +08:00
from modules . processing import StableDiffusionProcessingTxt2Img , StableDiffusionProcessingImg2Img , process_images
2022-12-25 07:02:22 +08:00
from modules . textual_inversion . textual_inversion import create_embedding , train_embedding
from modules . hypernetworks . hypernetwork import create_hypernetwork , train_hypernetwork
2024-03-06 18:04:58 +08:00
from PIL import PngImagePlugin
2023-01-27 16:28:12 +08:00
from modules . sd_models_config import find_checkpoint_config_near_filename
2022-11-03 11:51:22 +08:00
from modules . realesrgan_model import get_realesrgan_models
2022-12-25 07:02:22 +08:00
from modules import devices
2023-08-25 15:58:19 +08:00
from typing import Any
2023-01-23 23:10:59 +08:00
import piexif
import piexif . helper
2023-07-04 01:17:47 +08:00
from contextlib import closing
2023-12-15 16:57:17 +08:00
from modules . progress import create_task_id , add_task_to_queue , start_task , finish_task , current_task
2022-10-17 14:58:42 +08:00
2023-01-06 05:21:48 +08:00
def script_name_to_index ( name , scripts ) :
try :
return [ script . title ( ) . lower ( ) for script in scripts ] . index ( name . lower ( ) )
2023-05-10 16:19:16 +08:00
except Exception as e :
raise HTTPException ( status_code = 422 , detail = f " Script ' { name } ' not found " ) from e
2022-10-30 14:10:22 +08:00
2022-11-19 17:01:51 +08:00
def validate_sampler_name ( name ) :
config = sd_samplers . all_samplers_map . get ( name , None )
if config is None :
2024-07-03 22:33:25 +08:00
raise HTTPException ( status_code = 400 , detail = " Sampler not found " )
2022-10-17 14:58:42 +08:00
2022-11-19 17:01:51 +08:00
return name
2022-10-22 07:27:40 +08:00
2023-05-10 16:19:16 +08:00
2022-10-24 03:01:16 +08:00
def setUpscalers ( req : dict ) :
reqDict = vars ( req )
2023-01-23 14:24:43 +08:00
reqDict [ ' extras_upscaler_1 ' ] = reqDict . pop ( ' upscaler_1 ' , None )
reqDict [ ' extras_upscaler_2 ' ] = reqDict . pop ( ' upscaler_2 ' , None )
2022-10-24 03:01:16 +08:00
return reqDict
2022-10-28 03:20:15 +08:00
2023-05-10 16:19:16 +08:00
2023-08-21 12:38:07 +08:00
def verify_url ( url ) :
""" Returns True if the url refers to a global resource. """
import socket
from urllib . parse import urlparse
try :
parsed_url = urlparse ( url )
domain_name = parsed_url . netloc
host = socket . gethostbyname_ex ( domain_name )
for ip in host [ 2 ] :
ip_addr = ipaddress . ip_address ( ip )
if not ip_addr . is_global :
return False
except Exception :
return False
2023-08-20 21:41:27 +08:00
2023-08-21 12:38:07 +08:00
return True
def decode_base64_to_image ( encoding ) :
2023-08-19 12:19:21 +08:00
if encoding . startswith ( " http:// " ) or encoding . startswith ( " https:// " ) :
2023-08-21 12:38:07 +08:00
if not opts . api_enable_requests :
raise HTTPException ( status_code = 500 , detail = " Requests not allowed " )
if opts . api_forbid_local_requests and not verify_url ( encoding ) :
raise HTTPException ( status_code = 500 , detail = " Request to local resource not allowed " )
2023-08-20 21:41:27 +08:00
2023-08-21 12:38:07 +08:00
headers = { ' user-agent ' : opts . api_useragent } if opts . api_useragent else { }
response = requests . get ( encoding , timeout = 30 , headers = headers )
2023-08-19 12:19:21 +08:00
try :
2024-03-05 00:14:53 +08:00
image = images . read ( BytesIO ( response . content ) )
2023-08-19 12:19:21 +08:00
return image
except Exception as e :
raise HTTPException ( status_code = 500 , detail = " Invalid image url " ) from e
2022-11-24 13:10:40 +08:00
if encoding . startswith ( " data:image/ " ) :
encoding = encoding . split ( " ; " ) [ 1 ] . split ( " , " ) [ 1 ]
2023-01-24 06:11:22 +08:00
try :
2024-03-05 00:14:53 +08:00
image = images . read ( BytesIO ( base64 . b64decode ( encoding ) ) )
2023-01-24 06:11:22 +08:00
return image
2023-05-10 16:19:16 +08:00
except Exception as e :
raise HTTPException ( status_code = 500 , detail = " Invalid encoded image " ) from e
2022-10-17 14:58:42 +08:00
2022-11-02 00:13:59 +08:00
def encode_pil_to_base64 ( image ) :
2022-11-02 22:37:45 +08:00
with io . BytesIO ( ) as output_bytes :
2023-10-01 23:06:48 +08:00
if isinstance ( image , str ) :
return image
2023-01-23 23:10:59 +08:00
if opts . samples_format . lower ( ) == ' png ' :
use_metadata = False
metadata = PngImagePlugin . PngInfo ( )
for key , value in image . info . items ( ) :
if isinstance ( key , str ) and isinstance ( value , str ) :
metadata . add_text ( key , value )
use_metadata = True
image . save ( output_bytes , format = " PNG " , pnginfo = ( metadata if use_metadata else None ) , quality = opts . jpeg_quality )
elif opts . samples_format . lower ( ) in ( " jpg " , " jpeg " , " webp " ) :
2023-07-06 18:43:17 +08:00
if image . mode == " RGBA " :
image = image . convert ( " RGB " )
2023-01-23 23:10:59 +08:00
parameters = image . info . get ( ' parameters ' , None )
exif_bytes = piexif . dump ( {
" Exif " : { piexif . ExifIFD . UserComment : piexif . helper . UserComment . dump ( parameters or " " , encoding = " unicode " ) }
} )
if opts . samples_format . lower ( ) in ( " jpg " , " jpeg " ) :
image . save ( output_bytes , format = " JPEG " , exif = exif_bytes , quality = opts . jpeg_quality )
else :
image . save ( output_bytes , format = " WEBP " , exif = exif_bytes , quality = opts . jpeg_quality )
else :
raise HTTPException ( status_code = 500 , detail = " Invalid image format " )
2022-11-02 22:37:45 +08:00
bytes_data = output_bytes . getvalue ( )
2023-01-23 23:10:59 +08:00
2022-11-02 22:37:45 +08:00
return base64 . b64encode ( bytes_data )
2022-11-02 00:13:59 +08:00
2023-05-10 16:19:16 +08:00
2023-01-03 23:58:52 +08:00
def api_middleware ( app : FastAPI ) :
2023-07-15 12:44:37 +08:00
rich_available = False
2023-03-16 03:11:04 +08:00
try :
2023-07-15 12:44:37 +08:00
if os . environ . get ( ' WEBUI_RICH_EXCEPTIONS ' , None ) is not None :
import anyio # importing just so it can be placed on silent list
import starlette # importing just so it can be placed on silent list
from rich . console import Console
console = Console ( )
rich_available = True
2023-05-10 13:25:25 +08:00
except Exception :
2023-07-15 12:44:37 +08:00
pass
2023-03-16 03:11:04 +08:00
2023-01-03 22:45:16 +08:00
@app.middleware ( " http " )
async def log_and_time ( req : Request , call_next ) :
ts = time . time ( )
res : Response = await call_next ( req )
duration = str ( round ( time . time ( ) - ts , 4 ) )
res . headers [ " X-Process-Time " ] = duration
2023-01-03 23:58:52 +08:00
endpoint = req . scope . get ( ' path ' , ' err ' )
if shared . cmd_opts . api_log and endpoint . startswith ( ' /sdapi ' ) :
print ( ' API {t} {code} {prot} / {ver} {method} {endpoint} {cli} {duration} ' . format (
2023-07-15 12:44:37 +08:00
t = datetime . datetime . now ( ) . strftime ( " % Y- % m- %d % H: % M: % S. %f " ) ,
code = res . status_code ,
ver = req . scope . get ( ' http_version ' , ' 0.0 ' ) ,
cli = req . scope . get ( ' client ' , ( ' 0:0.0.0 ' , 0 ) ) [ 0 ] ,
prot = req . scope . get ( ' scheme ' , ' err ' ) ,
method = req . scope . get ( ' method ' , ' err ' ) ,
endpoint = endpoint ,
duration = duration ,
2023-01-03 22:45:16 +08:00
) )
return res
2023-03-16 03:11:04 +08:00
def handle_exception ( request : Request , e : Exception ) :
err = {
" error " : type ( e ) . __name__ ,
" detail " : vars ( e ) . get ( ' detail ' , ' ' ) ,
" body " : vars ( e ) . get ( ' body ' , ' ' ) ,
" errors " : str ( e ) ,
}
2023-07-15 12:44:37 +08:00
if not isinstance ( e , HTTPException ) : # do not print backtrace on known httpexceptions
2023-05-29 13:54:13 +08:00
message = f " API error: { request . method } : { request . url } { err } "
2023-03-16 03:11:04 +08:00
if rich_available :
2023-05-29 13:54:13 +08:00
print ( message )
2023-03-16 03:11:04 +08:00
console . print_exception ( show_locals = True , max_frames = 2 , extra_lines = 1 , suppress = [ anyio , starlette ] , word_wrap = False , width = min ( [ console . width , 200 ] ) )
else :
2023-06-01 00:56:37 +08:00
errors . report ( message , exc_info = True )
2023-03-16 03:11:04 +08:00
return JSONResponse ( status_code = vars ( e ) . get ( ' status_code ' , 500 ) , content = jsonable_encoder ( err ) )
@app.middleware ( " http " )
async def exception_handling ( request : Request , call_next ) :
try :
return await call_next ( request )
except Exception as e :
return handle_exception ( request , e )
@app.exception_handler ( Exception )
async def fastapi_exception_handler ( request : Request , e : Exception ) :
return handle_exception ( request , e )
@app.exception_handler ( HTTPException )
async def http_exception_handler ( request : Request , e : HTTPException ) :
return handle_exception ( request , e )
2022-11-02 00:13:59 +08:00
2022-10-17 14:58:42 +08:00
class Api :
2022-11-03 11:51:22 +08:00
def __init__ ( self , app : FastAPI , queue_lock : Lock ) :
2022-11-15 16:12:34 +08:00
if shared . cmd_opts . api_auth :
2023-05-10 16:55:09 +08:00
self . credentials = { }
2022-11-15 16:12:34 +08:00
for auth in shared . cmd_opts . api_auth . split ( " , " ) :
user , password = auth . split ( " : " )
2022-12-15 10:01:32 +08:00
self . credentials [ user ] = password
2022-11-15 16:12:34 +08:00
2022-10-17 14:58:42 +08:00
self . router = APIRouter ( )
2022-10-18 14:51:53 +08:00
self . app = app
self . queue_lock = queue_lock
2023-01-04 19:36:57 +08:00
api_middleware ( self . app )
2023-05-10 13:25:25 +08:00
self . add_api_route ( " /sdapi/v1/txt2img " , self . text2imgapi , methods = [ " POST " ] , response_model = models . TextToImageResponse )
self . add_api_route ( " /sdapi/v1/img2img " , self . img2imgapi , methods = [ " POST " ] , response_model = models . ImageToImageResponse )
self . add_api_route ( " /sdapi/v1/extra-single-image " , self . extras_single_image_api , methods = [ " POST " ] , response_model = models . ExtrasSingleImageResponse )
self . add_api_route ( " /sdapi/v1/extra-batch-images " , self . extras_batch_images_api , methods = [ " POST " ] , response_model = models . ExtrasBatchImagesResponse )
self . add_api_route ( " /sdapi/v1/png-info " , self . pnginfoapi , methods = [ " POST " ] , response_model = models . PNGInfoResponse )
self . add_api_route ( " /sdapi/v1/progress " , self . progressapi , methods = [ " GET " ] , response_model = models . ProgressResponse )
2022-11-15 16:12:34 +08:00
self . add_api_route ( " /sdapi/v1/interrogate " , self . interrogateapi , methods = [ " POST " ] )
self . add_api_route ( " /sdapi/v1/interrupt " , self . interruptapi , methods = [ " POST " ] )
2022-11-19 20:13:07 +08:00
self . add_api_route ( " /sdapi/v1/skip " , self . skip , methods = [ " POST " ] )
2023-05-10 13:25:25 +08:00
self . add_api_route ( " /sdapi/v1/options " , self . get_config , methods = [ " GET " ] , response_model = models . OptionsModel )
2022-11-15 16:12:34 +08:00
self . add_api_route ( " /sdapi/v1/options " , self . set_config , methods = [ " POST " ] )
2023-05-10 13:25:25 +08:00
self . add_api_route ( " /sdapi/v1/cmd-flags " , self . get_cmd_flags , methods = [ " GET " ] , response_model = models . FlagsModel )
2023-08-25 15:58:19 +08:00
self . add_api_route ( " /sdapi/v1/samplers " , self . get_samplers , methods = [ " GET " ] , response_model = list [ models . SamplerItem ] )
2024-04-20 10:29:22 +08:00
self . add_api_route ( " /sdapi/v1/schedulers " , self . get_schedulers , methods = [ " GET " ] , response_model = list [ models . SchedulerItem ] )
2023-08-25 15:58:19 +08:00
self . add_api_route ( " /sdapi/v1/upscalers " , self . get_upscalers , methods = [ " GET " ] , response_model = list [ models . UpscalerItem ] )
self . add_api_route ( " /sdapi/v1/latent-upscale-modes " , self . get_latent_upscale_modes , methods = [ " GET " ] , response_model = list [ models . LatentUpscalerModeItem ] )
self . add_api_route ( " /sdapi/v1/sd-models " , self . get_sd_models , methods = [ " GET " ] , response_model = list [ models . SDModelItem ] )
self . add_api_route ( " /sdapi/v1/sd-vae " , self . get_sd_vaes , methods = [ " GET " ] , response_model = list [ models . SDVaeItem ] )
self . add_api_route ( " /sdapi/v1/hypernetworks " , self . get_hypernetworks , methods = [ " GET " ] , response_model = list [ models . HypernetworkItem ] )
self . add_api_route ( " /sdapi/v1/face-restorers " , self . get_face_restorers , methods = [ " GET " ] , response_model = list [ models . FaceRestorerItem ] )
self . add_api_route ( " /sdapi/v1/realesrgan-models " , self . get_realesrgan_models , methods = [ " GET " ] , response_model = list [ models . RealesrganItem ] )
self . add_api_route ( " /sdapi/v1/prompt-styles " , self . get_prompt_styles , methods = [ " GET " ] , response_model = list [ models . PromptStyleItem ] )
2023-05-10 13:25:25 +08:00
self . add_api_route ( " /sdapi/v1/embeddings " , self . get_embeddings , methods = [ " GET " ] , response_model = models . EmbeddingsResponse )
2024-01-21 21:05:47 +08:00
self . add_api_route ( " /sdapi/v1/refresh-embeddings " , self . refresh_embeddings , methods = [ " POST " ] )
2022-12-12 03:16:44 +08:00
self . add_api_route ( " /sdapi/v1/refresh-checkpoints " , self . refresh_checkpoints , methods = [ " POST " ] )
2023-07-24 19:45:08 +08:00
self . add_api_route ( " /sdapi/v1/refresh-vae " , self . refresh_vae , methods = [ " POST " ] )
2023-05-10 13:25:25 +08:00
self . add_api_route ( " /sdapi/v1/create/embedding " , self . create_embedding , methods = [ " POST " ] , response_model = models . CreateResponse )
self . add_api_route ( " /sdapi/v1/create/hypernetwork " , self . create_hypernetwork , methods = [ " POST " ] , response_model = models . CreateResponse )
self . add_api_route ( " /sdapi/v1/train/embedding " , self . train_embedding , methods = [ " POST " ] , response_model = models . TrainResponse )
self . add_api_route ( " /sdapi/v1/train/hypernetwork " , self . train_hypernetwork , methods = [ " POST " ] , response_model = models . TrainResponse )
self . add_api_route ( " /sdapi/v1/memory " , self . get_memory , methods = [ " GET " ] , response_model = models . MemoryResponse )
2023-03-09 12:56:19 +08:00
self . add_api_route ( " /sdapi/v1/unload-checkpoint " , self . unloadapi , methods = [ " POST " ] )
self . add_api_route ( " /sdapi/v1/reload-checkpoint " , self . reloadapi , methods = [ " POST " ] )
2023-05-10 13:25:25 +08:00
self . add_api_route ( " /sdapi/v1/scripts " , self . get_scripts_list , methods = [ " GET " ] , response_model = models . ScriptsList )
2023-08-25 15:58:19 +08:00
self . add_api_route ( " /sdapi/v1/script-info " , self . get_script_info , methods = [ " GET " ] , response_model = list [ models . ScriptInfo ] )
self . add_api_route ( " /sdapi/v1/extensions " , self . get_extensions_list , methods = [ " GET " ] , response_model = list [ models . ExtensionItem ] )
2022-11-15 16:12:34 +08:00
2023-06-29 13:21:28 +08:00
if shared . cmd_opts . api_server_stop :
2023-06-14 17:51:47 +08:00
self . add_api_route ( " /sdapi/v1/server-kill " , self . kill_webui , methods = [ " POST " ] )
self . add_api_route ( " /sdapi/v1/server-restart " , self . restart_webui , methods = [ " POST " ] )
2023-06-14 18:53:08 +08:00
self . add_api_route ( " /sdapi/v1/server-stop " , self . stop_webui , methods = [ " POST " ] )
2022-11-15 16:12:34 +08:00
2023-03-26 02:16:35 +08:00
self . default_script_arg_txt2img = [ ]
self . default_script_arg_img2img = [ ]
2023-12-31 01:09:13 +08:00
txt2img_script_runner = scripts . scripts_txt2img
img2img_script_runner = scripts . scripts_img2img
if not txt2img_script_runner . scripts or not img2img_script_runner . scripts :
ui . create_ui ( )
if not txt2img_script_runner . scripts :
txt2img_script_runner . initialize_scripts ( False )
if not self . default_script_arg_txt2img :
self . default_script_arg_txt2img = self . init_default_script_args ( txt2img_script_runner )
if not img2img_script_runner . scripts :
img2img_script_runner . initialize_scripts ( True )
if not self . default_script_arg_img2img :
self . default_script_arg_img2img = self . init_default_script_args ( img2img_script_runner )
2023-12-30 21:32:22 +08:00
2023-12-26 14:46:29 +08:00
2022-11-15 16:12:34 +08:00
def add_api_route ( self , path : str , endpoint , * * kwargs ) :
if shared . cmd_opts . api_auth :
return self . app . add_api_route ( path , endpoint , dependencies = [ Depends ( self . auth ) ] , * * kwargs )
return self . app . add_api_route ( path , endpoint , * * kwargs )
2022-12-15 10:01:32 +08:00
def auth ( self , credentials : HTTPBasicCredentials = Depends ( HTTPBasic ( ) ) ) :
if credentials . username in self . credentials :
if compare_digest ( credentials . password , self . credentials [ credentials . username ] ) :
2022-11-15 16:12:34 +08:00
return True
raise HTTPException ( status_code = 401 , detail = " Incorrect username or password " , headers = { " WWW-Authenticate " : " Basic " } )
2022-10-17 14:58:42 +08:00
2023-02-27 08:17:58 +08:00
def get_selectable_script ( self , script_name , script_runner ) :
if script_name is None or script_name == " " :
2023-01-08 21:14:38 +08:00
return None , None
script_idx = script_name_to_index ( script_name , script_runner . selectable_scripts )
script = script_runner . selectable_scripts [ script_idx ]
return script , script_idx
2023-05-11 23:28:15 +08:00
2023-03-04 11:46:07 +08:00
def get_scripts_list ( self ) :
2023-05-18 03:43:24 +08:00
t2ilist = [ script . name for script in scripts . scripts_txt2img . scripts if script . name is not None ]
i2ilist = [ script . name for script in scripts . scripts_img2img . scripts if script . name is not None ]
2023-03-04 11:46:07 +08:00
2023-05-10 13:25:25 +08:00
return models . ScriptsList ( txt2img = t2ilist , img2img = i2ilist )
2023-01-07 22:21:31 +08:00
2023-05-18 03:43:24 +08:00
def get_script_info ( self ) :
res = [ ]
for script_list in [ scripts . scripts_txt2img . scripts , scripts . scripts_img2img . scripts ] :
res + = [ script . api_info for script in script_list if script . api_info is not None ]
return res
2023-02-27 08:17:58 +08:00
def get_script ( self , script_name , script_runner ) :
2023-02-28 12:27:33 +08:00
if script_name is None or script_name == " " :
return None , None
2023-05-11 23:28:15 +08:00
2023-02-28 12:27:33 +08:00
script_idx = script_name_to_index ( script_name , script_runner . scripts )
return script_runner . scripts [ script_idx ]
2023-02-27 08:17:58 +08:00
2023-03-26 02:16:35 +08:00
def init_default_script_args ( self , script_runner ) :
2023-02-27 08:17:58 +08:00
#find max idx from the scripts in runner and generate a none array to init script_args
last_arg_index = 1
for script in script_runner . scripts :
if last_arg_index < script . args_to :
last_arg_index = script . args_to
2023-02-28 12:27:33 +08:00
# None everywhere except position 0 to initialize script args
2023-02-27 08:17:58 +08:00
script_args = [ None ] * last_arg_index
2023-03-26 02:16:35 +08:00
script_args [ 0 ] = 0
# get default values
with gr . Blocks ( ) : # will throw errors calling ui function without this
for script in script_runner . scripts :
if script . ui ( script . is_img2img ) :
ui_default_values = [ ]
for elem in script . ui ( script . is_img2img ) :
ui_default_values . append ( elem . value )
script_args [ script . args_from : script . args_to ] = ui_default_values
return script_args
2023-12-30 18:33:18 +08:00
def init_script_args ( self , request , default_script_args , selectable_scripts , selectable_idx , script_runner , * , input_script_args = None ) :
2023-03-26 02:16:35 +08:00
script_args = default_script_args . copy ( )
2023-12-30 18:33:18 +08:00
if input_script_args is not None :
for index , value in input_script_args . items ( ) :
script_args [ index ] = value
2023-02-28 12:27:33 +08:00
# position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
if selectable_scripts :
script_args [ selectable_scripts . args_from : selectable_scripts . args_to ] = request . script_args
script_args [ 0 ] = selectable_idx + 1
2023-02-27 08:17:58 +08:00
# Now check for always on scripts
2023-06-02 19:58:10 +08:00
if request . alwayson_scripts :
2023-03-12 01:21:33 +08:00
for alwayson_script_name in request . alwayson_scripts . keys ( ) :
2023-02-27 08:17:58 +08:00
alwayson_script = self . get_script ( alwayson_script_name , script_runner )
2023-05-10 12:52:45 +08:00
if alwayson_script is None :
2023-02-27 08:17:58 +08:00
raise HTTPException ( status_code = 422 , detail = f " always on script { alwayson_script_name } not found " )
# Selectable script in always on script param check
2023-05-10 12:52:45 +08:00
if alwayson_script . alwayson is False :
raise HTTPException ( status_code = 422 , detail = " Cannot have a selectable script in the always on scripts params " )
2023-03-12 01:21:33 +08:00
# always on script with no arg should always run so you don't really need to add them to the requests
if " args " in request . alwayson_scripts [ alwayson_script_name ] :
2023-03-29 11:52:51 +08:00
# min between arg length in scriptrunner and arg length in the request
for idx in range ( 0 , min ( ( alwayson_script . args_to - alwayson_script . args_from ) , len ( request . alwayson_scripts [ alwayson_script_name ] [ " args " ] ) ) ) :
script_args [ alwayson_script . args_from + idx ] = request . alwayson_scripts [ alwayson_script_name ] [ " args " ] [ idx ]
2023-02-28 12:27:33 +08:00
return script_args
2023-12-30 18:33:18 +08:00
def apply_infotext ( self , request , tabname , * , script_runner = None , mentioned_script_args = None ) :
2024-03-04 14:37:23 +08:00
""" Processes `infotext` field from the `request`, and sets other fields of the `request` according to what ' s in infotext.
2023-12-30 18:48:25 +08:00
If request already has a field set , and that field is encountered in infotext too , the value from infotext is ignored .
Additionally , fills ` mentioned_script_args ` dict with index : value pairs for script arguments read from infotext .
"""
2023-12-17 15:22:03 +08:00
if not request . infotext :
return { }
2024-01-01 22:25:30 +08:00
possible_fields = infotext_utils . paste_fields [ tabname ] [ " fields " ]
2024-06-21 09:52:02 +08:00
set_fields = request . model_dump ( exclude_unset = True ) if hasattr ( request , " request " ) else request . dict ( exclude_unset = True ) # pydantic v1/v2 have different names for this
2024-01-01 22:25:30 +08:00
params = infotext_utils . parse_generation_parameters ( request . infotext )
2023-12-17 15:22:03 +08:00
2023-12-30 18:33:18 +08:00
def get_field_value ( field , params ) :
2023-12-17 15:22:03 +08:00
value = field . function ( params ) if field . function else params . get ( field . label )
if value is None :
2023-12-30 18:33:18 +08:00
return None
if field . api in request . __fields__ :
target_type = request . __fields__ [ field . api ] . type_
else :
target_type = type ( field . component . value )
if target_type == type ( None ) :
return None
2023-12-17 15:22:03 +08:00
2023-12-30 18:48:25 +08:00
if isinstance ( value , dict ) and value . get ( ' __type__ ' ) == ' generic_update ' : # this is a gradio.update rather than a value
value = value . get ( ' value ' )
if value is not None and not isinstance ( value , target_type ) :
2023-12-17 15:22:03 +08:00
value = target_type ( value )
2023-12-30 18:33:18 +08:00
return value
for field in possible_fields :
if not field . api :
continue
if field . api in set_fields :
continue
value = get_field_value ( field , params )
if value is not None :
setattr ( request , field . api , value )
2023-12-30 17:11:09 +08:00
if request . override_settings is None :
request . override_settings = { }
2024-03-04 14:37:23 +08:00
overridden_settings = infotext_utils . get_override_settings ( params )
for _ , setting_name , value in overridden_settings :
2023-12-30 17:11:09 +08:00
if setting_name not in request . override_settings :
request . override_settings [ setting_name ] = value
2023-12-17 15:22:03 +08:00
2023-12-30 18:33:18 +08:00
if script_runner is not None and mentioned_script_args is not None :
indexes = { v : i for i , v in enumerate ( script_runner . inputs ) }
script_fields = ( ( field , indexes [ field . component ] ) for field in possible_fields if field . component in indexes )
for field , index in script_fields :
2023-12-30 18:48:25 +08:00
value = get_field_value ( field , params )
if value is None :
continue
mentioned_script_args [ index ] = value
2023-12-30 18:33:18 +08:00
2023-12-17 15:22:03 +08:00
return params
2023-05-10 13:25:25 +08:00
def text2imgapi ( self , txt2imgreq : models . StableDiffusionTxt2ImgProcessingAPI ) :
2023-12-17 13:55:35 +08:00
task_id = txt2imgreq . force_task_id or create_task_id ( " txt2img " )
2023-02-28 12:27:33 +08:00
script_runner = scripts . scripts_txt2img
2023-12-17 15:22:03 +08:00
2023-12-30 18:33:18 +08:00
infotext_script_args = { }
self . apply_infotext ( txt2imgreq , " txt2img " , script_runner = script_runner , mentioned_script_args = infotext_script_args )
2023-12-17 15:22:03 +08:00
2023-02-28 12:27:33 +08:00
selectable_scripts , selectable_script_idx = self . get_selectable_script ( txt2imgreq . script_name , script_runner )
2024-04-30 15:53:41 +08:00
sampler , scheduler = sd_samplers . get_sampler_and_scheduler ( txt2imgreq . sampler_name or txt2imgreq . sampler_index , txt2imgreq . scheduler )
2023-02-28 12:27:33 +08:00
2023-03-11 18:22:59 +08:00
populate = txt2imgreq . copy ( update = { # Override __init__ params
2024-04-29 12:36:43 +08:00
" sampler_name " : validate_sampler_name ( sampler ) ,
2023-03-11 18:22:59 +08:00
" do_not_save_samples " : not txt2imgreq . save_images ,
" do_not_save_grid " : not txt2imgreq . save_images ,
} )
2023-02-28 12:27:33 +08:00
if populate . sampler_name :
populate . sampler_index = None # prevent a warning later on
2024-04-30 15:53:41 +08:00
if not populate . scheduler and scheduler != " Automatic " :
2024-04-29 12:36:43 +08:00
populate . scheduler = scheduler
2023-02-28 12:27:33 +08:00
args = vars ( populate )
args . pop ( ' script_name ' , None )
args . pop ( ' script_args ' , None ) # will refeed them to the pipeline directly after initializing them
2023-03-12 01:21:33 +08:00
args . pop ( ' alwayson_scripts ' , None )
2023-12-17 15:22:03 +08:00
args . pop ( ' infotext ' , None )
2023-02-28 12:27:33 +08:00
2023-12-30 18:33:18 +08:00
script_args = self . init_script_args ( txt2imgreq , self . default_script_arg_txt2img , selectable_scripts , selectable_script_idx , script_runner , input_script_args = infotext_script_args )
2023-01-07 22:21:31 +08:00
2023-03-11 18:22:59 +08:00
send_images = args . pop ( ' send_images ' , True )
args . pop ( ' save_images ' , None )
2023-03-03 22:00:52 +08:00
2023-12-15 16:57:17 +08:00
add_task_to_queue ( task_id )
2022-10-18 14:51:53 +08:00
with self . queue_lock :
2023-07-04 01:17:47 +08:00
with closing ( StableDiffusionProcessingTxt2Img ( sd_model = shared . sd_model , * * args ) ) as p :
2023-08-14 15:43:18 +08:00
p . is_api = True
2023-07-04 01:02:30 +08:00
p . scripts = script_runner
p . outpath_grids = opts . outdir_txt2img_grids
p . outpath_samples = opts . outdir_txt2img_samples
2023-07-22 12:03:21 +08:00
try :
shared . state . begin ( job = " scripts_txt2img " )
2023-12-15 16:57:17 +08:00
start_task ( task_id )
2023-07-22 12:03:21 +08:00
if selectable_scripts is not None :
p . script_args = script_args
processed = scripts . scripts_txt2img . run ( p , * p . script_args ) # Need to pass args as list here
else :
p . script_args = tuple ( script_args ) # Need to pass args as tuple here
processed = process_images ( p )
2023-12-15 16:57:17 +08:00
finish_task ( task_id )
2023-07-22 12:03:21 +08:00
finally :
shared . state . end ( )
2023-07-28 11:40:10 +08:00
shared . total_tqdm . clear ( )
2022-10-17 14:58:42 +08:00
2023-03-03 22:00:52 +08:00
b64images = list ( map ( encode_pil_to_base64 , processed . images ) ) if send_images else [ ]
2022-10-26 22:33:45 +08:00
2023-05-10 13:25:25 +08:00
return models . TextToImageResponse ( images = b64images , parameters = vars ( txt2imgreq ) , info = processed . js ( ) )
2022-10-17 14:58:42 +08:00
2023-05-10 13:25:25 +08:00
def img2imgapi ( self , img2imgreq : models . StableDiffusionImg2ImgProcessingAPI ) :
2023-12-17 13:55:35 +08:00
task_id = img2imgreq . force_task_id or create_task_id ( " img2img " )
2023-12-15 16:57:17 +08:00
2022-10-22 07:27:40 +08:00
init_images = img2imgreq . init_images
if init_images is None :
2022-10-26 22:33:45 +08:00
raise HTTPException ( status_code = 404 , detail = " Init image not found " )
2022-10-22 07:27:40 +08:00
2022-10-23 03:42:00 +08:00
mask = img2imgreq . mask
if mask :
2022-11-24 13:10:40 +08:00
mask = decode_base64_to_image ( mask )
2022-10-23 03:42:00 +08:00
2023-02-27 08:17:58 +08:00
script_runner = scripts . scripts_img2img
2023-12-30 18:34:46 +08:00
infotext_script_args = { }
self . apply_infotext ( img2imgreq , " img2img " , script_runner = script_runner , mentioned_script_args = infotext_script_args )
2023-02-28 12:27:33 +08:00
selectable_scripts , selectable_script_idx = self . get_selectable_script ( img2imgreq . script_name , script_runner )
2024-04-30 15:53:41 +08:00
sampler , scheduler = sd_samplers . get_sampler_and_scheduler ( img2imgreq . sampler_name or img2imgreq . sampler_index , img2imgreq . scheduler )
2023-02-27 08:17:58 +08:00
2023-03-12 03:34:56 +08:00
populate = img2imgreq . copy ( update = { # Override __init__ params
2024-04-29 12:36:43 +08:00
" sampler_name " : validate_sampler_name ( sampler ) ,
2023-03-11 18:22:59 +08:00
" do_not_save_samples " : not img2imgreq . save_images ,
" do_not_save_grid " : not img2imgreq . save_images ,
" mask " : mask ,
} )
2022-11-27 21:19:47 +08:00
if populate . sampler_name :
populate . sampler_index = None # prevent a warning later on
2022-12-03 14:15:24 +08:00
2024-04-30 15:53:41 +08:00
if not populate . scheduler and scheduler != " Automatic " :
2024-04-29 12:36:43 +08:00
populate . scheduler = scheduler
2022-12-03 14:15:24 +08:00
args = vars ( populate )
args . pop ( ' include_init_images ' , None ) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
2023-01-06 05:21:48 +08:00
args . pop ( ' script_name ' , None )
2023-02-28 12:27:33 +08:00
args . pop ( ' script_args ' , None ) # will refeed them to the pipeline directly after initializing them
2023-03-12 01:21:33 +08:00
args . pop ( ' alwayson_scripts ' , None )
2023-12-30 23:02:51 +08:00
args . pop ( ' infotext ' , None )
2023-02-27 08:17:58 +08:00
2023-12-30 18:34:46 +08:00
script_args = self . init_script_args ( img2imgreq , self . default_script_arg_img2img , selectable_scripts , selectable_script_idx , script_runner , input_script_args = infotext_script_args )
2022-10-30 14:10:22 +08:00
2023-03-11 18:22:59 +08:00
send_images = args . pop ( ' send_images ' , True )
args . pop ( ' save_images ' , None )
2023-03-03 22:00:52 +08:00
2023-12-15 16:57:17 +08:00
add_task_to_queue ( task_id )
2022-10-22 07:27:40 +08:00
with self . queue_lock :
2023-07-04 01:17:47 +08:00
with closing ( StableDiffusionProcessingImg2Img ( sd_model = shared . sd_model , * * args ) ) as p :
2023-07-04 01:02:30 +08:00
p . init_images = [ decode_base64_to_image ( x ) for x in init_images ]
2023-08-14 15:43:18 +08:00
p . is_api = True
2023-07-04 01:02:30 +08:00
p . scripts = script_runner
p . outpath_grids = opts . outdir_img2img_grids
p . outpath_samples = opts . outdir_img2img_samples
2023-07-22 12:03:21 +08:00
try :
shared . state . begin ( job = " scripts_img2img " )
2023-12-15 16:57:17 +08:00
start_task ( task_id )
2023-07-22 12:03:21 +08:00
if selectable_scripts is not None :
p . script_args = script_args
processed = scripts . scripts_img2img . run ( p , * p . script_args ) # Need to pass args as list here
else :
p . script_args = tuple ( script_args ) # Need to pass args as tuple here
processed = process_images ( p )
2023-12-15 16:57:17 +08:00
finish_task ( task_id )
2023-07-22 12:03:21 +08:00
finally :
shared . state . end ( )
2023-07-28 11:40:10 +08:00
shared . total_tqdm . clear ( )
2022-10-26 22:33:45 +08:00
2023-03-03 22:00:52 +08:00
b64images = list ( map ( encode_pil_to_base64 , processed . images ) ) if send_images else [ ]
2022-10-22 07:27:40 +08:00
2022-12-03 14:15:24 +08:00
if not img2imgreq . include_init_images :
2022-10-24 23:16:07 +08:00
img2imgreq . init_images = None
img2imgreq . mask = None
2023-05-10 13:25:25 +08:00
return models . ImageToImageResponse ( images = b64images , parameters = vars ( img2imgreq ) , info = processed . js ( ) )
2022-10-17 14:58:42 +08:00
2023-05-10 13:25:25 +08:00
def extras_single_image_api ( self , req : models . ExtrasSingleImageRequest ) :
2022-10-24 03:01:16 +08:00
reqDict = setUpscalers ( req )
2022-10-23 10:13:32 +08:00
2022-10-24 03:01:16 +08:00
reqDict [ ' image ' ] = decode_base64_to_image ( reqDict [ ' image ' ] )
2022-10-23 10:13:32 +08:00
with self . queue_lock :
2023-01-23 14:24:43 +08:00
result = postprocessing . run_extras ( extras_mode = 0 , image_folder = " " , input_dir = " " , output_dir = " " , save_output = False , * * reqDict )
2022-10-23 10:13:32 +08:00
2023-05-10 13:25:25 +08:00
return models . ExtrasSingleImageResponse ( image = encode_pil_to_base64 ( result [ 0 ] [ 0 ] ) , html_info = result [ 1 ] )
2022-10-24 00:07:59 +08:00
2023-05-10 13:25:25 +08:00
def extras_batch_images_api ( self , req : models . ExtrasBatchImagesRequest ) :
2022-10-24 03:01:16 +08:00
reqDict = setUpscalers ( req )
2022-10-24 00:07:59 +08:00
2023-04-29 14:17:35 +08:00
image_list = reqDict . pop ( ' imageList ' , [ ] )
image_folder = [ decode_base64_to_image ( x . data ) for x in image_list ]
2022-10-24 00:07:59 +08:00
with self . queue_lock :
2023-04-29 14:17:35 +08:00
result = postprocessing . run_extras ( extras_mode = 1 , image_folder = image_folder , image = " " , input_dir = " " , output_dir = " " , save_output = False , * * reqDict )
2022-10-24 00:07:59 +08:00
2023-05-10 13:25:25 +08:00
return models . ExtrasBatchImagesResponse ( images = list ( map ( encode_pil_to_base64 , result [ 0 ] ) ) , html_info = result [ 1 ] )
2022-10-17 14:58:42 +08:00
2023-05-10 13:25:25 +08:00
def pnginfoapi ( self , req : models . PNGInfoRequest ) :
2023-01-05 04:36:30 +08:00
image = decode_base64_to_image ( req . image . strip ( ) )
if image is None :
2023-05-10 13:25:25 +08:00
return models . PNGInfoResponse ( info = " " )
2023-01-05 04:36:30 +08:00
geninfo , items = images . read_info_from_image ( image )
if geninfo is None :
geninfo = " "
2024-01-01 22:25:30 +08:00
params = infotext_utils . parse_generation_parameters ( geninfo )
2023-08-26 11:52:18 +08:00
script_callbacks . infotext_pasted_callback ( geninfo , params )
2022-10-30 03:09:19 +08:00
2023-08-26 11:52:18 +08:00
return models . PNGInfoResponse ( info = geninfo , items = items , parameters = params )
2022-10-17 14:58:42 +08:00
2023-05-10 13:25:25 +08:00
def progressapi ( self , req : models . ProgressRequest = Depends ( ) ) :
2022-10-26 22:33:45 +08:00
# copy from check_progress_call of ui.py
if shared . state . job_count == 0 :
2023-05-10 13:25:25 +08:00
return models . ProgressResponse ( progress = 0 , eta_relative = 0 , state = shared . state . dict ( ) , textinfo = shared . state . textinfo )
2022-10-26 22:33:45 +08:00
# avoid dividing zero
progress = 0.01
if shared . state . job_count > 0 :
progress + = shared . state . job_no / shared . state . job_count
if shared . state . sampling_steps > 0 :
progress + = 1 / shared . state . job_count * shared . state . sampling_step / shared . state . sampling_steps
time_since_start = time . time ( ) - shared . state . time_start
eta = ( time_since_start / progress )
eta_relative = eta - time_since_start
progress = min ( progress , 1 )
2022-11-02 17:12:32 +08:00
shared . state . set_current_image ( )
2022-10-30 17:02:47 +08:00
2022-10-30 05:19:17 +08:00
current_image = None
2022-10-30 06:03:32 +08:00
if shared . state . current_image and not req . skip_current_image :
2022-10-30 05:19:17 +08:00
current_image = encode_pil_to_base64 ( shared . state . current_image )
2023-12-15 16:57:17 +08:00
return models . ProgressResponse ( progress = progress , eta_relative = eta_relative , state = shared . state . dict ( ) , current_image = current_image , textinfo = shared . state . textinfo , current_task = current_task )
2022-10-26 22:33:45 +08:00
2023-05-10 13:25:25 +08:00
def interrogateapi ( self , interrogatereq : models . InterrogateRequest ) :
2022-10-28 03:20:15 +08:00
image_b64 = interrogatereq . image
if image_b64 is None :
2022-12-15 10:01:32 +08:00
raise HTTPException ( status_code = 404 , detail = " Image not found " )
2022-10-28 03:20:15 +08:00
2022-11-07 02:32:06 +08:00
img = decode_base64_to_image ( image_b64 )
img = img . convert ( ' RGB ' )
2022-10-28 03:20:15 +08:00
# Override object param
with self . queue_lock :
2022-11-07 02:32:06 +08:00
if interrogatereq . model == " clip " :
processed = shared . interrogator . interrogate ( img )
elif interrogatereq . model == " deepdanbooru " :
2022-11-20 21:39:20 +08:00
processed = deepbooru . model . tag ( img )
2022-11-07 02:32:06 +08:00
else :
raise HTTPException ( status_code = 404 , detail = " Model not found " )
2022-12-15 10:01:32 +08:00
2023-05-10 13:25:25 +08:00
return models . InterrogateResponse ( caption = processed )
2022-10-17 14:58:42 +08:00
2022-10-30 18:08:40 +08:00
def interruptapi ( self ) :
shared . state . interrupt ( )
return { }
2023-03-09 12:56:19 +08:00
def unloadapi ( self ) :
2023-10-15 14:41:02 +08:00
sd_models . unload_model_weights ( )
2023-03-09 12:56:19 +08:00
return { }
def reloadapi ( self ) :
2023-10-15 14:41:02 +08:00
sd_models . send_model_to_device ( shared . sd_model )
2023-03-09 12:56:19 +08:00
return { }
2022-11-06 06:05:15 +08:00
def skip ( self ) :
shared . state . skip ( )
2022-11-03 11:51:22 +08:00
def get_config ( self ) :
options = { }
for key in shared . opts . data . keys ( ) :
metadata = shared . opts . data_labels . get ( key )
if ( metadata is not None ) :
options . update ( { key : shared . opts . data . get ( key , shared . opts . data_labels . get ( key ) . default ) } )
else :
options . update ( { key : shared . opts . data . get ( key , None ) } )
2022-11-05 01:43:02 +08:00
2022-11-03 11:51:22 +08:00
return options
2022-11-05 01:43:02 +08:00
2023-08-25 15:58:19 +08:00
def set_config ( self , req : dict [ str , Any ] ) :
2023-06-27 14:26:18 +08:00
checkpoint_name = req . get ( " sd_model_checkpoint " , None )
2023-10-15 14:41:02 +08:00
if checkpoint_name is not None and checkpoint_name not in sd_models . checkpoint_aliases :
2023-06-27 14:26:18 +08:00
raise RuntimeError ( f " model { checkpoint_name !r} not found " )
2023-06-12 15:22:49 +08:00
2022-11-19 20:15:24 +08:00
for k , v in req . items ( ) :
2023-08-21 12:59:57 +08:00
shared . opts . set ( k , v , is_api = True )
2022-11-03 11:51:22 +08:00
shared . opts . save ( shared . config_filename )
return
def get_cmd_flags ( self ) :
return vars ( shared . cmd_opts )
def get_samplers ( self ) :
2022-11-19 20:15:24 +08:00
return [ { " name " : sampler [ 0 ] , " aliases " : sampler [ 2 ] , " options " : sampler [ 3 ] } for sampler in sd_samplers . all_samplers ]
2022-11-03 11:51:22 +08:00
2024-04-20 10:29:22 +08:00
def get_schedulers ( self ) :
return [
{
" name " : scheduler . name ,
" label " : scheduler . label ,
" aliases " : scheduler . aliases ,
" default_rho " : scheduler . default_rho ,
" need_inner_model " : scheduler . need_inner_model ,
}
for scheduler in sd_schedulers . schedulers ]
2022-11-03 11:51:22 +08:00
def get_upscalers ( self ) :
2023-01-24 15:05:45 +08:00
return [
{
" name " : upscaler . name ,
" model_name " : upscaler . scaler . model_name ,
" model_path " : upscaler . data_path ,
2023-01-24 15:09:30 +08:00
" model_url " : None ,
2023-01-24 15:05:45 +08:00
" scale " : upscaler . scale ,
}
for upscaler in shared . sd_upscalers
]
2022-11-05 01:43:02 +08:00
2023-06-04 21:59:23 +08:00
def get_latent_upscale_modes ( self ) :
return [
{
" name " : upscale_mode ,
}
for upscale_mode in [ * ( shared . latent_upscale_modes or { } ) ]
]
2022-11-03 11:51:22 +08:00
def get_sd_models ( self ) :
2023-08-18 09:48:17 +08:00
import modules . sd_models as sd_models
return [ { " title " : x . title , " model_name " : x . model_name , " hash " : x . shorthash , " sha256 " : x . sha256 , " filename " : x . filename , " config " : find_checkpoint_config_near_filename ( x ) } for x in sd_models . checkpoints_list . values ( ) ]
2022-11-03 11:51:22 +08:00
2023-05-30 05:25:43 +08:00
def get_sd_vaes ( self ) :
2023-08-18 09:48:17 +08:00
import modules . sd_vae as sd_vae
return [ { " model_name " : x , " filename " : sd_vae . vae_dict [ x ] } for x in sd_vae . vae_dict . keys ( ) ]
2022-11-03 11:51:22 +08:00
def get_hypernetworks ( self ) :
return [ { " name " : name , " path " : shared . hypernetworks [ name ] } for name in shared . hypernetworks ]
def get_face_restorers ( self ) :
return [ { " name " : x . name ( ) , " cmd_dir " : getattr ( x , " cmd_dir " , None ) } for x in shared . face_restorers ]
def get_realesrgan_models ( self ) :
return [ { " name " : x . name , " path " : x . data_path , " scale " : x . scale } for x in get_realesrgan_models ( None ) ]
2022-11-05 01:43:02 +08:00
2022-12-15 10:01:32 +08:00
def get_prompt_styles ( self ) :
2022-11-03 11:51:22 +08:00
styleList = [ ]
for k in shared . prompt_styles . styles :
2022-11-05 01:43:02 +08:00
style = shared . prompt_styles . styles [ k ]
2022-11-22 22:02:59 +08:00
styleList . append ( { " name " : style [ 0 ] , " prompt " : style [ 1 ] , " negative_prompt " : style [ 2 ] } )
2022-11-03 11:51:22 +08:00
return styleList
2023-01-02 07:17:33 +08:00
def get_embeddings ( self ) :
db = sd_hijack . model_hijack . embedding_db
2023-01-02 09:21:22 +08:00
def convert_embedding ( embedding ) :
return {
" step " : embedding . step ,
" sd_checkpoint " : embedding . sd_checkpoint ,
" sd_checkpoint_name " : embedding . sd_checkpoint_name ,
" shape " : embedding . shape ,
" vectors " : embedding . vectors ,
}
def convert_embeddings ( embeddings ) :
return { embedding . name : convert_embedding ( embedding ) for embedding in embeddings . values ( ) }
2023-01-02 07:17:33 +08:00
return {
2023-01-02 09:21:22 +08:00
" loaded " : convert_embeddings ( db . word_embeddings ) ,
" skipped " : convert_embeddings ( db . skipped_embeddings ) ,
2023-01-02 07:17:33 +08:00
}
2024-01-21 21:05:47 +08:00
def refresh_embeddings ( self ) :
with self . queue_lock :
sd_hijack . model_hijack . embedding_db . load_textual_inversion_embeddings ( force_reload = True )
2022-12-12 03:16:44 +08:00
def refresh_checkpoints ( self ) :
2023-07-10 22:10:14 +08:00
with self . queue_lock :
shared . refresh_checkpoints ( )
2022-10-30 18:08:40 +08:00
2023-07-24 19:45:08 +08:00
def refresh_vae ( self ) :
with self . queue_lock :
shared_items . refresh_vae_list ( )
2022-12-25 07:02:22 +08:00
def create_embedding ( self , args : dict ) :
try :
2023-06-30 18:11:31 +08:00
shared . state . begin ( job = " create_embedding " )
2022-12-25 07:02:22 +08:00
filename = create_embedding ( * * args ) # create empty embedding
sd_hijack . model_hijack . embedding_db . load_textual_inversion_embeddings ( ) # reload embeddings so new one can be immediately used
2023-05-10 13:25:25 +08:00
return models . CreateResponse ( info = f " create embedding filename: { filename } " )
2022-12-25 07:02:22 +08:00
except AssertionError as e :
2023-05-10 13:25:25 +08:00
return models . TrainResponse ( info = f " create embedding error: { e } " )
2023-06-30 18:11:49 +08:00
finally :
shared . state . end ( )
2022-12-25 07:02:22 +08:00
def create_hypernetwork ( self , args : dict ) :
try :
2023-06-30 18:11:31 +08:00
shared . state . begin ( job = " create_hypernetwork " )
2022-12-25 07:02:22 +08:00
filename = create_hypernetwork ( * * args ) # create empty embedding
2023-05-10 13:25:25 +08:00
return models . CreateResponse ( info = f " create hypernetwork filename: { filename } " )
2022-12-25 07:02:22 +08:00
except AssertionError as e :
2023-05-10 13:25:25 +08:00
return models . TrainResponse ( info = f " create hypernetwork error: { e } " )
2023-06-30 18:11:49 +08:00
finally :
shared . state . end ( )
2022-12-25 07:02:22 +08:00
def train_embedding ( self , args : dict ) :
try :
2023-06-30 18:11:31 +08:00
shared . state . begin ( job = " train_embedding " )
2022-12-25 07:02:22 +08:00
apply_optimizations = shared . opts . training_xattention_optimizations
error = None
filename = ' '
if not apply_optimizations :
sd_hijack . undo_optimizations ( )
try :
embedding , filename = train_embedding ( * * args ) # can take a long time to complete
except Exception as e :
error = e
finally :
if not apply_optimizations :
sd_hijack . apply_optimizations ( )
2023-05-10 13:25:25 +08:00
return models . TrainResponse ( info = f " train embedding complete: filename: { filename } error: { error } " )
2023-06-30 18:11:49 +08:00
except Exception as msg :
2023-05-10 13:25:25 +08:00
return models . TrainResponse ( info = f " train embedding error: { msg } " )
2023-06-30 18:11:49 +08:00
finally :
shared . state . end ( )
2022-12-25 07:02:22 +08:00
def train_hypernetwork ( self , args : dict ) :
try :
2023-06-30 18:11:31 +08:00
shared . state . begin ( job = " train_hypernetwork " )
2023-01-21 13:36:07 +08:00
shared . loaded_hypernetworks = [ ]
2022-12-25 07:02:22 +08:00
apply_optimizations = shared . opts . training_xattention_optimizations
error = None
filename = ' '
if not apply_optimizations :
sd_hijack . undo_optimizations ( )
try :
2023-02-10 16:58:35 +08:00
hypernetwork , filename = train_hypernetwork ( * * args )
2022-12-25 07:02:22 +08:00
except Exception as e :
error = e
finally :
shared . sd_model . cond_stage_model . to ( devices . device )
shared . sd_model . first_stage_model . to ( devices . device )
if not apply_optimizations :
sd_hijack . apply_optimizations ( )
shared . state . end ( )
2023-05-10 13:25:25 +08:00
return models . TrainResponse ( info = f " train embedding complete: filename: { filename } error: { error } " )
2023-06-30 18:11:49 +08:00
except Exception as exc :
return models . TrainResponse ( info = f " train embedding error: { exc } " )
finally :
2022-12-25 07:02:22 +08:00
shared . state . end ( )
2023-01-07 20:51:35 +08:00
def get_memory ( self ) :
try :
2023-05-10 13:25:25 +08:00
import os
import psutil
2023-01-07 20:51:35 +08:00
process = psutil . Process ( os . getpid ( ) )
2023-01-10 05:54:12 +08:00
res = process . memory_info ( ) # only rss is cross-platform guaranteed so we dont rely on other values
ram_total = 100 * res . rss / process . memory_percent ( ) # and total memory is calculated as actual value is not cross-platform safe
ram = { ' free ' : ram_total - res . rss , ' used ' : res . rss , ' total ' : ram_total }
2023-01-07 20:51:35 +08:00
except Exception as err :
ram = { ' error ' : f ' { err } ' }
try :
import torch
if torch . cuda . is_available ( ) :
s = torch . cuda . mem_get_info ( )
2023-01-10 05:54:12 +08:00
system = { ' free ' : s [ 0 ] , ' used ' : s [ 1 ] - s [ 0 ] , ' total ' : s [ 1 ] }
2023-01-07 20:51:35 +08:00
s = dict ( torch . cuda . memory_stats ( shared . device ) )
2023-01-10 05:54:12 +08:00
allocated = { ' current ' : s [ ' allocated_bytes.all.current ' ] , ' peak ' : s [ ' allocated_bytes.all.peak ' ] }
reserved = { ' current ' : s [ ' reserved_bytes.all.current ' ] , ' peak ' : s [ ' reserved_bytes.all.peak ' ] }
active = { ' current ' : s [ ' active_bytes.all.current ' ] , ' peak ' : s [ ' active_bytes.all.peak ' ] }
inactive = { ' current ' : s [ ' inactive_split_bytes.all.current ' ] , ' peak ' : s [ ' inactive_split_bytes.all.peak ' ] }
2023-01-07 20:51:35 +08:00
warnings = { ' retries ' : s [ ' num_alloc_retries ' ] , ' oom ' : s [ ' num_ooms ' ] }
cuda = {
' system ' : system ,
' active ' : active ,
' allocated ' : allocated ,
' reserved ' : reserved ,
' inactive ' : inactive ,
' events ' : warnings ,
}
else :
2023-05-10 13:25:25 +08:00
cuda = { ' error ' : ' unavailable ' }
2023-01-07 20:51:35 +08:00
except Exception as err :
2023-05-10 13:25:25 +08:00
cuda = { ' error ' : f ' { err } ' }
return models . MemoryResponse ( ram = ram , cuda = cuda )
2023-08-25 22:23:17 +08:00
2023-08-25 22:15:35 +08:00
def get_extensions_list ( self ) :
from modules import extensions
extensions . list_extensions ( )
ext_list = [ ]
for ext in extensions . extensions :
ext : extensions . Extension
ext . read_info_from_repo ( )
if ext . remote is not None :
ext_list . append ( {
" name " : ext . name ,
" remote " : ext . remote ,
" branch " : ext . branch ,
" commit_hash " : ext . commit_hash ,
" commit_date " : ext . commit_date ,
" version " : ext . version ,
" enabled " : ext . enabled
} )
return ext_list
2023-01-07 20:51:35 +08:00
2023-07-25 20:19:10 +08:00
def launch ( self , server_name , port , root_path ) :
2022-10-18 14:51:53 +08:00
self . app . include_router ( self . router )
2024-01-09 17:01:20 +08:00
uvicorn . run (
self . app ,
host = server_name ,
port = port ,
timeout_keep_alive = shared . cmd_opts . timeout_keep_alive ,
root_path = root_path ,
ssl_keyfile = shared . cmd_opts . tls_keyfile ,
ssl_certfile = shared . cmd_opts . tls_certfile
)
2023-06-10 22:36:34 +08:00
2023-06-14 17:51:47 +08:00
def kill_webui ( self ) :
2023-06-10 22:36:34 +08:00
restart . stop_program ( )
def restart_webui ( self ) :
if restart . is_restartable ( ) :
restart . restart_program ( )
2023-06-14 18:52:12 +08:00
return Response ( status_code = 501 )
2023-06-12 17:15:27 +08:00
2023-06-14 18:53:08 +08:00
def stop_webui ( request ) :
2023-06-12 17:15:27 +08:00
shared . state . server_command = " stop "
return Response ( " Stopping. " )
2023-07-13 20:21:39 +08:00