stable-diffusion-webui/modules/api/api.py

191 lines
6.3 KiB
Python
Raw Normal View History

2022-10-26 22:33:45 +08:00
import time
from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
2022-10-19 13:19:01 +08:00
from modules.sd_samplers import all_samplers
from modules.extras import run_pnginfo
import modules.shared as shared
2022-10-26 22:33:45 +08:00
from modules import devices
import uvicorn
2022-10-19 03:04:56 +08:00
from fastapi import Body, APIRouter, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, Json
from typing import List
import json
import io
import base64
from PIL import Image
2022-10-19 13:19:01 +08:00
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
2022-10-19 03:04:56 +08:00
class TextToImageResponse(BaseModel):
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json
info: Json
class ImageToImageResponse(BaseModel):
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json
info: Json
2022-10-26 22:33:45 +08:00
class ProgressResponse(BaseModel):
progress: float
eta_relative: float
state: Json
# copy from wrap_gradio_gpu_call of webui.py
# because queue lock will be acquired in api handlers
# and time start needs to be set
# the function has been modified into two parts
def before_gpu_call():
devices.torch_gc()
shared.state.sampling_step = 0
shared.state.job_count = -1
shared.state.job_no = 0
shared.state.job_timestamp = shared.state.get_job_timestamp()
shared.state.current_latent = None
shared.state.current_image = None
shared.state.current_image_sampling_step = 0
shared.state.skipped = False
shared.state.interrupted = False
shared.state.textinfo = None
shared.state.time_start = time.time()
def after_gpu_call():
shared.state.job = ""
shared.state.job_count = 0
devices.torch_gc()
class Api:
2022-10-18 14:51:53 +08:00
def __init__(self, app, queue_lock):
self.router = APIRouter()
2022-10-18 14:51:53 +08:00
self.app = app
self.queue_lock = queue_lock
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"])
2022-10-26 22:33:45 +08:00
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"])
2022-10-23 05:10:28 +08:00
def __base64_to_image(self, base64_string):
# if has a comma, deal with prefix
if "," in base64_string:
base64_string = base64_string.split(",")[1]
imgdata = base64.b64decode(base64_string)
# convert base64 to PIL image
return Image.open(io.BytesIO(imgdata))
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
2022-10-19 03:04:56 +08:00
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
2022-10-26 22:33:45 +08:00
2022-10-19 03:04:56 +08:00
if sampler_index is None:
2022-10-26 22:33:45 +08:00
raise HTTPException(status_code=404, detail="Sampler not found")
2022-10-18 03:10:36 +08:00
populate = txt2imgreq.copy(update={ # Override __init__ params
2022-10-26 22:33:45 +08:00
"sd_model": shared.sd_model,
2022-10-19 03:04:56 +08:00
"sampler_index": sampler_index[0],
2022-10-18 04:36:14 +08:00
"do_not_save_samples": True,
"do_not_save_grid": True
2022-10-18 03:10:36 +08:00
}
)
p = StableDiffusionProcessingTxt2Img(**vars(populate))
# Override object param
2022-10-26 22:33:45 +08:00
before_gpu_call()
2022-10-18 14:51:53 +08:00
with self.queue_lock:
processed = process_images(p)
2022-10-26 22:33:45 +08:00
after_gpu_call()
b64images = []
for i in processed.images:
buffer = io.BytesIO()
i.save(buffer, format="png")
b64images.append(base64.b64encode(buffer.getvalue()))
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.js())
2022-10-26 22:33:45 +08:00
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
sampler_index = sampler_to_index(img2imgreq.sampler_index)
2022-10-26 22:33:45 +08:00
if sampler_index is None:
2022-10-26 22:33:45 +08:00
raise HTTPException(status_code=404, detail="Sampler not found")
init_images = img2imgreq.init_images
if init_images is None:
2022-10-26 22:33:45 +08:00
raise HTTPException(status_code=404, detail="Init image not found")
2022-10-23 03:42:00 +08:00
mask = img2imgreq.mask
if mask:
2022-10-23 05:10:28 +08:00
mask = self.__base64_to_image(mask)
2022-10-23 03:42:00 +08:00
2022-10-26 22:33:45 +08:00
populate = img2imgreq.copy(update={ # Override __init__ params
2022-10-26 22:33:45 +08:00
"sd_model": shared.sd_model,
"sampler_index": sampler_index[0],
"do_not_save_samples": True,
2022-10-26 22:33:45 +08:00
"do_not_save_grid": True,
2022-10-23 05:10:28 +08:00
"mask": mask
}
)
p = StableDiffusionProcessingImg2Img(**vars(populate))
imgs = []
for img in init_images:
2022-10-23 05:10:28 +08:00
img = self.__base64_to_image(img)
imgs = [img] * p.batch_size
p.init_images = imgs
# Override object param
2022-10-26 22:33:45 +08:00
before_gpu_call()
with self.queue_lock:
processed = process_images(p)
2022-10-26 22:33:45 +08:00
after_gpu_call()
b64images = []
for i in processed.images:
buffer = io.BytesIO()
i.save(buffer, format="png")
b64images.append(base64.b64encode(buffer.getvalue()))
if (not img2imgreq.include_init_images):
img2imgreq.init_images = None
img2imgreq.mask = None
return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js())
2022-10-26 22:33:45 +08:00
def progressapi(self):
# copy from check_progress_call of ui.py
if shared.state.job_count == 0:
return ProgressResponse(progress=0, eta_relative=0, state=shared.state.js())
# avoid dividing zero
progress = 0.01
if shared.state.job_count > 0:
progress += shared.state.job_no / shared.state.job_count
if shared.state.sampling_steps > 0:
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
time_since_start = time.time() - shared.state.time_start
eta = (time_since_start/progress)
eta_relative = eta-time_since_start
progress = min(progress, 1)
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.js())
2022-10-19 13:19:01 +08:00
def extrasapi(self):
raise NotImplementedError
2022-10-19 13:19:01 +08:00
def pnginfoapi(self):
raise NotImplementedError
def launch(self, server_name, port):
2022-10-18 14:51:53 +08:00
self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port)