mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-01 20:35:06 +08:00
commit
fac92610d2
@ -82,8 +82,8 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
|
|||||||
- Use VAEs
|
- Use VAEs
|
||||||
- Estimated completion time in progress bar
|
- Estimated completion time in progress bar
|
||||||
- API
|
- API
|
||||||
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
|
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
|
||||||
- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
|
- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embeds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
|
||||||
- [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions
|
- [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions
|
||||||
|
|
||||||
## Installation and Running
|
## Installation and Running
|
||||||
|
@ -9,7 +9,7 @@ contextMenuInit = function(){
|
|||||||
|
|
||||||
function showContextMenu(event,element,menuEntries){
|
function showContextMenu(event,element,menuEntries){
|
||||||
let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft;
|
let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft;
|
||||||
let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop;
|
let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop;
|
||||||
|
|
||||||
let oldMenu = gradioApp().querySelector('#context-menu')
|
let oldMenu = gradioApp().querySelector('#context-menu')
|
||||||
if(oldMenu){
|
if(oldMenu){
|
||||||
@ -61,15 +61,15 @@ contextMenuInit = function(){
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function appendContextMenuOption(targetEmementSelector,entryName,entryFunction){
|
function appendContextMenuOption(targetElementSelector,entryName,entryFunction){
|
||||||
|
|
||||||
currentItems = menuSpecs.get(targetEmementSelector)
|
currentItems = menuSpecs.get(targetElementSelector)
|
||||||
|
|
||||||
if(!currentItems){
|
if(!currentItems){
|
||||||
currentItems = []
|
currentItems = []
|
||||||
menuSpecs.set(targetEmementSelector,currentItems);
|
menuSpecs.set(targetElementSelector,currentItems);
|
||||||
}
|
}
|
||||||
let newItem = {'id':targetEmementSelector+'_'+uid(),
|
let newItem = {'id':targetElementSelector+'_'+uid(),
|
||||||
'name':entryName,
|
'name':entryName,
|
||||||
'func':entryFunction,
|
'func':entryFunction,
|
||||||
'isNew':true}
|
'isNew':true}
|
||||||
@ -97,7 +97,7 @@ contextMenuInit = function(){
|
|||||||
if(source.id && source.id.indexOf('check_progress')>-1){
|
if(source.id && source.id.indexOf('check_progress')>-1){
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
let oldMenu = gradioApp().querySelector('#context-menu')
|
let oldMenu = gradioApp().querySelector('#context-menu')
|
||||||
if(oldMenu){
|
if(oldMenu){
|
||||||
oldMenu.remove()
|
oldMenu.remove()
|
||||||
@ -117,7 +117,7 @@ contextMenuInit = function(){
|
|||||||
})
|
})
|
||||||
});
|
});
|
||||||
eventListenerApplied=true
|
eventListenerApplied=true
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener]
|
return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener]
|
||||||
@ -152,8 +152,8 @@ addContextMenuEventListener = initResponse[2];
|
|||||||
generateOnRepeat('#img2img_generate','#img2img_interrupt');
|
generateOnRepeat('#img2img_generate','#img2img_interrupt');
|
||||||
})
|
})
|
||||||
|
|
||||||
let cancelGenerateForever = function(){
|
let cancelGenerateForever = function(){
|
||||||
clearInterval(window.generateOnRepeatInterval)
|
clearInterval(window.generateOnRepeatInterval)
|
||||||
}
|
}
|
||||||
|
|
||||||
appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever)
|
appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever)
|
||||||
@ -162,7 +162,7 @@ addContextMenuEventListener = initResponse[2];
|
|||||||
appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever)
|
appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever)
|
||||||
|
|
||||||
appendContextMenuOption('#roll','Roll three',
|
appendContextMenuOption('#roll','Roll three',
|
||||||
function(){
|
function(){
|
||||||
let rollbutton = get_uiCurrentTabContent().querySelector('#roll');
|
let rollbutton = get_uiCurrentTabContent().querySelector('#roll');
|
||||||
setTimeout(function(){rollbutton.click()},100)
|
setTimeout(function(){rollbutton.click()},100)
|
||||||
setTimeout(function(){rollbutton.click()},200)
|
setTimeout(function(){rollbutton.click()},200)
|
||||||
|
@ -3,7 +3,7 @@ global_progressbars = {}
|
|||||||
galleries = {}
|
galleries = {}
|
||||||
galleryObservers = {}
|
galleryObservers = {}
|
||||||
|
|
||||||
// this tracks laumnches of window.setTimeout for progressbar to prevent starting a new timeout when the previous is still running
|
// this tracks launches of window.setTimeout for progressbar to prevent starting a new timeout when the previous is still running
|
||||||
timeoutIds = {}
|
timeoutIds = {}
|
||||||
|
|
||||||
function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
|
function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
|
||||||
@ -20,21 +20,21 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
|
|||||||
|
|
||||||
var skip = id_skip ? gradioApp().getElementById(id_skip) : null
|
var skip = id_skip ? gradioApp().getElementById(id_skip) : null
|
||||||
var interrupt = gradioApp().getElementById(id_interrupt)
|
var interrupt = gradioApp().getElementById(id_interrupt)
|
||||||
|
|
||||||
if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){
|
if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){
|
||||||
if(progressbar.innerText){
|
if(progressbar.innerText){
|
||||||
let newtitle = '[' + progressbar.innerText.trim() + '] Stable Diffusion';
|
let newtitle = '[' + progressbar.innerText.trim() + '] Stable Diffusion';
|
||||||
if(document.title != newtitle){
|
if(document.title != newtitle){
|
||||||
document.title = newtitle;
|
document.title = newtitle;
|
||||||
}
|
}
|
||||||
}else{
|
}else{
|
||||||
let newtitle = 'Stable Diffusion'
|
let newtitle = 'Stable Diffusion'
|
||||||
if(document.title != newtitle){
|
if(document.title != newtitle){
|
||||||
document.title = newtitle;
|
document.title = newtitle;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if(progressbar!= null && progressbar != global_progressbars[id_progressbar]){
|
if(progressbar!= null && progressbar != global_progressbars[id_progressbar]){
|
||||||
global_progressbars[id_progressbar] = progressbar
|
global_progressbars[id_progressbar] = progressbar
|
||||||
|
|
||||||
@ -63,7 +63,7 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
|
|||||||
skip.style.display = "none"
|
skip.style.display = "none"
|
||||||
}
|
}
|
||||||
interrupt.style.display = "none"
|
interrupt.style.display = "none"
|
||||||
|
|
||||||
//disconnect observer once generation finished, so user can close selected image if they want
|
//disconnect observer once generation finished, so user can close selected image if they want
|
||||||
if (galleryObservers[id_gallery]) {
|
if (galleryObservers[id_gallery]) {
|
||||||
galleryObservers[id_gallery].disconnect();
|
galleryObservers[id_gallery].disconnect();
|
||||||
|
@ -100,7 +100,7 @@ function create_submit_args(args){
|
|||||||
|
|
||||||
// As it is currently, txt2img and img2img send back the previous output args (txt2img_gallery, generation_info, html_info) whenever you generate a new image.
|
// As it is currently, txt2img and img2img send back the previous output args (txt2img_gallery, generation_info, html_info) whenever you generate a new image.
|
||||||
// This can lead to uploading a huge gallery of previously generated images, which leads to an unnecessary delay between submitting and beginning to generate.
|
// This can lead to uploading a huge gallery of previously generated images, which leads to an unnecessary delay between submitting and beginning to generate.
|
||||||
// I don't know why gradio is seding outputs along with inputs, but we can prevent sending the image gallery here, which seems to be an issue for some.
|
// I don't know why gradio is sending outputs along with inputs, but we can prevent sending the image gallery here, which seems to be an issue for some.
|
||||||
// If gradio at some point stops sending outputs, this may break something
|
// If gradio at some point stops sending outputs, this may break something
|
||||||
if(Array.isArray(res[res.length - 3])){
|
if(Array.isArray(res[res.length - 3])){
|
||||||
res[res.length - 3] = null
|
res[res.length - 3] = null
|
||||||
|
@ -67,10 +67,10 @@ def encode_pil_to_base64(image):
|
|||||||
class Api:
|
class Api:
|
||||||
def __init__(self, app: FastAPI, queue_lock: Lock):
|
def __init__(self, app: FastAPI, queue_lock: Lock):
|
||||||
if shared.cmd_opts.api_auth:
|
if shared.cmd_opts.api_auth:
|
||||||
self.credenticals = dict()
|
self.credentials = dict()
|
||||||
for auth in shared.cmd_opts.api_auth.split(","):
|
for auth in shared.cmd_opts.api_auth.split(","):
|
||||||
user, password = auth.split(":")
|
user, password = auth.split(":")
|
||||||
self.credenticals[user] = password
|
self.credentials[user] = password
|
||||||
|
|
||||||
self.router = APIRouter()
|
self.router = APIRouter()
|
||||||
self.app = app
|
self.app = app
|
||||||
@ -93,7 +93,7 @@ class Api:
|
|||||||
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
|
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
|
||||||
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
|
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
|
||||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
|
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
|
||||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem])
|
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
|
||||||
self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
|
self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
|
||||||
self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
|
self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
|
||||||
|
|
||||||
@ -102,9 +102,9 @@ class Api:
|
|||||||
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
|
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
|
||||||
return self.app.add_api_route(path, endpoint, **kwargs)
|
return self.app.add_api_route(path, endpoint, **kwargs)
|
||||||
|
|
||||||
def auth(self, credenticals: HTTPBasicCredentials = Depends(HTTPBasic())):
|
def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
|
||||||
if credenticals.username in self.credenticals:
|
if credentials.username in self.credentials:
|
||||||
if compare_digest(credenticals.password, self.credenticals[credenticals.username]):
|
if compare_digest(credentials.password, self.credentials[credentials.username]):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
|
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
|
||||||
@ -239,7 +239,7 @@ class Api:
|
|||||||
def interrogateapi(self, interrogatereq: InterrogateRequest):
|
def interrogateapi(self, interrogatereq: InterrogateRequest):
|
||||||
image_b64 = interrogatereq.image
|
image_b64 = interrogatereq.image
|
||||||
if image_b64 is None:
|
if image_b64 is None:
|
||||||
raise HTTPException(status_code=404, detail="Image not found")
|
raise HTTPException(status_code=404, detail="Image not found")
|
||||||
|
|
||||||
img = decode_base64_to_image(image_b64)
|
img = decode_base64_to_image(image_b64)
|
||||||
img = img.convert('RGB')
|
img = img.convert('RGB')
|
||||||
@ -252,7 +252,7 @@ class Api:
|
|||||||
processed = deepbooru.model.tag(img)
|
processed = deepbooru.model.tag(img)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=404, detail="Model not found")
|
raise HTTPException(status_code=404, detail="Model not found")
|
||||||
|
|
||||||
return InterrogateResponse(caption=processed)
|
return InterrogateResponse(caption=processed)
|
||||||
|
|
||||||
def interruptapi(self):
|
def interruptapi(self):
|
||||||
@ -308,7 +308,7 @@ class Api:
|
|||||||
def get_realesrgan_models(self):
|
def get_realesrgan_models(self):
|
||||||
return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
|
return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
|
||||||
|
|
||||||
def get_promp_styles(self):
|
def get_prompt_styles(self):
|
||||||
styleList = []
|
styleList = []
|
||||||
for k in shared.prompt_styles.styles:
|
for k in shared.prompt_styles.styles:
|
||||||
style = shared.prompt_styles.styles[k]
|
style = shared.prompt_styles.styles[k]
|
||||||
|
@ -128,7 +128,7 @@ class ExtrasBaseRequest(BaseModel):
|
|||||||
upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.")
|
upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.")
|
||||||
upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
|
upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
|
||||||
upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
|
upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
|
||||||
upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the choosen size?")
|
upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?")
|
||||||
upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
|
upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
|
||||||
upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
|
upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
|
||||||
extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
|
extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
|
||||||
|
@ -438,7 +438,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||||||
The directory to save the image. Note, the option `save_to_dirs` will make the image to be saved into a sub directory.
|
The directory to save the image. Note, the option `save_to_dirs` will make the image to be saved into a sub directory.
|
||||||
basename (`str`):
|
basename (`str`):
|
||||||
The base filename which will be applied to `filename pattern`.
|
The base filename which will be applied to `filename pattern`.
|
||||||
seed, prompt, short_filename,
|
seed, prompt, short_filename,
|
||||||
extension (`str`):
|
extension (`str`):
|
||||||
Image file extension, default is `png`.
|
Image file extension, default is `png`.
|
||||||
pngsectionname (`str`):
|
pngsectionname (`str`):
|
||||||
@ -599,7 +599,7 @@ def read_info_from_image(image):
|
|||||||
Negative prompt: {json_info["uc"]}
|
Negative prompt: {json_info["uc"]}
|
||||||
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
|
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error parsing NovelAI iamge generation parameters:", file=sys.stderr)
|
print(f"Error parsing NovelAI image generation parameters:", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
return geninfo, items
|
return geninfo, items
|
||||||
|
@ -150,11 +150,11 @@ class StableDiffusionProcessing():
|
|||||||
|
|
||||||
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
||||||
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
||||||
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
|
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
|
||||||
|
|
||||||
# Add the fake full 1s mask to the first dimension.
|
# Add the fake full 1s mask to the first dimension.
|
||||||
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
||||||
image_conditioning = image_conditioning.to(x.dtype)
|
image_conditioning = image_conditioning.to(x.dtype)
|
||||||
|
|
||||||
return image_conditioning
|
return image_conditioning
|
||||||
|
|
||||||
@ -202,7 +202,7 @@ class StableDiffusionProcessing():
|
|||||||
source_image * (1.0 - conditioning_mask),
|
source_image * (1.0 - conditioning_mask),
|
||||||
getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
|
getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Encode the new masked image using first stage of network.
|
# Encode the new masked image using first stage of network.
|
||||||
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
|
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
|
||||||
|
|
||||||
@ -540,7 +540,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
for n in range(p.n_iter):
|
for n in range(p.n_iter):
|
||||||
if state.skipped:
|
if state.skipped:
|
||||||
state.skipped = False
|
state.skipped = False
|
||||||
|
|
||||||
if state.interrupted:
|
if state.interrupted:
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -615,7 +615,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
image.info["parameters"] = text
|
image.info["parameters"] = text
|
||||||
output_images.append(image)
|
output_images.append(image)
|
||||||
|
|
||||||
del x_samples_ddim
|
del x_samples_ddim
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
@ -707,7 +707,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
|
|
||||||
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
|
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
|
||||||
|
|
||||||
"""saves image before applying hires fix, if enabled in options; takes as an arguyment either an image or batch with latent space images"""
|
"""saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
|
||||||
def save_intermediate(image, index):
|
def save_intermediate(image, index):
|
||||||
if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
|
if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
|
||||||
return
|
return
|
||||||
@ -723,7 +723,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
|
|
||||||
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
|
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
|
||||||
|
|
||||||
# Avoid making the inpainting conditioning unless necessary as
|
# Avoid making the inpainting conditioning unless necessary as
|
||||||
# this does need some extra compute to decode / encode the image again.
|
# this does need some extra compute to decode / encode the image again.
|
||||||
if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
|
if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
|
||||||
image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
|
image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
|
||||||
|
@ -80,7 +80,7 @@ def check_pt(filename, extra_handler):
|
|||||||
# new pytorch format is a zip file
|
# new pytorch format is a zip file
|
||||||
with zipfile.ZipFile(filename) as z:
|
with zipfile.ZipFile(filename) as z:
|
||||||
check_zip_filenames(filename, z.namelist())
|
check_zip_filenames(filename, z.namelist())
|
||||||
|
|
||||||
# find filename of data.pkl in zip file: '<directory name>/data.pkl'
|
# find filename of data.pkl in zip file: '<directory name>/data.pkl'
|
||||||
data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
|
data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
|
||||||
if len(data_pkl_filenames) == 0:
|
if len(data_pkl_filenames) == 0:
|
||||||
@ -108,7 +108,7 @@ def load(filename, *args, **kwargs):
|
|||||||
|
|
||||||
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
this functon is intended to be used by extensions that want to load models with
|
this function is intended to be used by extensions that want to load models with
|
||||||
some extra classes in them that the usual unpickler would find suspicious.
|
some extra classes in them that the usual unpickler would find suspicious.
|
||||||
|
|
||||||
Use the extra_handler argument to specify a function that takes module and field name as text,
|
Use the extra_handler argument to specify a function that takes module and field name as text,
|
||||||
|
@ -36,7 +36,7 @@ class Script:
|
|||||||
def ui(self, is_img2img):
|
def ui(self, is_img2img):
|
||||||
"""this function should create gradio UI elements. See https://gradio.app/docs/#components
|
"""this function should create gradio UI elements. See https://gradio.app/docs/#components
|
||||||
The return value should be an array of all components that are used in processing.
|
The return value should be an array of all components that are used in processing.
|
||||||
Values of those returned componenbts will be passed to run() and process() functions.
|
Values of those returned components will be passed to run() and process() functions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
@ -47,7 +47,7 @@ class Script:
|
|||||||
|
|
||||||
This function should return:
|
This function should return:
|
||||||
- False if the script should not be shown in UI at all
|
- False if the script should not be shown in UI at all
|
||||||
- True if the script should be shown in UI if it's scelected in the scripts drowpdown
|
- True if the script should be shown in UI if it's selected in the scripts dropdown
|
||||||
- script.AlwaysVisible if the script should be shown in UI at all times
|
- script.AlwaysVisible if the script should be shown in UI at all times
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -209,7 +209,7 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
|
|||||||
else:
|
else:
|
||||||
x_in = torch.cat([x] * 2)
|
x_in = torch.cat([x] * 2)
|
||||||
t_in = torch.cat([t] * 2)
|
t_in = torch.cat([t] * 2)
|
||||||
|
|
||||||
if isinstance(c, dict):
|
if isinstance(c, dict):
|
||||||
assert isinstance(unconditional_conditioning, dict)
|
assert isinstance(unconditional_conditioning, dict)
|
||||||
c_in = dict()
|
c_in = dict()
|
||||||
@ -278,7 +278,7 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
|
|||||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||||
|
|
||||||
return x_prev, pred_x0, e_t
|
return x_prev, pred_x0, e_t
|
||||||
|
|
||||||
# =================================================================================================
|
# =================================================================================================
|
||||||
# Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config.
|
# Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config.
|
||||||
# Adapted from:
|
# Adapted from:
|
||||||
@ -325,7 +325,7 @@ def should_hijack_inpainting(checkpoint_info):
|
|||||||
def do_inpainting_hijack():
|
def do_inpainting_hijack():
|
||||||
# most of this stuff seems to no longer be needed because it is already included into SD2.0
|
# most of this stuff seems to no longer be needed because it is already included into SD2.0
|
||||||
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
|
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
|
||||||
# this file should be cleaned up later if weverything tuens out to work fine
|
# this file should be cleaned up later if everything turns out to work fine
|
||||||
|
|
||||||
# ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
|
# ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
|
||||||
# ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
|
# ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
|
||||||
|
@ -4,7 +4,7 @@ import torch
|
|||||||
class TorchHijackForUnet:
|
class TorchHijackForUnet:
|
||||||
"""
|
"""
|
||||||
This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
|
This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
|
||||||
this makes it possible to create pictures with dimensions that are muliples of 8 rather than 64
|
this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __getattr__(self, item):
|
def __getattr__(self, item):
|
||||||
|
@ -28,9 +28,9 @@ class DatasetEntry:
|
|||||||
|
|
||||||
|
|
||||||
class PersonalizedBase(Dataset):
|
class PersonalizedBase(Dataset):
|
||||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
|
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
|
||||||
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
|
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
|
||||||
|
|
||||||
self.placeholder_token = placeholder_token
|
self.placeholder_token = placeholder_token
|
||||||
|
|
||||||
self.width = width
|
self.width = width
|
||||||
@ -50,14 +50,14 @@ class PersonalizedBase(Dataset):
|
|||||||
|
|
||||||
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
||||||
|
|
||||||
|
|
||||||
self.shuffle_tags = shuffle_tags
|
self.shuffle_tags = shuffle_tags
|
||||||
self.tag_drop_out = tag_drop_out
|
self.tag_drop_out = tag_drop_out
|
||||||
|
|
||||||
print("Preparing dataset...")
|
print("Preparing dataset...")
|
||||||
for path in tqdm.tqdm(self.image_paths):
|
for path in tqdm.tqdm(self.image_paths):
|
||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
raise Exception("inturrupted")
|
raise Exception("interrupted")
|
||||||
try:
|
try:
|
||||||
image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
|
image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -144,7 +144,7 @@ class PersonalizedDataLoader(DataLoader):
|
|||||||
self.collate_fn = collate_wrapper_random
|
self.collate_fn = collate_wrapper_random
|
||||||
else:
|
else:
|
||||||
self.collate_fn = collate_wrapper
|
self.collate_fn = collate_wrapper
|
||||||
|
|
||||||
|
|
||||||
class BatchLoader:
|
class BatchLoader:
|
||||||
def __init__(self, data):
|
def __init__(self, data):
|
||||||
|
@ -133,7 +133,7 @@ class EmbeddingDatabase:
|
|||||||
|
|
||||||
process_file(fullfn, fn)
|
process_file(fullfn, fn)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error loading emedding {fn}:", file=sys.stderr)
|
print(f"Error loading embedding {fn}:", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -194,7 +194,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
|
|||||||
csv_writer.writeheader()
|
csv_writer.writeheader()
|
||||||
|
|
||||||
epoch = (step - 1) // epoch_len
|
epoch = (step - 1) // epoch_len
|
||||||
epoch_step = (step - 1) % epoch_len
|
epoch_step = (step - 1) % epoch_len
|
||||||
|
|
||||||
csv_writer.writerow({
|
csv_writer.writerow({
|
||||||
"step": step,
|
"step": step,
|
||||||
@ -270,9 +270,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
|||||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
||||||
|
|
||||||
pin_memory = shared.opts.pin_memory
|
pin_memory = shared.opts.pin_memory
|
||||||
|
|
||||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
|
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
|
||||||
|
|
||||||
latent_sampling_method = ds.latent_sampling_method
|
latent_sampling_method = ds.latent_sampling_method
|
||||||
@ -295,12 +295,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
|||||||
loss_step = 0
|
loss_step = 0
|
||||||
_loss_step = 0 #internal
|
_loss_step = 0 #internal
|
||||||
|
|
||||||
|
|
||||||
last_saved_file = "<none>"
|
last_saved_file = "<none>"
|
||||||
last_saved_image = "<none>"
|
last_saved_image = "<none>"
|
||||||
forced_filename = "<none>"
|
forced_filename = "<none>"
|
||||||
embedding_yet_to_be_embedded = False
|
embedding_yet_to_be_embedded = False
|
||||||
|
|
||||||
pbar = tqdm.tqdm(total=steps - initial_step)
|
pbar = tqdm.tqdm(total=steps - initial_step)
|
||||||
try:
|
try:
|
||||||
for i in range((steps-initial_step) * gradient_step):
|
for i in range((steps-initial_step) * gradient_step):
|
||||||
@ -327,10 +327,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
|||||||
c = shared.sd_model.cond_stage_model(batch.cond_text)
|
c = shared.sd_model.cond_stage_model(batch.cond_text)
|
||||||
loss = shared.sd_model(x, c)[0] / gradient_step
|
loss = shared.sd_model(x, c)[0] / gradient_step
|
||||||
del x
|
del x
|
||||||
|
|
||||||
_loss_step += loss.item()
|
_loss_step += loss.item()
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
# go back until we reach gradient accumulation steps
|
# go back until we reach gradient accumulation steps
|
||||||
if (j + 1) % gradient_step != 0:
|
if (j + 1) % gradient_step != 0:
|
||||||
continue
|
continue
|
||||||
|
@ -18,7 +18,7 @@ def draw_xy_grid(xs, ys, x_label, y_label, cell):
|
|||||||
ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys]
|
ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys]
|
||||||
hor_texts = [[images.GridAnnotation(x_label(x))] for x in xs]
|
hor_texts = [[images.GridAnnotation(x_label(x))] for x in xs]
|
||||||
|
|
||||||
first_pocessed = None
|
first_processed = None
|
||||||
|
|
||||||
state.job_count = len(xs) * len(ys)
|
state.job_count = len(xs) * len(ys)
|
||||||
|
|
||||||
@ -27,17 +27,17 @@ def draw_xy_grid(xs, ys, x_label, y_label, cell):
|
|||||||
state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
|
state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
|
||||||
|
|
||||||
processed = cell(x, y)
|
processed = cell(x, y)
|
||||||
if first_pocessed is None:
|
if first_processed is None:
|
||||||
first_pocessed = processed
|
first_processed = processed
|
||||||
|
|
||||||
res.append(processed.images[0])
|
res.append(processed.images[0])
|
||||||
|
|
||||||
grid = images.image_grid(res, rows=len(ys))
|
grid = images.image_grid(res, rows=len(ys))
|
||||||
grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts)
|
grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts)
|
||||||
|
|
||||||
first_pocessed.images = [grid]
|
first_processed.images = [grid]
|
||||||
|
|
||||||
return first_pocessed
|
return first_processed
|
||||||
|
|
||||||
|
|
||||||
class Script(scripts.Script):
|
class Script(scripts.Script):
|
||||||
|
4
webui.py
4
webui.py
@ -154,8 +154,8 @@ def webui():
|
|||||||
|
|
||||||
# gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
|
# gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
|
||||||
# an attacker to trick the user into opening a malicious HTML page, which makes a request to the
|
# an attacker to trick the user into opening a malicious HTML page, which makes a request to the
|
||||||
# running web ui and do whatever the attcker wants, including installing an extension and
|
# running web ui and do whatever the attacker wants, including installing an extension and
|
||||||
# runnnig its code. We disable this here. Suggested by RyotaK.
|
# running its code. We disable this here. Suggested by RyotaK.
|
||||||
app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
|
app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
|
||||||
|
|
||||||
setup_cors(app)
|
setup_cors(app)
|
||||||
|
Loading…
Reference in New Issue
Block a user