mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-21 13:50:12 +08:00
Merge branch 'master' into master
This commit is contained in:
commit
37aafdb059
32
.github/ISSUE_TEMPLATE/bug_report.md
vendored
32
.github/ISSUE_TEMPLATE/bug_report.md
vendored
@ -1,32 +0,0 @@
|
||||
---
|
||||
name: Bug report
|
||||
about: Create a report to help us improve
|
||||
title: ''
|
||||
labels: bug-report
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**To Reproduce**
|
||||
Steps to reproduce the behavior:
|
||||
1. Go to '...'
|
||||
2. Click on '....'
|
||||
3. Scroll down to '....'
|
||||
4. See error
|
||||
|
||||
**Expected behavior**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Screenshots**
|
||||
If applicable, add screenshots to help explain your problem.
|
||||
|
||||
**Desktop (please complete the following information):**
|
||||
- OS: [e.g. Windows, Linux]
|
||||
- Browser [e.g. chrome, safari]
|
||||
- Commit revision [looks like this: e68484500f76a33ba477d5a99340ab30451e557b; can be seen when launching webui.bat, or obtained manually by running `git rev-parse HEAD`]
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
83
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
83
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
@ -0,0 +1,83 @@
|
||||
name: Bug Report
|
||||
description: You think somethings is broken in the UI
|
||||
title: "[Bug]: "
|
||||
labels: ["bug-report"]
|
||||
|
||||
body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Is there an existing issue for this?
|
||||
description: Please search to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit.
|
||||
options:
|
||||
- label: I have searched the existing issues and checked the recent builds/commits
|
||||
required: true
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
*Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible**
|
||||
- type: textarea
|
||||
id: what-did
|
||||
attributes:
|
||||
label: What happened?
|
||||
description: Tell us what happened in a very clear and simple way
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: steps
|
||||
attributes:
|
||||
label: Steps to reproduce the problem
|
||||
description: Please provide us with precise step by step information on how to reproduce the bug
|
||||
value: |
|
||||
1. Go to ....
|
||||
2. Press ....
|
||||
3. ...
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: what-should
|
||||
attributes:
|
||||
label: What should have happened?
|
||||
description: tell what you think the normal behavior should be
|
||||
validations:
|
||||
required: true
|
||||
- type: input
|
||||
id: commit
|
||||
attributes:
|
||||
label: Commit where the problem happens
|
||||
description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit hash** shown in the cmd/terminal when you launch the UI)
|
||||
validations:
|
||||
required: true
|
||||
- type: dropdown
|
||||
id: platforms
|
||||
attributes:
|
||||
label: What platforms do you use to access UI ?
|
||||
multiple: true
|
||||
options:
|
||||
- Windows
|
||||
- Linux
|
||||
- MacOS
|
||||
- iOS
|
||||
- Android
|
||||
- Other/Cloud
|
||||
- type: dropdown
|
||||
id: browsers
|
||||
attributes:
|
||||
label: What browsers do you use to access the UI ?
|
||||
multiple: true
|
||||
options:
|
||||
- Mozilla Firefox
|
||||
- Google Chrome
|
||||
- Brave
|
||||
- Apple Safari
|
||||
- Microsoft Edge
|
||||
- type: textarea
|
||||
id: cmdargs
|
||||
attributes:
|
||||
label: Command Line Arguments
|
||||
description: Are you using any launching parameters/command line arguments (modified webui-user.py) ? If yes, please write them below
|
||||
render: Shell
|
||||
- type: textarea
|
||||
id: misc
|
||||
attributes:
|
||||
label: Additional information, context and logs
|
||||
description: Please provide us with any relevant additional info, context or log output.
|
5
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
5
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: WebUI Community Support
|
||||
url: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions
|
||||
about: Please ask and answer questions here.
|
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
@ -1,20 +0,0 @@
|
||||
---
|
||||
name: Feature request
|
||||
about: Suggest an idea for this project
|
||||
title: ''
|
||||
labels: 'suggestion'
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Is your feature request related to a problem? Please describe.**
|
||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
||||
|
||||
**Describe the solution you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Describe alternatives you've considered**
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
**Additional context**
|
||||
Add any other context or screenshots about the feature request here.
|
40
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
40
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
@ -0,0 +1,40 @@
|
||||
name: Feature request
|
||||
description: Suggest an idea for this project
|
||||
title: "[Feature Request]: "
|
||||
labels: ["suggestion"]
|
||||
|
||||
body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Is there an existing issue for this?
|
||||
description: Please search to see if an issue already exists for the feature you want, and that it's not implemented in a recent build/commit.
|
||||
options:
|
||||
- label: I have searched the existing issues and checked the recent builds/commits
|
||||
required: true
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
*Please fill this form with as much information as possible, provide screenshots and/or illustrations of the feature if possible*
|
||||
- type: textarea
|
||||
id: feature
|
||||
attributes:
|
||||
label: What would your feature do ?
|
||||
description: Tell us about your feature in a very clear and simple way, and what problem it would solve
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: workflow
|
||||
attributes:
|
||||
label: Proposed workflow
|
||||
description: Please provide us with step by step information on how you'd like the feature to be accessed and used
|
||||
value: |
|
||||
1. Go to ....
|
||||
2. Press ....
|
||||
3. ...
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: misc
|
||||
attributes:
|
||||
label: Additional information
|
||||
description: Add any other context or screenshots about the feature request here.
|
31
.github/workflows/run_tests.yaml
vendored
Normal file
31
.github/workflows/run_tests.yaml
vendored
Normal file
@ -0,0 +1,31 @@
|
||||
name: Run basic features tests on CPU with empty SD model
|
||||
|
||||
on:
|
||||
- push
|
||||
- pull_request
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v3
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.10.6
|
||||
- uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
|
||||
restore-keys: ${{ runner.os }}-pip-
|
||||
- name: Run tests
|
||||
run: python launch.py --tests basic_features --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
|
||||
- name: Upload main app stdout-stderr
|
||||
uses: actions/upload-artifact@v3
|
||||
if: always()
|
||||
with:
|
||||
name: stdout-stderr
|
||||
path: |
|
||||
test/stdout.txt
|
||||
test/stderr.txt
|
5
.gitignore
vendored
5
.gitignore
vendored
@ -1,5 +1,6 @@
|
||||
__pycache__
|
||||
*.ckpt
|
||||
*.safetensors
|
||||
*.pth
|
||||
/ESRGAN/*
|
||||
/SwinIR/*
|
||||
@ -27,3 +28,7 @@ __pycache__
|
||||
notification.mp3
|
||||
/SwinIR
|
||||
/textual_inversion
|
||||
.vscode
|
||||
/extensions
|
||||
/test/stdout.txt
|
||||
/test/stderr.txt
|
||||
|
11
CODEOWNERS
11
CODEOWNERS
@ -1 +1,12 @@
|
||||
* @AUTOMATIC1111
|
||||
|
||||
# if you were managing a localization and were removed from this file, this is because
|
||||
# the intended way to do localizations now is via extensions. See:
|
||||
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Developing-extensions
|
||||
# Make a repo with your localization and since you are still listed as a collaborator
|
||||
# you can add it to the wiki page yourself. This change is because some people complained
|
||||
# the git commit log is cluttered with things unrelated to almost everyone and
|
||||
# because I believe this is the best overall for the project to handle localizations almost
|
||||
# entirely without my oversight.
|
||||
|
||||
|
||||
|
42
README.md
42
README.md
@ -11,6 +11,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
|
||||
- One click install and run script (but you still must install python and git)
|
||||
- Outpainting
|
||||
- Inpainting
|
||||
- Color Sketch
|
||||
- Prompt Matrix
|
||||
- Stable Diffusion Upscale
|
||||
- Attention, specify parts of text that the model should pay more attention to
|
||||
@ -23,6 +24,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
|
||||
- have as many embeddings as you want and use any names you like for them
|
||||
- use multiple embeddings with different numbers of vectors per token
|
||||
- works with half precision floating point numbers
|
||||
- train embeddings on 8GB (also reports of 6GB working)
|
||||
- Extras tab with:
|
||||
- GFPGAN, neural network that fixes faces
|
||||
- CodeFormer, face restoration tool as an alternative to GFPGAN
|
||||
@ -37,14 +39,14 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
|
||||
- Interrupt processing at any time
|
||||
- 4GB video card support (also reports of 2GB working)
|
||||
- Correct seeds for batches
|
||||
- Prompt length validation
|
||||
- get length of prompt in tokens as you type
|
||||
- get a warning after generation if some text was truncated
|
||||
- Live prompt token length validation
|
||||
- Generation parameters
|
||||
- parameters you used to generate images are saved with that image
|
||||
- in PNG chunks for PNG, in EXIF for JPEG
|
||||
- can drag the image to PNG info tab to restore generation parameters and automatically copy them into UI
|
||||
- can be disabled in settings
|
||||
- drag and drop an image/text-parameters to promptbox
|
||||
- Read Generation Parameters Button, loads parameters in promptbox to UI
|
||||
- Settings page
|
||||
- Running arbitrary python code from UI (must run with --allow-code to enable)
|
||||
- Mouseover hints for most UI elements
|
||||
@ -59,25 +61,37 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
|
||||
- CLIP interrogator, a button that tries to guess prompt from an image
|
||||
- Prompt Editing, a way to change prompt mid-generation, say to start making a watermelon and switch to anime girl midway
|
||||
- Batch Processing, process a group of files using img2img
|
||||
- Img2img Alternative
|
||||
- Img2img Alternative, reverse Euler method of cross attention control
|
||||
- Highres Fix, a convenience option to produce high resolution pictures in one click without usual distortions
|
||||
- Reloading checkpoints on the fly
|
||||
- Checkpoint Merger, a tab that allows you to merge two checkpoints into one
|
||||
- Checkpoint Merger, a tab that allows you to merge up to 3 checkpoints into one
|
||||
- [Custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) with many extensions from community
|
||||
- [Composable-Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/), a way to use multiple prompts at once
|
||||
- separate prompts using uppercase `AND`
|
||||
- also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2`
|
||||
- No token limit for prompts (original stable diffusion lets you use up to 75 tokens)
|
||||
- DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args)
|
||||
- DeepDanbooru integration, creates danbooru style tags for anime prompts
|
||||
- [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args)
|
||||
- via extension: [History tab](https://github.com/yfszzx/stable-diffusion-webui-images-browser): view, direct and delete images conveniently within the UI
|
||||
- Generate forever option
|
||||
- Training tab
|
||||
- hypernetworks and embeddings options
|
||||
- Preprocessing images: cropping, mirroring, autotagging using BLIP or deepdanbooru (for anime)
|
||||
- Clip skip
|
||||
- Use Hypernetworks
|
||||
- Use VAEs
|
||||
- Estimated completion time in progress bar
|
||||
- API
|
||||
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
|
||||
- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embeds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
|
||||
- [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions
|
||||
|
||||
## Installation and Running
|
||||
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
|
||||
|
||||
Alternatively, use Google Colab:
|
||||
Alternatively, use online services (like Google Colab):
|
||||
|
||||
- [Colab, maintained by Akaibu](https://colab.research.google.com/drive/1kw3egmSn-KgWsikYvOMjJkVDsPLjEMzl)
|
||||
- [Colab, original by me, outdated](https://colab.research.google.com/drive/1Iy-xW9t1-OQWhb0hNxueGij8phCyluOh).
|
||||
- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
|
||||
|
||||
### Automatic Installation on Windows
|
||||
1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH"
|
||||
@ -113,6 +127,8 @@ Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC
|
||||
The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki).
|
||||
|
||||
## Credits
|
||||
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
|
||||
|
||||
- Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers
|
||||
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git
|
||||
- GFPGAN - https://github.com/TencentARC/GFPGAN.git
|
||||
@ -121,15 +137,17 @@ The documentation was moved from this README over to the project's [wiki](https:
|
||||
- SwinIR - https://github.com/JingyunLiang/SwinIR
|
||||
- Swin2SR - https://github.com/mv-lab/swin2sr
|
||||
- LDSR - https://github.com/Hafiidz/latent-diffusion
|
||||
- MiDaS - https://github.com/isl-org/MiDaS
|
||||
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
|
||||
- Doggettx - Cross Attention layer optimization - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
|
||||
- InvokeAI, lstein - Cross Attention layer optimization - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
|
||||
- Rinon Gal - Textual Inversion - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
|
||||
- Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
|
||||
- Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
|
||||
- Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
|
||||
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
|
||||
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
|
||||
- CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator
|
||||
- Idea for Composable Diffusion - https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch
|
||||
- xformers - https://github.com/facebookresearch/xformers
|
||||
- DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru
|
||||
- Security advice - RyotaK
|
||||
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
||||
- (You)
|
||||
|
72
configs/alt-diffusion-inference.yaml
Normal file
72
configs/alt-diffusion-inference.yaml
Normal file
@ -0,0 +1,72 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: modules.xlmr.BertSeriesModelWithTransformation
|
||||
params:
|
||||
name: "XLMR-Large"
|
70
configs/v1-inference.yaml
Normal file
70
configs/v1-inference.yaml
Normal file
@ -0,0 +1,70 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import gc
|
||||
import time
|
||||
import warnings
|
||||
@ -8,27 +9,49 @@ import torchvision
|
||||
from PIL import Image
|
||||
from einops import rearrange, repeat
|
||||
from omegaconf import OmegaConf
|
||||
import safetensors.torch
|
||||
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.util import instantiate_from_config, ismap
|
||||
from modules import shared, sd_hijack
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
|
||||
cached_ldsr_model: torch.nn.Module = None
|
||||
|
||||
|
||||
# Create LDSR Class
|
||||
class LDSR:
|
||||
def load_model_from_config(self, half_attention):
|
||||
global cached_ldsr_model
|
||||
|
||||
if shared.opts.ldsr_cached and cached_ldsr_model is not None:
|
||||
print("Loading model from cache")
|
||||
model: torch.nn.Module = cached_ldsr_model
|
||||
else:
|
||||
print(f"Loading model from {self.modelPath}")
|
||||
_, extension = os.path.splitext(self.modelPath)
|
||||
if extension.lower() == ".safetensors":
|
||||
pl_sd = safetensors.torch.load_file(self.modelPath, device="cpu")
|
||||
else:
|
||||
pl_sd = torch.load(self.modelPath, map_location="cpu")
|
||||
sd = pl_sd["state_dict"]
|
||||
sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd
|
||||
config = OmegaConf.load(self.yamlPath)
|
||||
model = instantiate_from_config(config.model)
|
||||
config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1"
|
||||
model: torch.nn.Module = instantiate_from_config(config.model)
|
||||
model.load_state_dict(sd, strict=False)
|
||||
model.cuda()
|
||||
model = model.to(shared.device)
|
||||
if half_attention:
|
||||
model = model.half()
|
||||
if shared.cmd_opts.opt_channelslast:
|
||||
model = model.to(memory_format=torch.channels_last)
|
||||
|
||||
sd_hijack.model_hijack.hijack(model) # apply optimization
|
||||
model.eval()
|
||||
|
||||
if shared.opts.ldsr_cached:
|
||||
cached_ldsr_model = model
|
||||
|
||||
return {"model": model}
|
||||
|
||||
def __init__(self, model_path, yaml_path):
|
||||
@ -93,6 +116,7 @@ class LDSR:
|
||||
down_sample_method = 'Lanczos'
|
||||
|
||||
gc.collect()
|
||||
if torch.cuda.is_available:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
im_og = image
|
||||
@ -101,8 +125,8 @@ class LDSR:
|
||||
down_sample_rate = target_scale / 4
|
||||
wd = width_og * down_sample_rate
|
||||
hd = height_og * down_sample_rate
|
||||
width_downsampled_pre = int(wd)
|
||||
height_downsampled_pre = int(hd)
|
||||
width_downsampled_pre = int(np.ceil(wd))
|
||||
height_downsampled_pre = int(np.ceil(hd))
|
||||
|
||||
if down_sample_rate != 1:
|
||||
print(
|
||||
@ -110,7 +134,12 @@ class LDSR:
|
||||
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
|
||||
else:
|
||||
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
|
||||
logs = self.run(model["model"], im_og, diffusion_steps, eta)
|
||||
|
||||
# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
|
||||
pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
|
||||
im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
|
||||
|
||||
logs = self.run(model["model"], im_padded, diffusion_steps, eta)
|
||||
|
||||
sample = logs["sample"]
|
||||
sample = sample.detach().cpu()
|
||||
@ -120,9 +149,14 @@ class LDSR:
|
||||
sample = np.transpose(sample, (0, 2, 3, 1))
|
||||
a = Image.fromarray(sample[0])
|
||||
|
||||
# remove padding
|
||||
a = a.crop((0, 0) + tuple(np.array(im_og.size) * 4))
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
if torch.cuda.is_available:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return a
|
||||
|
||||
|
||||
@ -137,7 +171,7 @@ def get_cond(selected_path):
|
||||
c = rearrange(c, '1 c h w -> 1 h w c')
|
||||
c = 2. * c - 1.
|
||||
|
||||
c = c.to(torch.device("cuda"))
|
||||
c = c.to(shared.device)
|
||||
example["LR_image"] = c
|
||||
example["image"] = c_up
|
||||
|
6
extensions-builtin/LDSR/preload.py
Normal file
6
extensions-builtin/LDSR/preload.py
Normal file
@ -0,0 +1,6 @@
|
||||
import os
|
||||
from modules import paths
|
||||
|
||||
|
||||
def preload(parser):
|
||||
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR'))
|
@ -5,8 +5,9 @@ import traceback
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.ldsr_model_arch import LDSR
|
||||
from modules import shared
|
||||
from ldsr_model_arch import LDSR
|
||||
from modules import shared, script_callbacks
|
||||
import sd_hijack_autoencoder, sd_hijack_ddpm_v1
|
||||
|
||||
|
||||
class UpscalerLDSR(Upscaler):
|
||||
@ -24,6 +25,7 @@ class UpscalerLDSR(Upscaler):
|
||||
yaml_path = os.path.join(self.model_path, "project.yaml")
|
||||
old_model_path = os.path.join(self.model_path, "model.pth")
|
||||
new_model_path = os.path.join(self.model_path, "model.ckpt")
|
||||
safetensors_model_path = os.path.join(self.model_path, "model.safetensors")
|
||||
if os.path.exists(yaml_path):
|
||||
statinfo = os.stat(yaml_path)
|
||||
if statinfo.st_size >= 10485760:
|
||||
@ -32,6 +34,9 @@ class UpscalerLDSR(Upscaler):
|
||||
if os.path.exists(old_model_path):
|
||||
print("Renaming model from model.pth to model.ckpt")
|
||||
os.rename(old_model_path, new_model_path)
|
||||
if os.path.exists(safetensors_model_path):
|
||||
model = safetensors_model_path
|
||||
else:
|
||||
model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
||||
file_name="model.ckpt", progress=True)
|
||||
yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
|
||||
@ -52,3 +57,13 @@ class UpscalerLDSR(Upscaler):
|
||||
return img
|
||||
ddim_steps = shared.opts.ldsr_steps
|
||||
return ldsr.super_resolution(img, ddim_steps, self.scale)
|
||||
|
||||
|
||||
def on_ui_settings():
|
||||
import gradio as gr
|
||||
|
||||
shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling")))
|
||||
shared.opts.add_option("ldsr_cached", shared.OptionInfo(False, "Cache LDSR model in memory", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")))
|
||||
|
||||
|
||||
script_callbacks.on_ui_settings(on_ui_settings)
|
286
extensions-builtin/LDSR/sd_hijack_autoencoder.py
Normal file
286
extensions-builtin/LDSR/sd_hijack_autoencoder.py
Normal file
@ -0,0 +1,286 @@
|
||||
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
|
||||
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
|
||||
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
|
||||
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
import torch.nn.functional as F
|
||||
from contextlib import contextmanager
|
||||
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
||||
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
import ldm.models.autoencoder
|
||||
|
||||
class VQModel(pl.LightningModule):
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
n_embed,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
batch_resize_range=None,
|
||||
scheduler_config=None,
|
||||
lr_g_factor=1.0,
|
||||
remap=None,
|
||||
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||||
use_ema=False
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.n_embed = n_embed
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
||||
remap=remap,
|
||||
sane_index_shape=sane_index_shape)
|
||||
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels)==int
|
||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
self.batch_resize_range = batch_resize_range
|
||||
if self.batch_resize_range is not None:
|
||||
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
|
||||
|
||||
self.use_ema = use_ema
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self)
|
||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
self.scheduler_config = scheduler_config
|
||||
self.lr_g_factor = lr_g_factor
|
||||
|
||||
@contextmanager
|
||||
def ema_scope(self, context=None):
|
||||
if self.use_ema:
|
||||
self.model_ema.store(self.parameters())
|
||||
self.model_ema.copy_to(self)
|
||||
if context is not None:
|
||||
print(f"{context}: Switched to EMA weights")
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.use_ema:
|
||||
self.model_ema.restore(self.parameters())
|
||||
if context is not None:
|
||||
print(f"{context}: Restored training weights")
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
missing, unexpected = self.load_state_dict(sd, strict=False)
|
||||
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||
if len(missing) > 0:
|
||||
print(f"Missing Keys: {missing}")
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
self.model_ema(self)
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
return quant, emb_loss, info
|
||||
|
||||
def encode_to_prequant(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
def decode(self, quant):
|
||||
quant = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant)
|
||||
return dec
|
||||
|
||||
def decode_code(self, code_b):
|
||||
quant_b = self.quantize.embed_code(code_b)
|
||||
dec = self.decode(quant_b)
|
||||
return dec
|
||||
|
||||
def forward(self, input, return_pred_indices=False):
|
||||
quant, diff, (_,_,ind) = self.encode(input)
|
||||
dec = self.decode(quant)
|
||||
if return_pred_indices:
|
||||
return dec, diff, ind
|
||||
return dec, diff
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||
if self.batch_resize_range is not None:
|
||||
lower_size = self.batch_resize_range[0]
|
||||
upper_size = self.batch_resize_range[1]
|
||||
if self.global_step <= 4:
|
||||
# do the first few batches with max size to avoid later oom
|
||||
new_resize = upper_size
|
||||
else:
|
||||
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
|
||||
if new_resize != x.shape[2]:
|
||||
x = F.interpolate(x, size=new_resize, mode="bicubic")
|
||||
x = x.detach()
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
# https://github.com/pytorch/pytorch/issues/37142
|
||||
# try not to fool the heuristics
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# autoencode
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train",
|
||||
predicted_indices=ind)
|
||||
|
||||
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# discriminator
|
||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
log_dict = self._validation_step(batch, batch_idx)
|
||||
with self.ema_scope():
|
||||
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
||||
return log_dict
|
||||
|
||||
def _validation_step(self, batch, batch_idx, suffix=""):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val"+suffix,
|
||||
predicted_indices=ind
|
||||
)
|
||||
|
||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val"+suffix,
|
||||
predicted_indices=ind
|
||||
)
|
||||
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
|
||||
self.log(f"val{suffix}/rec_loss", rec_loss,
|
||||
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||
self.log(f"val{suffix}/aeloss", aeloss,
|
||||
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
||||
del log_dict_ae[f"val{suffix}/rec_loss"]
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr_d = self.learning_rate
|
||||
lr_g = self.lr_g_factor*self.learning_rate
|
||||
print("lr_d", lr_d)
|
||||
print("lr_g", lr_g)
|
||||
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||||
list(self.decoder.parameters())+
|
||||
list(self.quantize.parameters())+
|
||||
list(self.quant_conv.parameters())+
|
||||
list(self.post_quant_conv.parameters()),
|
||||
lr=lr_g, betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||
lr=lr_d, betas=(0.5, 0.9))
|
||||
|
||||
if self.scheduler_config is not None:
|
||||
scheduler = instantiate_from_config(self.scheduler_config)
|
||||
|
||||
print("Setting up LambdaLR scheduler...")
|
||||
scheduler = [
|
||||
{
|
||||
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
||||
'interval': 'step',
|
||||
'frequency': 1
|
||||
},
|
||||
{
|
||||
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
||||
'interval': 'step',
|
||||
'frequency': 1
|
||||
},
|
||||
]
|
||||
return [opt_ae, opt_disc], scheduler
|
||||
return [opt_ae, opt_disc], []
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if only_inputs:
|
||||
log["inputs"] = x
|
||||
return log
|
||||
xrec, _ = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log["inputs"] = x
|
||||
log["reconstructions"] = xrec
|
||||
if plot_ema:
|
||||
with self.ema_scope():
|
||||
xrec_ema, _ = self(x)
|
||||
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
|
||||
log["reconstructions_ema"] = xrec_ema
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == "segmentation"
|
||||
if not hasattr(self, "colorize"):
|
||||
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
||||
return x
|
||||
|
||||
|
||||
class VQModelInterface(VQModel):
|
||||
def __init__(self, embed_dim, *args, **kwargs):
|
||||
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
def decode(self, h, force_not_quantize=False):
|
||||
# also go through quantization layer
|
||||
if not force_not_quantize:
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
else:
|
||||
quant = h
|
||||
quant = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant)
|
||||
return dec
|
||||
|
||||
setattr(ldm.models.autoencoder, "VQModel", VQModel)
|
||||
setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface)
|
1449
extensions-builtin/LDSR/sd_hijack_ddpm_v1.py
Normal file
1449
extensions-builtin/LDSR/sd_hijack_ddpm_v1.py
Normal file
File diff suppressed because it is too large
Load Diff
6
extensions-builtin/ScuNET/preload.py
Normal file
6
extensions-builtin/ScuNET/preload.py
Normal file
@ -0,0 +1,6 @@
|
||||
import os
|
||||
from modules import paths
|
||||
|
||||
|
||||
def preload(parser):
|
||||
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(paths.models_path, 'ScuNET'))
|
@ -9,7 +9,7 @@ from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
import modules.upscaler
|
||||
from modules import devices, modelloader
|
||||
from modules.scunet_model_arch import SCUNet as net
|
||||
from scunet_model_arch import SCUNet as net
|
||||
|
||||
|
||||
class UpscalerScuNET(modules.upscaler.Upscaler):
|
||||
@ -49,14 +49,13 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
||||
if model is None:
|
||||
return img
|
||||
|
||||
device = devices.device_scunet
|
||||
device = devices.get_device_for('scunet')
|
||||
img = np.array(img)
|
||||
img = img[:, :, ::-1]
|
||||
img = np.moveaxis(img, 2, 0) / 255
|
||||
img = torch.from_numpy(img).float()
|
||||
img = img.unsqueeze(0).to(device)
|
||||
|
||||
img = img.to(device)
|
||||
with torch.no_grad():
|
||||
output = model(img)
|
||||
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
@ -67,7 +66,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
||||
return PIL.Image.fromarray(output, 'RGB')
|
||||
|
||||
def load_model(self, path: str):
|
||||
device = devices.device_scunet
|
||||
device = devices.get_device_for('scunet')
|
||||
if "http" in path:
|
||||
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
|
||||
progress=True)
|
6
extensions-builtin/SwinIR/preload.py
Normal file
6
extensions-builtin/SwinIR/preload.py
Normal file
@ -0,0 +1,6 @@
|
||||
import os
|
||||
from modules import paths
|
||||
|
||||
|
||||
def preload(parser):
|
||||
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(paths.models_path, 'SwinIR'))
|
@ -7,15 +7,14 @@ from PIL import Image
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from tqdm import tqdm
|
||||
|
||||
from modules import modelloader
|
||||
from modules.shared import cmd_opts, opts, device
|
||||
from modules.swinir_model_arch import SwinIR as net
|
||||
from modules.swinir_model_arch_v2 import Swin2SR as net2
|
||||
from modules import modelloader, devices, script_callbacks, shared
|
||||
from modules.shared import cmd_opts, opts
|
||||
from swinir_model_arch import SwinIR as net
|
||||
from swinir_model_arch_v2 import Swin2SR as net2
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
|
||||
precision_scope = (
|
||||
torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
|
||||
)
|
||||
|
||||
device_swinir = devices.get_device_for('swinir')
|
||||
|
||||
|
||||
class UpscalerSwinIR(Upscaler):
|
||||
@ -42,7 +41,7 @@ class UpscalerSwinIR(Upscaler):
|
||||
model = self.load_model(model_file)
|
||||
if model is None:
|
||||
return img
|
||||
model = model.to(device)
|
||||
model = model.to(device_swinir, dtype=devices.dtype)
|
||||
img = upscale(img, model)
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
@ -94,25 +93,27 @@ class UpscalerSwinIR(Upscaler):
|
||||
model.load_state_dict(pretrained_model[params], strict=True)
|
||||
else:
|
||||
model.load_state_dict(pretrained_model, strict=True)
|
||||
if not cmd_opts.no_half:
|
||||
model = model.half()
|
||||
return model
|
||||
|
||||
|
||||
def upscale(
|
||||
img,
|
||||
model,
|
||||
tile=opts.SWIN_tile,
|
||||
tile_overlap=opts.SWIN_tile_overlap,
|
||||
tile=None,
|
||||
tile_overlap=None,
|
||||
window_size=8,
|
||||
scale=4,
|
||||
):
|
||||
tile = tile or opts.SWIN_tile
|
||||
tile_overlap = tile_overlap or opts.SWIN_tile_overlap
|
||||
|
||||
|
||||
img = np.array(img)
|
||||
img = img[:, :, ::-1]
|
||||
img = np.moveaxis(img, 2, 0) / 255
|
||||
img = torch.from_numpy(img).float()
|
||||
img = img.unsqueeze(0).to(device)
|
||||
with torch.no_grad(), precision_scope("cuda"):
|
||||
img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype)
|
||||
with torch.no_grad(), devices.autocast():
|
||||
_, _, h_old, w_old = img.size()
|
||||
h_pad = (h_old // window_size + 1) * window_size - h_old
|
||||
w_pad = (w_old // window_size + 1) * window_size - w_old
|
||||
@ -139,8 +140,8 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
|
||||
stride = tile - tile_overlap
|
||||
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
|
||||
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
|
||||
E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
|
||||
W = torch.zeros_like(E, dtype=torch.half, device=device)
|
||||
E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img)
|
||||
W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir)
|
||||
|
||||
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
|
||||
for h_idx in h_idx_list:
|
||||
@ -159,3 +160,13 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
|
||||
output = E.div_(W)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def on_ui_settings():
|
||||
import gradio as gr
|
||||
|
||||
shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
|
||||
shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
|
||||
|
||||
|
||||
script_callbacks.on_ui_settings(on_ui_settings)
|
@ -0,0 +1,107 @@
|
||||
// Stable Diffusion WebUI - Bracket checker
|
||||
// Version 1.0
|
||||
// By Hingashi no Florin/Bwin4L
|
||||
// Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
|
||||
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
|
||||
|
||||
function checkBrackets(evt) {
|
||||
textArea = evt.target;
|
||||
tabName = evt.target.parentElement.parentElement.id.split("_")[0];
|
||||
counterElt = document.querySelector('gradio-app').shadowRoot.querySelector('#' + tabName + '_token_counter');
|
||||
|
||||
promptName = evt.target.parentElement.parentElement.id.includes('neg') ? ' negative' : '';
|
||||
|
||||
errorStringParen = '(' + tabName + promptName + ' prompt) - Different number of opening and closing parentheses detected.\n';
|
||||
errorStringSquare = '[' + tabName + promptName + ' prompt] - Different number of opening and closing square brackets detected.\n';
|
||||
errorStringCurly = '{' + tabName + promptName + ' prompt} - Different number of opening and closing curly brackets detected.\n';
|
||||
|
||||
openBracketRegExp = /\(/g;
|
||||
closeBracketRegExp = /\)/g;
|
||||
|
||||
openSquareBracketRegExp = /\[/g;
|
||||
closeSquareBracketRegExp = /\]/g;
|
||||
|
||||
openCurlyBracketRegExp = /\{/g;
|
||||
closeCurlyBracketRegExp = /\}/g;
|
||||
|
||||
totalOpenBracketMatches = 0;
|
||||
totalCloseBracketMatches = 0;
|
||||
totalOpenSquareBracketMatches = 0;
|
||||
totalCloseSquareBracketMatches = 0;
|
||||
totalOpenCurlyBracketMatches = 0;
|
||||
totalCloseCurlyBracketMatches = 0;
|
||||
|
||||
openBracketMatches = textArea.value.match(openBracketRegExp);
|
||||
if(openBracketMatches) {
|
||||
totalOpenBracketMatches = openBracketMatches.length;
|
||||
}
|
||||
|
||||
closeBracketMatches = textArea.value.match(closeBracketRegExp);
|
||||
if(closeBracketMatches) {
|
||||
totalCloseBracketMatches = closeBracketMatches.length;
|
||||
}
|
||||
|
||||
openSquareBracketMatches = textArea.value.match(openSquareBracketRegExp);
|
||||
if(openSquareBracketMatches) {
|
||||
totalOpenSquareBracketMatches = openSquareBracketMatches.length;
|
||||
}
|
||||
|
||||
closeSquareBracketMatches = textArea.value.match(closeSquareBracketRegExp);
|
||||
if(closeSquareBracketMatches) {
|
||||
totalCloseSquareBracketMatches = closeSquareBracketMatches.length;
|
||||
}
|
||||
|
||||
openCurlyBracketMatches = textArea.value.match(openCurlyBracketRegExp);
|
||||
if(openCurlyBracketMatches) {
|
||||
totalOpenCurlyBracketMatches = openCurlyBracketMatches.length;
|
||||
}
|
||||
|
||||
closeCurlyBracketMatches = textArea.value.match(closeCurlyBracketRegExp);
|
||||
if(closeCurlyBracketMatches) {
|
||||
totalCloseCurlyBracketMatches = closeCurlyBracketMatches.length;
|
||||
}
|
||||
|
||||
if(totalOpenBracketMatches != totalCloseBracketMatches) {
|
||||
if(!counterElt.title.includes(errorStringParen)) {
|
||||
counterElt.title += errorStringParen;
|
||||
}
|
||||
} else {
|
||||
counterElt.title = counterElt.title.replace(errorStringParen, '');
|
||||
}
|
||||
|
||||
if(totalOpenSquareBracketMatches != totalCloseSquareBracketMatches) {
|
||||
if(!counterElt.title.includes(errorStringSquare)) {
|
||||
counterElt.title += errorStringSquare;
|
||||
}
|
||||
} else {
|
||||
counterElt.title = counterElt.title.replace(errorStringSquare, '');
|
||||
}
|
||||
|
||||
if(totalOpenCurlyBracketMatches != totalCloseCurlyBracketMatches) {
|
||||
if(!counterElt.title.includes(errorStringCurly)) {
|
||||
counterElt.title += errorStringCurly;
|
||||
}
|
||||
} else {
|
||||
counterElt.title = counterElt.title.replace(errorStringCurly, '');
|
||||
}
|
||||
|
||||
if(counterElt.title != '') {
|
||||
counterElt.style = 'color: #FF5555;';
|
||||
} else {
|
||||
counterElt.style = '';
|
||||
}
|
||||
}
|
||||
|
||||
var shadowRootLoaded = setInterval(function() {
|
||||
var shadowTextArea = document.querySelector('gradio-app').shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea');
|
||||
if(shadowTextArea.length < 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
clearInterval(shadowRootLoaded);
|
||||
|
||||
document.querySelector('gradio-app').shadowRoot.querySelector('#txt2img_prompt').onkeyup = checkBrackets;
|
||||
document.querySelector('gradio-app').shadowRoot.querySelector('#txt2img_neg_prompt').onkeyup = checkBrackets;
|
||||
document.querySelector('gradio-app').shadowRoot.querySelector('#img2img_prompt').onkeyup = checkBrackets;
|
||||
document.querySelector('gradio-app').shadowRoot.querySelector('#img2img_neg_prompt').onkeyup = checkBrackets;
|
||||
}, 1000);
|
50
extensions-builtin/roll-artist/scripts/roll-artist.py
Normal file
50
extensions-builtin/roll-artist/scripts/roll-artist.py
Normal file
@ -0,0 +1,50 @@
|
||||
import random
|
||||
|
||||
from modules import script_callbacks, shared
|
||||
import gradio as gr
|
||||
|
||||
art_symbol = '\U0001f3a8' # 🎨
|
||||
global_prompt = None
|
||||
related_ids = {"txt2img_prompt", "txt2img_clear_prompt", "img2img_prompt", "img2img_clear_prompt" }
|
||||
|
||||
|
||||
def roll_artist(prompt):
|
||||
allowed_cats = set([x for x in shared.artist_db.categories() if len(shared.opts.random_artist_categories)==0 or x in shared.opts.random_artist_categories])
|
||||
artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats])
|
||||
|
||||
return prompt + ", " + artist.name if prompt != '' else artist.name
|
||||
|
||||
|
||||
def add_roll_button(prompt):
|
||||
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
|
||||
|
||||
roll.click(
|
||||
fn=roll_artist,
|
||||
_js="update_txt2img_tokens",
|
||||
inputs=[
|
||||
prompt,
|
||||
],
|
||||
outputs=[
|
||||
prompt,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def after_component(component, **kwargs):
|
||||
global global_prompt
|
||||
|
||||
elem_id = kwargs.get('elem_id', None)
|
||||
if elem_id not in related_ids:
|
||||
return
|
||||
|
||||
if elem_id == "txt2img_prompt":
|
||||
global_prompt = component
|
||||
elif elem_id == "txt2img_clear_prompt":
|
||||
add_roll_button(global_prompt)
|
||||
elif elem_id == "img2img_prompt":
|
||||
global_prompt = component
|
||||
elif elem_id == "img2img_clear_prompt":
|
||||
add_roll_button(global_prompt)
|
||||
|
||||
|
||||
script_callbacks.on_after_component(after_component)
|
0
extensions/put extensions here.txt
Normal file
0
extensions/put extensions here.txt
Normal file
9
html/footer.html
Normal file
9
html/footer.html
Normal file
@ -0,0 +1,9 @@
|
||||
<div>
|
||||
<a href="/docs">API</a>
|
||||
•
|
||||
<a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui">Github</a>
|
||||
•
|
||||
<a href="https://gradio.app">Gradio</a>
|
||||
•
|
||||
<a href="/" onclick="javascript:gradioApp().getElementById('settings_restart_gradio').click(); return false">Reload UI</a>
|
||||
</div>
|
392
html/licenses.html
Normal file
392
html/licenses.html
Normal file
@ -0,0 +1,392 @@
|
||||
<style>
|
||||
#licenses h2 {font-size: 1.2em; font-weight: bold; margin-bottom: 0.2em;}
|
||||
#licenses small {font-size: 0.95em; opacity: 0.85;}
|
||||
#licenses pre { margin: 1em 0 2em 0;}
|
||||
</style>
|
||||
|
||||
<h2><a href="https://github.com/sczhou/CodeFormer/blob/master/LICENSE">CodeFormer</a></h2>
|
||||
<small>Parts of CodeFormer code had to be copied to be compatible with GFPGAN.</small>
|
||||
<pre>
|
||||
S-Lab License 1.0
|
||||
|
||||
Copyright 2022 S-Lab
|
||||
|
||||
Redistribution and use for non-commercial purpose in source and
|
||||
binary forms, with or without modification, are permitted provided
|
||||
that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in
|
||||
the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
In the event that redistribution and/or use for commercial purpose in
|
||||
source or binary forms, with or without modification is required,
|
||||
please contact the contributor(s) of the work.
|
||||
</pre>
|
||||
|
||||
|
||||
<h2><a href="https://github.com/victorca25/iNNfer/blob/main/LICENSE">ESRGAN</a></h2>
|
||||
<small>Code for architecture and reading models copied.</small>
|
||||
<pre>
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2021 victorca25
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
</pre>
|
||||
|
||||
<h2><a href="https://github.com/xinntao/Real-ESRGAN/blob/master/LICENSE">Real-ESRGAN</a></h2>
|
||||
<small>Some code is copied to support ESRGAN models.</small>
|
||||
<pre>
|
||||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2021, Xintao Wang
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
</pre>
|
||||
|
||||
<h2><a href="https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE">InvokeAI</a></h2>
|
||||
<small>Some code for compatibility with OSX is taken from lstein's repository.</small>
|
||||
<pre>
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 InvokeAI Team
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
</pre>
|
||||
|
||||
<h2><a href="https://github.com/Hafiidz/latent-diffusion/blob/main/LICENSE">LDSR</a></h2>
|
||||
<small>Code added by contirubtors, most likely copied from this repository.</small>
|
||||
<pre>
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
</pre>
|
||||
|
||||
<h2><a href="https://github.com/pharmapsychotic/clip-interrogator/blob/main/LICENSE">CLIP Interrogator</a></h2>
|
||||
<small>Some small amounts of code borrowed and reworked.</small>
|
||||
<pre>
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 pharmapsychotic
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
</pre>
|
||||
|
||||
<h2><a href="https://github.com/JingyunLiang/SwinIR/blob/main/LICENSE">SwinIR</a></h2>
|
||||
<small>Code added by contirubtors, most likely copied from this repository.</small>
|
||||
|
||||
<pre>
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [2021] [SwinIR Authors]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
</pre>
|
||||
|
@ -3,12 +3,12 @@ let currentWidth = null;
|
||||
let currentHeight = null;
|
||||
let arFrameTimeout = setTimeout(function(){},0);
|
||||
|
||||
function dimensionChange(e,dimname){
|
||||
function dimensionChange(e, is_width, is_height){
|
||||
|
||||
if(dimname == 'Width'){
|
||||
if(is_width){
|
||||
currentWidth = e.target.value*1.0
|
||||
}
|
||||
if(dimname == 'Height'){
|
||||
if(is_height){
|
||||
currentHeight = e.target.value*1.0
|
||||
}
|
||||
|
||||
@ -18,22 +18,13 @@ function dimensionChange(e,dimname){
|
||||
return;
|
||||
}
|
||||
|
||||
var img2imgMode = gradioApp().querySelector('#mode_img2img.tabs > div > button.rounded-t-lg.border-gray-200')
|
||||
if(img2imgMode){
|
||||
img2imgMode=img2imgMode.innerText
|
||||
}else{
|
||||
return;
|
||||
}
|
||||
|
||||
var redrawImage = gradioApp().querySelector('div[data-testid=image] img');
|
||||
var inpaintImage = gradioApp().querySelector('#img2maskimg div[data-testid=image] img')
|
||||
|
||||
var targetElement = null;
|
||||
|
||||
if(img2imgMode=='img2img' && redrawImage){
|
||||
targetElement = redrawImage;
|
||||
}else if(img2imgMode=='Inpaint' && inpaintImage){
|
||||
targetElement = inpaintImage;
|
||||
var tabIndex = get_tab_index('mode_img2img')
|
||||
if(tabIndex == 0){
|
||||
targetElement = gradioApp().querySelector('div[data-testid=image] img');
|
||||
} else if(tabIndex == 1){
|
||||
targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
|
||||
}
|
||||
|
||||
if(targetElement){
|
||||
@ -99,21 +90,19 @@ onUiUpdate(function(){
|
||||
if(inImg2img){
|
||||
let inputs = gradioApp().querySelectorAll('input');
|
||||
inputs.forEach(function(e){
|
||||
let parentLabel = e.parentElement.querySelector('label')
|
||||
if(parentLabel && parentLabel.innerText){
|
||||
if(!e.classList.contains('scrollwatch')){
|
||||
if(parentLabel.innerText == 'Width' || parentLabel.innerText == 'Height'){
|
||||
e.addEventListener('input', function(e){dimensionChange(e,parentLabel.innerText)} )
|
||||
var is_width = e.parentElement.id == "img2img_width"
|
||||
var is_height = e.parentElement.id == "img2img_height"
|
||||
|
||||
if((is_width || is_height) && !e.classList.contains('scrollwatch')){
|
||||
e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} )
|
||||
e.classList.add('scrollwatch')
|
||||
}
|
||||
if(parentLabel.innerText == 'Width'){
|
||||
if(is_width){
|
||||
currentWidth = e.value*1.0
|
||||
}
|
||||
if(parentLabel.innerText == 'Height'){
|
||||
if(is_height){
|
||||
currentHeight = e.value*1.0
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
});
|
||||
|
@ -61,15 +61,15 @@ contextMenuInit = function(){
|
||||
|
||||
}
|
||||
|
||||
function appendContextMenuOption(targetEmementSelector,entryName,entryFunction){
|
||||
function appendContextMenuOption(targetElementSelector,entryName,entryFunction){
|
||||
|
||||
currentItems = menuSpecs.get(targetEmementSelector)
|
||||
currentItems = menuSpecs.get(targetElementSelector)
|
||||
|
||||
if(!currentItems){
|
||||
currentItems = []
|
||||
menuSpecs.set(targetEmementSelector,currentItems);
|
||||
menuSpecs.set(targetElementSelector,currentItems);
|
||||
}
|
||||
let newItem = {'id':targetEmementSelector+'_'+uid(),
|
||||
let newItem = {'id':targetElementSelector+'_'+uid(),
|
||||
'name':entryName,
|
||||
'func':entryFunction,
|
||||
'isNew':true}
|
||||
|
10
javascript/dragdrop.js
vendored
10
javascript/dragdrop.js
vendored
@ -9,11 +9,19 @@ function dropReplaceImage( imgWrap, files ) {
|
||||
return;
|
||||
}
|
||||
|
||||
const tmpFile = files[0];
|
||||
|
||||
imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click();
|
||||
const callback = () => {
|
||||
const fileInput = imgWrap.querySelector('input[type="file"]');
|
||||
if ( fileInput ) {
|
||||
if ( files.length === 0 ) {
|
||||
files = new DataTransfer();
|
||||
files.items.add(tmpFile);
|
||||
fileInput.files = files.files;
|
||||
} else {
|
||||
fileInput.files = files;
|
||||
}
|
||||
fileInput.dispatchEvent(new Event('change'));
|
||||
}
|
||||
};
|
||||
@ -43,7 +51,7 @@ function dropReplaceImage( imgWrap, files ) {
|
||||
window.document.addEventListener('dragover', e => {
|
||||
const target = e.composedPath()[0];
|
||||
const imgWrap = target.closest('[data-testid="image"]');
|
||||
if ( !imgWrap && target.placeholder.indexOf("Prompt") == -1) {
|
||||
if ( !imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) {
|
||||
return;
|
||||
}
|
||||
e.stopPropagation();
|
||||
|
@ -1,7 +1,6 @@
|
||||
addEventListener('keydown', (event) => {
|
||||
let target = event.originalTarget || event.composedPath()[0];
|
||||
if (!target.hasAttribute("placeholder")) return;
|
||||
if (!target.placeholder.toLowerCase().includes("prompt")) return;
|
||||
if (!target.matches("#toprow textarea.gr-text-input[placeholder]")) return;
|
||||
if (! (event.metaKey || event.ctrlKey)) return;
|
||||
|
||||
|
||||
|
35
javascript/extensions.js
Normal file
35
javascript/extensions.js
Normal file
@ -0,0 +1,35 @@
|
||||
|
||||
function extensions_apply(_, _){
|
||||
disable = []
|
||||
update = []
|
||||
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
||||
if(x.name.startsWith("enable_") && ! x.checked)
|
||||
disable.push(x.name.substr(7))
|
||||
|
||||
if(x.name.startsWith("update_") && x.checked)
|
||||
update.push(x.name.substr(7))
|
||||
})
|
||||
|
||||
restart_reload()
|
||||
|
||||
return [JSON.stringify(disable), JSON.stringify(update)]
|
||||
}
|
||||
|
||||
function extensions_check(){
|
||||
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
|
||||
x.innerHTML = "Loading..."
|
||||
})
|
||||
|
||||
return []
|
||||
}
|
||||
|
||||
function install_extension_from_index(button, url){
|
||||
button.disabled = "disabled"
|
||||
button.value = "Installing..."
|
||||
|
||||
textarea = gradioApp().querySelector('#extension_to_install textarea')
|
||||
textarea.value = url
|
||||
textarea.dispatchEvent(new Event("input", { bubbles: true }))
|
||||
|
||||
gradioApp().querySelector('#install_extension_button').click()
|
||||
}
|
33
javascript/generationParams.js
Normal file
33
javascript/generationParams.js
Normal file
@ -0,0 +1,33 @@
|
||||
// attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes
|
||||
|
||||
let txt2img_gallery, img2img_gallery, modal = undefined;
|
||||
onUiUpdate(function(){
|
||||
if (!txt2img_gallery) {
|
||||
txt2img_gallery = attachGalleryListeners("txt2img")
|
||||
}
|
||||
if (!img2img_gallery) {
|
||||
img2img_gallery = attachGalleryListeners("img2img")
|
||||
}
|
||||
if (!modal) {
|
||||
modal = gradioApp().getElementById('lightboxModal')
|
||||
modalObserver.observe(modal, { attributes : true, attributeFilter : ['style'] });
|
||||
}
|
||||
});
|
||||
|
||||
let modalObserver = new MutationObserver(function(mutations) {
|
||||
mutations.forEach(function(mutationRecord) {
|
||||
let selectedTab = gradioApp().querySelector('#tabs div button.bg-white')?.innerText
|
||||
if (mutationRecord.target.style.display === 'none' && selectedTab === 'txt2img' || selectedTab === 'img2img')
|
||||
gradioApp().getElementById(selectedTab+"_generation_info_button").click()
|
||||
});
|
||||
});
|
||||
|
||||
function attachGalleryListeners(tab_name) {
|
||||
gallery = gradioApp().querySelector('#'+tab_name+'_gallery')
|
||||
gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click());
|
||||
gallery?.addEventListener('keydown', (e) => {
|
||||
if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow
|
||||
gradioApp().getElementById(tab_name+"_generation_info_button").click()
|
||||
});
|
||||
return gallery;
|
||||
}
|
@ -6,6 +6,7 @@ titles = {
|
||||
"GFPGAN": "Restore low quality faces using GFPGAN neural network",
|
||||
"Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps to higher than 30-40 does not help",
|
||||
"DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
|
||||
"DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution",
|
||||
|
||||
"Batch count": "How many batches of images to create",
|
||||
"Batch size": "How many image to create in a single batch",
|
||||
@ -17,6 +18,7 @@ titles = {
|
||||
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
|
||||
"\u{1f4c2}": "Open images output directory",
|
||||
"\u{1f4be}": "Save style",
|
||||
"\U0001F5D1": "Clear prompt",
|
||||
"\u{1f4cb}": "Apply selected styles to current prompt",
|
||||
|
||||
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
|
||||
@ -62,8 +64,8 @@ titles = {
|
||||
|
||||
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
|
||||
|
||||
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]; leave empty for default.",
|
||||
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]; leave empty for default.",
|
||||
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
|
||||
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
|
||||
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
|
||||
|
||||
"Loopback": "Process an image, use it as an input, repeat.",
|
||||
@ -75,6 +77,7 @@ titles = {
|
||||
"Create style": "Save current prompts as a style. If you add the token {prompt} to the text, the style use that as placeholder for your prompt when you use the style in the future.",
|
||||
|
||||
"Checkpoint name": "Loads weights from checkpoint before making images. You can either use hash or a part of filename (as seen in settings) for checkpoint name. Recommended to use with Y axis for less switching.",
|
||||
"Inpainting conditioning mask strength": "Only applies to inpainting models. Determines how strongly to mask off the original image for inpainting and img2img. 1.0 means fully masked, which is the default behaviour. 0.0 means a fully unmasked conditioning. Lower values will help preserve the overall composition of the image, but will struggle with large changes.",
|
||||
|
||||
"vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).",
|
||||
|
||||
@ -91,6 +94,13 @@ titles = {
|
||||
|
||||
"Weighted sum": "Result = A * (1 - M) + B * M",
|
||||
"Add difference": "Result = A + (B - C) * M",
|
||||
|
||||
"Learning rate": "how fast should the training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
|
||||
|
||||
"Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.",
|
||||
|
||||
"Approx NN": "Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resoluton and lower quality.",
|
||||
"Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resoluton and extremely low quality."
|
||||
}
|
||||
|
||||
|
||||
|
@ -1,206 +0,0 @@
|
||||
var images_history_click_image = function(){
|
||||
if (!this.classList.contains("transform")){
|
||||
var gallery = images_history_get_parent_by_class(this, "images_history_cantainor");
|
||||
var buttons = gallery.querySelectorAll(".gallery-item");
|
||||
var i = 0;
|
||||
var hidden_list = [];
|
||||
buttons.forEach(function(e){
|
||||
if (e.style.display == "none"){
|
||||
hidden_list.push(i);
|
||||
}
|
||||
i += 1;
|
||||
})
|
||||
if (hidden_list.length > 0){
|
||||
setTimeout(images_history_hide_buttons, 10, hidden_list, gallery);
|
||||
}
|
||||
}
|
||||
images_history_set_image_info(this);
|
||||
}
|
||||
|
||||
var images_history_click_tab = function(){
|
||||
var tabs_box = gradioApp().getElementById("images_history_tab");
|
||||
if (!tabs_box.classList.contains(this.getAttribute("tabname"))) {
|
||||
gradioApp().getElementById(this.getAttribute("tabname") + "_images_history_renew_page").click();
|
||||
tabs_box.classList.add(this.getAttribute("tabname"))
|
||||
}
|
||||
}
|
||||
|
||||
function images_history_disabled_del(){
|
||||
gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){
|
||||
btn.setAttribute('disabled','disabled');
|
||||
});
|
||||
}
|
||||
|
||||
function images_history_get_parent_by_class(item, class_name){
|
||||
var parent = item.parentElement;
|
||||
while(!parent.classList.contains(class_name)){
|
||||
parent = parent.parentElement;
|
||||
}
|
||||
return parent;
|
||||
}
|
||||
|
||||
function images_history_get_parent_by_tagname(item, tagname){
|
||||
var parent = item.parentElement;
|
||||
tagname = tagname.toUpperCase()
|
||||
while(parent.tagName != tagname){
|
||||
console.log(parent.tagName, tagname)
|
||||
parent = parent.parentElement;
|
||||
}
|
||||
return parent;
|
||||
}
|
||||
|
||||
function images_history_hide_buttons(hidden_list, gallery){
|
||||
var buttons = gallery.querySelectorAll(".gallery-item");
|
||||
var num = 0;
|
||||
buttons.forEach(function(e){
|
||||
if (e.style.display == "none"){
|
||||
num += 1;
|
||||
}
|
||||
});
|
||||
if (num == hidden_list.length){
|
||||
setTimeout(images_history_hide_buttons, 10, hidden_list, gallery);
|
||||
}
|
||||
for( i in hidden_list){
|
||||
buttons[hidden_list[i]].style.display = "none";
|
||||
}
|
||||
}
|
||||
|
||||
function images_history_set_image_info(button){
|
||||
var buttons = images_history_get_parent_by_tagname(button, "DIV").querySelectorAll(".gallery-item");
|
||||
var index = -1;
|
||||
var i = 0;
|
||||
buttons.forEach(function(e){
|
||||
if(e == button){
|
||||
index = i;
|
||||
}
|
||||
if(e.style.display != "none"){
|
||||
i += 1;
|
||||
}
|
||||
});
|
||||
var gallery = images_history_get_parent_by_class(button, "images_history_cantainor");
|
||||
var set_btn = gallery.querySelector(".images_history_set_index");
|
||||
var curr_idx = set_btn.getAttribute("img_index", index);
|
||||
if (curr_idx != index) {
|
||||
set_btn.setAttribute("img_index", index);
|
||||
images_history_disabled_del();
|
||||
}
|
||||
set_btn.click();
|
||||
|
||||
}
|
||||
|
||||
function images_history_get_current_img(tabname, image_path, files){
|
||||
return [
|
||||
gradioApp().getElementById(tabname + '_images_history_set_index').getAttribute("img_index"),
|
||||
image_path,
|
||||
files
|
||||
];
|
||||
}
|
||||
|
||||
function images_history_delete(del_num, tabname, img_path, img_file_name, page_index, filenames, image_index){
|
||||
image_index = parseInt(image_index);
|
||||
var tab = gradioApp().getElementById(tabname + '_images_history');
|
||||
var set_btn = tab.querySelector(".images_history_set_index");
|
||||
var buttons = [];
|
||||
tab.querySelectorAll(".gallery-item").forEach(function(e){
|
||||
if (e.style.display != 'none'){
|
||||
buttons.push(e);
|
||||
}
|
||||
});
|
||||
var img_num = buttons.length / 2;
|
||||
if (img_num <= del_num){
|
||||
setTimeout(function(tabname){
|
||||
gradioApp().getElementById(tabname + '_images_history_renew_page').click();
|
||||
}, 30, tabname);
|
||||
} else {
|
||||
var next_img
|
||||
for (var i = 0; i < del_num; i++){
|
||||
if (image_index + i < image_index + img_num){
|
||||
buttons[image_index + i].style.display = 'none';
|
||||
buttons[image_index + img_num + 1].style.display = 'none';
|
||||
next_img = image_index + i + 1
|
||||
}
|
||||
}
|
||||
var bnt;
|
||||
if (next_img >= img_num){
|
||||
btn = buttons[image_index - del_num];
|
||||
} else {
|
||||
btn = buttons[next_img];
|
||||
}
|
||||
setTimeout(function(btn){btn.click()}, 30, btn);
|
||||
}
|
||||
images_history_disabled_del();
|
||||
return [del_num, tabname, img_path, img_file_name, page_index, filenames, image_index];
|
||||
}
|
||||
|
||||
function images_history_turnpage(img_path, page_index, image_index, tabname){
|
||||
var buttons = gradioApp().getElementById(tabname + '_images_history').querySelectorAll(".gallery-item");
|
||||
buttons.forEach(function(elem) {
|
||||
elem.style.display = 'block';
|
||||
})
|
||||
return [img_path, page_index, image_index, tabname];
|
||||
}
|
||||
|
||||
function images_history_enable_del_buttons(){
|
||||
gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){
|
||||
btn.removeAttribute('disabled');
|
||||
})
|
||||
}
|
||||
|
||||
function images_history_init(){
|
||||
var load_txt2img_button = gradioApp().getElementById('txt2img_images_history_renew_page')
|
||||
if (load_txt2img_button){
|
||||
for (var i in images_history_tab_list ){
|
||||
tab = images_history_tab_list[i];
|
||||
gradioApp().getElementById(tab + '_images_history').classList.add("images_history_cantainor");
|
||||
gradioApp().getElementById(tab + '_images_history_set_index').classList.add("images_history_set_index");
|
||||
gradioApp().getElementById(tab + '_images_history_del_button').classList.add("images_history_del_button");
|
||||
gradioApp().getElementById(tab + '_images_history_gallery').classList.add("images_history_gallery");
|
||||
|
||||
}
|
||||
var tabs_box = gradioApp().getElementById("tab_images_history").querySelector("div").querySelector("div").querySelector("div");
|
||||
tabs_box.setAttribute("id", "images_history_tab");
|
||||
var tab_btns = tabs_box.querySelectorAll("button");
|
||||
for (var i in images_history_tab_list){
|
||||
var tabname = images_history_tab_list[i]
|
||||
tab_btns[i].setAttribute("tabname", tabname);
|
||||
|
||||
// this refreshes history upon tab switch
|
||||
// until the history is known to work well, which is not the case now, we do not do this at startup
|
||||
//tab_btns[i].addEventListener('click', images_history_click_tab);
|
||||
}
|
||||
tabs_box.classList.add(images_history_tab_list[0]);
|
||||
|
||||
// same as above, at page load
|
||||
//load_txt2img_button.click();
|
||||
} else {
|
||||
setTimeout(images_history_init, 500);
|
||||
}
|
||||
}
|
||||
|
||||
var images_history_tab_list = ["txt2img", "img2img", "extras"];
|
||||
setTimeout(images_history_init, 500);
|
||||
document.addEventListener("DOMContentLoaded", function() {
|
||||
var mutationObserver = new MutationObserver(function(m){
|
||||
for (var i in images_history_tab_list ){
|
||||
let tabname = images_history_tab_list[i]
|
||||
var buttons = gradioApp().querySelectorAll('#' + tabname + '_images_history .gallery-item');
|
||||
buttons.forEach(function(bnt){
|
||||
bnt.addEventListener('click', images_history_click_image, true);
|
||||
});
|
||||
|
||||
// same as load_txt2img_button.click() above
|
||||
/*
|
||||
var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg");
|
||||
if (cls_btn){
|
||||
cls_btn.addEventListener('click', function(){
|
||||
gradioApp().getElementById(tabname + '_images_history_renew_page').click();
|
||||
}, false);
|
||||
}*/
|
||||
|
||||
}
|
||||
});
|
||||
mutationObserver.observe( gradioApp(), { childList:true, subtree:true });
|
||||
|
||||
});
|
||||
|
||||
|
@ -13,6 +13,15 @@ function showModal(event) {
|
||||
}
|
||||
lb.style.display = "block";
|
||||
lb.focus()
|
||||
|
||||
const tabTxt2Img = gradioApp().getElementById("tab_txt2img")
|
||||
const tabImg2Img = gradioApp().getElementById("tab_img2img")
|
||||
// show the save button in modal only on txt2img or img2img tabs
|
||||
if (tabTxt2Img.style.display != "none" || tabImg2Img.style.display != "none") {
|
||||
gradioApp().getElementById("modal_save").style.display = "inline"
|
||||
} else {
|
||||
gradioApp().getElementById("modal_save").style.display = "none"
|
||||
}
|
||||
event.stopPropagation()
|
||||
}
|
||||
|
||||
@ -81,6 +90,25 @@ function modalImageSwitch(offset) {
|
||||
}
|
||||
}
|
||||
|
||||
function saveImage(){
|
||||
const tabTxt2Img = gradioApp().getElementById("tab_txt2img")
|
||||
const tabImg2Img = gradioApp().getElementById("tab_img2img")
|
||||
const saveTxt2Img = "save_txt2img"
|
||||
const saveImg2Img = "save_img2img"
|
||||
if (tabTxt2Img.style.display != "none") {
|
||||
gradioApp().getElementById(saveTxt2Img).click()
|
||||
} else if (tabImg2Img.style.display != "none") {
|
||||
gradioApp().getElementById(saveImg2Img).click()
|
||||
} else {
|
||||
console.error("missing implementation for saving modal of this type")
|
||||
}
|
||||
}
|
||||
|
||||
function modalSaveImage(event) {
|
||||
saveImage()
|
||||
event.stopPropagation()
|
||||
}
|
||||
|
||||
function modalNextImage(event) {
|
||||
modalImageSwitch(1)
|
||||
event.stopPropagation()
|
||||
@ -93,6 +121,9 @@ function modalPrevImage(event) {
|
||||
|
||||
function modalKeyHandler(event) {
|
||||
switch (event.key) {
|
||||
case "s":
|
||||
saveImage()
|
||||
break;
|
||||
case "ArrowLeft":
|
||||
modalPrevImage(event)
|
||||
break;
|
||||
@ -198,6 +229,14 @@ document.addEventListener("DOMContentLoaded", function() {
|
||||
modalTileImage.title = "Preview tiling";
|
||||
modalControls.appendChild(modalTileImage)
|
||||
|
||||
const modalSave = document.createElement("span")
|
||||
modalSave.className = "modalSave cursor"
|
||||
modalSave.id = "modal_save"
|
||||
modalSave.innerHTML = "🖫"
|
||||
modalSave.addEventListener("click", modalSaveImage, true)
|
||||
modalSave.title = "Save Image(s)"
|
||||
modalControls.appendChild(modalSave)
|
||||
|
||||
const modalClose = document.createElement('span')
|
||||
modalClose.className = 'modalClose cursor';
|
||||
modalClose.innerHTML = '×'
|
||||
|
@ -108,6 +108,9 @@ function processNode(node){
|
||||
|
||||
function dumpTranslations(){
|
||||
dumped = {}
|
||||
if (localization.rtl) {
|
||||
dumped.rtl = true
|
||||
}
|
||||
|
||||
Object.keys(original_lines).forEach(function(text){
|
||||
if(dumped[text] !== undefined) return
|
||||
@ -129,6 +132,24 @@ onUiUpdate(function(m){
|
||||
|
||||
document.addEventListener("DOMContentLoaded", function() {
|
||||
processNode(gradioApp())
|
||||
|
||||
if (localization.rtl) { // if the language is from right to left,
|
||||
(new MutationObserver((mutations, observer) => { // wait for the style to load
|
||||
mutations.forEach(mutation => {
|
||||
mutation.addedNodes.forEach(node => {
|
||||
if (node.tagName === 'STYLE') {
|
||||
observer.disconnect();
|
||||
|
||||
for (const x of node.sheet.rules) { // find all rtl media rules
|
||||
if (Array.from(x.media || []).includes('rtl')) {
|
||||
x.media.appendMedium('all'); // enable them
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
});
|
||||
})).observe(gradioApp(), { childList: true });
|
||||
}
|
||||
})
|
||||
|
||||
function download_localization() {
|
||||
|
@ -15,7 +15,7 @@ onUiUpdate(function(){
|
||||
}
|
||||
}
|
||||
|
||||
const galleryPreviews = gradioApp().querySelectorAll('img.h-full.w-full.overflow-hidden');
|
||||
const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] img.h-full.w-full.overflow-hidden');
|
||||
|
||||
if (galleryPreviews == null) return;
|
||||
|
||||
|
@ -3,14 +3,27 @@ global_progressbars = {}
|
||||
galleries = {}
|
||||
galleryObservers = {}
|
||||
|
||||
// this tracks launches of window.setTimeout for progressbar to prevent starting a new timeout when the previous is still running
|
||||
timeoutIds = {}
|
||||
|
||||
function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
|
||||
var progressbar = gradioApp().getElementById(id_progressbar)
|
||||
// gradio 3.8's enlightened approach allows them to create two nested div elements inside each other with same id
|
||||
// every time you use gr.HTML(elem_id='xxx'), so we handle this here
|
||||
var progressbar = gradioApp().querySelector("#"+id_progressbar+" #"+id_progressbar)
|
||||
var progressbarParent
|
||||
if(progressbar){
|
||||
progressbarParent = gradioApp().querySelector("#"+id_progressbar)
|
||||
} else{
|
||||
progressbar = gradioApp().getElementById(id_progressbar)
|
||||
progressbarParent = null
|
||||
}
|
||||
|
||||
var skip = id_skip ? gradioApp().getElementById(id_skip) : null
|
||||
var interrupt = gradioApp().getElementById(id_interrupt)
|
||||
|
||||
if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){
|
||||
if(progressbar.innerText){
|
||||
let newtitle = 'Stable Diffusion - ' + progressbar.innerText.slice(2)
|
||||
let newtitle = '[' + progressbar.innerText.trim() + '] Stable Diffusion';
|
||||
if(document.title != newtitle){
|
||||
document.title = newtitle;
|
||||
}
|
||||
@ -26,18 +39,26 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
|
||||
global_progressbars[id_progressbar] = progressbar
|
||||
|
||||
var mutationObserver = new MutationObserver(function(m){
|
||||
if(timeoutIds[id_part]) return;
|
||||
|
||||
preview = gradioApp().getElementById(id_preview)
|
||||
gallery = gradioApp().getElementById(id_gallery)
|
||||
|
||||
if(preview != null && gallery != null){
|
||||
preview.style.width = gallery.clientWidth + "px"
|
||||
preview.style.height = gallery.clientHeight + "px"
|
||||
if(progressbarParent) progressbar.style.width = progressbarParent.clientWidth + "px"
|
||||
|
||||
//only watch gallery if there is a generation process going on
|
||||
check_gallery(id_gallery);
|
||||
|
||||
var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
|
||||
if(!progressDiv){
|
||||
if(progressDiv){
|
||||
timeoutIds[id_part] = window.setTimeout(function() {
|
||||
timeoutIds[id_part] = null
|
||||
requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt)
|
||||
}, 500)
|
||||
} else{
|
||||
if (skip) {
|
||||
skip.style.display = "none"
|
||||
}
|
||||
@ -49,11 +70,8 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
|
||||
galleries[id_gallery] = null;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
window.setTimeout(function() { requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) }, 500)
|
||||
});
|
||||
mutationObserver.observe( progressbar, { childList:true, subtree:true })
|
||||
}
|
||||
@ -74,14 +92,26 @@ function check_gallery(id_gallery){
|
||||
if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) {
|
||||
// automatically re-open previously selected index (if exists)
|
||||
activeElement = gradioApp().activeElement;
|
||||
let scrollX = window.scrollX;
|
||||
let scrollY = window.scrollY;
|
||||
|
||||
galleryButtons[prevSelectedIndex].click();
|
||||
showGalleryImage();
|
||||
|
||||
// When the gallery button is clicked, it gains focus and scrolls itself into view
|
||||
// We need to scroll back to the previous position
|
||||
setTimeout(function (){
|
||||
window.scrollTo(scrollX, scrollY);
|
||||
}, 50);
|
||||
|
||||
if(activeElement){
|
||||
// i fought this for about an hour; i don't know why the focus is lost or why this helps recover it
|
||||
// if somenoe has a better solution please by all means
|
||||
setTimeout(function() { activeElement.focus() }, 1);
|
||||
// if someone has a better solution please by all means
|
||||
setTimeout(function (){
|
||||
activeElement.focus({
|
||||
preventScroll: true // Refocus the element that was focused before the gallery was opened without scrolling to it
|
||||
})
|
||||
}, 1);
|
||||
}
|
||||
}
|
||||
})
|
||||
|
@ -1,4 +1,4 @@
|
||||
// various functions for interation with ui.py not large enough to warrant putting them in separate files
|
||||
// various functions for interaction with ui.py not large enough to warrant putting them in separate files
|
||||
|
||||
function set_theme(theme){
|
||||
gradioURL = window.location.href
|
||||
@ -8,8 +8,8 @@ function set_theme(theme){
|
||||
}
|
||||
|
||||
function selected_gallery_index(){
|
||||
var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem .gallery-item')
|
||||
var button = gradioApp().querySelector('[style="display: block;"].tabitem .gallery-item.\\!ring-2')
|
||||
var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem div[id$=_gallery] .gallery-item')
|
||||
var button = gradioApp().querySelector('[style="display: block;"].tabitem div[id$=_gallery] .gallery-item.\\!ring-2')
|
||||
|
||||
var result = -1
|
||||
buttons.forEach(function(v, i){ if(v==button) { result = i } })
|
||||
@ -19,7 +19,7 @@ function selected_gallery_index(){
|
||||
|
||||
function extract_image_from_gallery(gallery){
|
||||
if(gallery.length == 1){
|
||||
return gallery[0]
|
||||
return [gallery[0]]
|
||||
}
|
||||
|
||||
index = selected_gallery_index()
|
||||
@ -28,7 +28,7 @@ function extract_image_from_gallery(gallery){
|
||||
return [null]
|
||||
}
|
||||
|
||||
return gallery[index];
|
||||
return [gallery[index]];
|
||||
}
|
||||
|
||||
function args_to_array(args){
|
||||
@ -45,14 +45,14 @@ function switch_to_txt2img(){
|
||||
return args_to_array(arguments);
|
||||
}
|
||||
|
||||
function switch_to_img2img_img2img(){
|
||||
function switch_to_img2img(){
|
||||
gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
|
||||
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[0].click();
|
||||
|
||||
return args_to_array(arguments);
|
||||
}
|
||||
|
||||
function switch_to_img2img_inpaint(){
|
||||
function switch_to_inpaint(){
|
||||
gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
|
||||
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[1].click();
|
||||
|
||||
@ -65,26 +65,6 @@ function switch_to_extras(){
|
||||
return args_to_array(arguments);
|
||||
}
|
||||
|
||||
function extract_image_from_gallery_txt2img(gallery){
|
||||
switch_to_txt2img()
|
||||
return extract_image_from_gallery(gallery);
|
||||
}
|
||||
|
||||
function extract_image_from_gallery_img2img(gallery){
|
||||
switch_to_img2img_img2img()
|
||||
return extract_image_from_gallery(gallery);
|
||||
}
|
||||
|
||||
function extract_image_from_gallery_inpaint(gallery){
|
||||
switch_to_img2img_inpaint()
|
||||
return extract_image_from_gallery(gallery);
|
||||
}
|
||||
|
||||
function extract_image_from_gallery_extras(gallery){
|
||||
switch_to_extras()
|
||||
return extract_image_from_gallery(gallery);
|
||||
}
|
||||
|
||||
function get_tab_index(tabId){
|
||||
var res = 0
|
||||
|
||||
@ -120,7 +100,7 @@ function create_submit_args(args){
|
||||
|
||||
// As it is currently, txt2img and img2img send back the previous output args (txt2img_gallery, generation_info, html_info) whenever you generate a new image.
|
||||
// This can lead to uploading a huge gallery of previously generated images, which leads to an unnecessary delay between submitting and beginning to generate.
|
||||
// I don't know why gradio is seding outputs along with inputs, but we can prevent sending the image gallery here, which seems to be an issue for some.
|
||||
// I don't know why gradio is sending outputs along with inputs, but we can prevent sending the image gallery here, which seems to be an issue for some.
|
||||
// If gradio at some point stops sending outputs, this may break something
|
||||
if(Array.isArray(res[res.length - 3])){
|
||||
res[res.length - 3] = null
|
||||
@ -151,6 +131,15 @@ function ask_for_style_name(_, prompt_text, negative_prompt_text) {
|
||||
return [name_, prompt_text, negative_prompt_text]
|
||||
}
|
||||
|
||||
function confirm_clear_prompt(prompt, negative_prompt) {
|
||||
if(confirm("Delete prompt?")) {
|
||||
prompt = ""
|
||||
negative_prompt = ""
|
||||
}
|
||||
|
||||
return [prompt, negative_prompt]
|
||||
}
|
||||
|
||||
|
||||
|
||||
opts = {}
|
||||
@ -199,6 +188,17 @@ onUiUpdate(function(){
|
||||
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
|
||||
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
|
||||
}
|
||||
|
||||
show_all_pages = gradioApp().getElementById('settings_show_all_pages')
|
||||
settings_tabs = gradioApp().querySelector('#settings div')
|
||||
if(show_all_pages && settings_tabs){
|
||||
settings_tabs.appendChild(show_all_pages)
|
||||
show_all_pages.onclick = function(){
|
||||
gradioApp().querySelectorAll('#settings > div').forEach(function(elem){
|
||||
elem.style.display = "block";
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
let txt2img_textarea, img2img_textarea = undefined;
|
||||
@ -228,4 +228,6 @@ function update_token_counter(button_id) {
|
||||
function restart_reload(){
|
||||
document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
|
||||
setTimeout(function(){location.reload()},2000)
|
||||
|
||||
return []
|
||||
}
|
||||
|
135
launch.py
135
launch.py
@ -5,8 +5,11 @@ import sys
|
||||
import importlib.util
|
||||
import shlex
|
||||
import platform
|
||||
import argparse
|
||||
import json
|
||||
|
||||
dir_repos = "repositories"
|
||||
dir_extensions = "extensions"
|
||||
python = sys.executable
|
||||
git = os.environ.get('GIT', "git")
|
||||
index_url = os.environ.get('INDEX_URL', "")
|
||||
@ -16,11 +19,24 @@ def extract_arg(args, name):
|
||||
return [x for x in args if x != name], name in args
|
||||
|
||||
|
||||
def run(command, desc=None, errdesc=None):
|
||||
def extract_opt(args, name):
|
||||
opt = None
|
||||
is_present = False
|
||||
if name in args:
|
||||
is_present = True
|
||||
idx = args.index(name)
|
||||
del args[idx]
|
||||
if idx < len(args) and args[idx][0] != "-":
|
||||
opt = args[idx]
|
||||
del args[idx]
|
||||
return args, is_present, opt
|
||||
|
||||
|
||||
def run(command, desc=None, errdesc=None, custom_env=None):
|
||||
if desc is not None:
|
||||
print(desc)
|
||||
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
|
||||
|
||||
if result.returncode != 0:
|
||||
|
||||
@ -101,39 +117,81 @@ def version_check(commit):
|
||||
else:
|
||||
print("Not a git clone, can't perform version check.")
|
||||
except Exception as e:
|
||||
print("versipm check failed",e)
|
||||
print("version check failed", e)
|
||||
|
||||
|
||||
def prepare_enviroment():
|
||||
def run_extension_installer(extension_dir):
|
||||
path_installer = os.path.join(extension_dir, "install.py")
|
||||
if not os.path.isfile(path_installer):
|
||||
return
|
||||
|
||||
try:
|
||||
env = os.environ.copy()
|
||||
env['PYTHONPATH'] = os.path.abspath(".")
|
||||
|
||||
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
|
||||
except Exception as e:
|
||||
print(e, file=sys.stderr)
|
||||
|
||||
|
||||
def list_extensions(settings_file):
|
||||
settings = {}
|
||||
|
||||
try:
|
||||
if os.path.isfile(settings_file):
|
||||
with open(settings_file, "r", encoding="utf8") as file:
|
||||
settings = json.load(file)
|
||||
except Exception as e:
|
||||
print(e, file=sys.stderr)
|
||||
|
||||
disabled_extensions = set(settings.get('disabled_extensions', []))
|
||||
|
||||
return [x for x in os.listdir(dir_extensions) if x not in disabled_extensions]
|
||||
|
||||
|
||||
def run_extensions_installers(settings_file):
|
||||
if not os.path.isdir(dir_extensions):
|
||||
return
|
||||
|
||||
for dirname_extension in list_extensions(settings_file):
|
||||
run_extension_installer(os.path.join(dir_extensions, dirname_extension))
|
||||
|
||||
|
||||
def prepare_environment():
|
||||
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
|
||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
||||
|
||||
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
|
||||
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
|
||||
deepdanbooru_package = os.environ.get('DEEPDANBOORU_PACKAGE', "git+https://github.com/KichangKim/DeepDanbooru.git@edf73df4cdaeea2cf00e9ac08bd8a9026b7a7b26")
|
||||
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
|
||||
|
||||
xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl')
|
||||
|
||||
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/CompVis/stable-diffusion.git")
|
||||
taming_transformers_repo = os.environ.get('TAMING_REANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
|
||||
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
||||
taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
|
||||
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
||||
codeformer_repo = os.environ.get('CODEFORMET_REPO', 'https://github.com/sczhou/CodeFormer.git')
|
||||
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
|
||||
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
||||
|
||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
|
||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "47b6b607fdd31875c9279cd2f4f16b92e4ea958e")
|
||||
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
|
||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "f4e99857772fc3a126ba886aadf795a332774878")
|
||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "5b3af030dd83e0297272d861c19477735d0317ec")
|
||||
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||
|
||||
sys.argv += shlex.split(commandline_args)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default='config.json')
|
||||
args, _ = parser.parse_known_args(sys.argv)
|
||||
|
||||
sys.argv, _ = extract_arg(sys.argv, '-f')
|
||||
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
|
||||
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
|
||||
sys.argv, update_check = extract_arg(sys.argv, '--update-check')
|
||||
sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests')
|
||||
xformers = '--xformers' in sys.argv
|
||||
deepdanbooru = '--deepdanbooru' in sys.argv
|
||||
ngrok = '--ngrok' in sys.argv
|
||||
|
||||
try:
|
||||
@ -156,21 +214,27 @@ def prepare_enviroment():
|
||||
if not is_installed("clip"):
|
||||
run_pip(f"install {clip_package}", "clip")
|
||||
|
||||
if (not is_installed("xformers") or reinstall_xformers) and xformers and platform.python_version().startswith("3.10"):
|
||||
if not is_installed("open_clip"):
|
||||
run_pip(f"install {openclip_package}", "open_clip")
|
||||
|
||||
if (not is_installed("xformers") or reinstall_xformers) and xformers:
|
||||
if platform.system() == "Windows":
|
||||
if platform.python_version().startswith("3.10"):
|
||||
run_pip(f"install -U -I --no-deps {xformers_windows_package}", "xformers")
|
||||
else:
|
||||
print("Installation of xformers is not supported in this version of Python.")
|
||||
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
|
||||
if not is_installed("xformers"):
|
||||
exit(0)
|
||||
elif platform.system() == "Linux":
|
||||
run_pip("install xformers", "xformers")
|
||||
|
||||
if not is_installed("deepdanbooru") and deepdanbooru:
|
||||
run_pip(f"install {deepdanbooru_package}#egg=deepdanbooru[tensorflow] tensorflow==2.10.0 tensorflow-io==0.27.0", "deepdanbooru")
|
||||
|
||||
if not is_installed("pyngrok") and ngrok:
|
||||
run_pip("install pyngrok", "ngrok")
|
||||
|
||||
os.makedirs(dir_repos, exist_ok=True)
|
||||
|
||||
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash)
|
||||
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
||||
git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
|
||||
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
||||
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
||||
@ -181,6 +245,8 @@ def prepare_enviroment():
|
||||
|
||||
run_pip(f"install -r {requirements_file}", "requirements for Web UI")
|
||||
|
||||
run_extensions_installers(settings_file=args.ui_settings_file)
|
||||
|
||||
if update_check:
|
||||
version_check(commit)
|
||||
|
||||
@ -188,13 +254,42 @@ def prepare_enviroment():
|
||||
print("Exiting because of --exit argument")
|
||||
exit(0)
|
||||
|
||||
if run_tests:
|
||||
exitcode = tests(test_dir)
|
||||
exit(exitcode)
|
||||
|
||||
def start_webui():
|
||||
print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}")
|
||||
|
||||
def tests(test_dir):
|
||||
if "--api" not in sys.argv:
|
||||
sys.argv.append("--api")
|
||||
if "--ckpt" not in sys.argv:
|
||||
sys.argv.append("--ckpt")
|
||||
sys.argv.append("./test/test_files/empty.pt")
|
||||
if "--skip-torch-cuda-test" not in sys.argv:
|
||||
sys.argv.append("--skip-torch-cuda-test")
|
||||
|
||||
print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}")
|
||||
|
||||
with open('test/stdout.txt', "w", encoding="utf8") as stdout, open('test/stderr.txt', "w", encoding="utf8") as stderr:
|
||||
proc = subprocess.Popen([sys.executable, *sys.argv], stdout=stdout, stderr=stderr)
|
||||
|
||||
import test.server_poll
|
||||
exitcode = test.server_poll.run_tests(proc, test_dir)
|
||||
|
||||
print(f"Stopping Web UI process with id {proc.pid}")
|
||||
proc.kill()
|
||||
return exitcode
|
||||
|
||||
|
||||
def start():
|
||||
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
|
||||
import webui
|
||||
if '--nowebui' in sys.argv:
|
||||
webui.api_only()
|
||||
else:
|
||||
webui.webui()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
prepare_enviroment()
|
||||
start_webui()
|
||||
prepare_environment()
|
||||
start()
|
||||
|
BIN
models/VAE-approx/model.pt
Normal file
BIN
models/VAE-approx/model.pt
Normal file
Binary file not shown.
0
models/VAE/Put VAE here.txt
Normal file
0
models/VAE/Put VAE here.txt
Normal file
@ -1,67 +1,461 @@
|
||||
from modules.api.processing import StableDiffusionProcessingAPI
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, process_images
|
||||
from modules.sd_samplers import all_samplers
|
||||
from modules.extras import run_pnginfo
|
||||
import modules.shared as shared
|
||||
import uvicorn
|
||||
from fastapi import Body, APIRouter, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, Field, Json
|
||||
import json
|
||||
import io
|
||||
import base64
|
||||
import io
|
||||
import time
|
||||
import datetime
|
||||
import uvicorn
|
||||
from threading import Lock
|
||||
from io import BytesIO
|
||||
from gradio.processing_utils import decode_base64_to_file
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
from secrets import compare_digest
|
||||
|
||||
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
|
||||
import modules.shared as shared
|
||||
from modules import sd_samplers, deepbooru, sd_hijack
|
||||
from modules.api.models import *
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.extras import run_extras, run_pnginfo
|
||||
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
||||
from modules.textual_inversion.preprocess import preprocess
|
||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||
from PIL import PngImagePlugin,Image
|
||||
from modules.sd_models import checkpoints_list, find_checkpoint_config
|
||||
from modules.realesrgan_model import get_realesrgan_models
|
||||
from modules import devices
|
||||
from typing import List
|
||||
|
||||
class TextToImageResponse(BaseModel):
|
||||
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
parameters: Json
|
||||
info: Json
|
||||
def upscaler_to_index(name: str):
|
||||
try:
|
||||
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
|
||||
except:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
|
||||
|
||||
|
||||
def validate_sampler_name(name):
|
||||
config = sd_samplers.all_samplers_map.get(name, None)
|
||||
if config is None:
|
||||
raise HTTPException(status_code=404, detail="Sampler not found")
|
||||
|
||||
return name
|
||||
|
||||
def setUpscalers(req: dict):
|
||||
reqDict = vars(req)
|
||||
reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
|
||||
reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2)
|
||||
reqDict.pop('upscaler_1')
|
||||
reqDict.pop('upscaler_2')
|
||||
return reqDict
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";")[1].split(",")[1]
|
||||
return Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
|
||||
def encode_pil_to_base64(image):
|
||||
with io.BytesIO() as output_bytes:
|
||||
|
||||
# Copy any text-only metadata
|
||||
use_metadata = False
|
||||
metadata = PngImagePlugin.PngInfo()
|
||||
for key, value in image.info.items():
|
||||
if isinstance(key, str) and isinstance(value, str):
|
||||
metadata.add_text(key, value)
|
||||
use_metadata = True
|
||||
|
||||
image.save(
|
||||
output_bytes, "PNG", pnginfo=(metadata if use_metadata else None)
|
||||
)
|
||||
bytes_data = output_bytes.getvalue()
|
||||
return base64.b64encode(bytes_data)
|
||||
|
||||
def api_middleware(app: FastAPI):
|
||||
@app.middleware("http")
|
||||
async def log_and_time(req: Request, call_next):
|
||||
ts = time.time()
|
||||
res: Response = await call_next(req)
|
||||
duration = str(round(time.time() - ts, 4))
|
||||
res.headers["X-Process-Time"] = duration
|
||||
endpoint = req.scope.get('path', 'err')
|
||||
if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
|
||||
print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
|
||||
t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
|
||||
code = res.status_code,
|
||||
ver = req.scope.get('http_version', '0.0'),
|
||||
cli = req.scope.get('client', ('0:0.0.0', 0))[0],
|
||||
prot = req.scope.get('scheme', 'err'),
|
||||
method = req.scope.get('method', 'err'),
|
||||
endpoint = endpoint,
|
||||
duration = duration,
|
||||
))
|
||||
return res
|
||||
|
||||
|
||||
class Api:
|
||||
def __init__(self, app, queue_lock):
|
||||
def __init__(self, app: FastAPI, queue_lock: Lock):
|
||||
if shared.cmd_opts.api_auth:
|
||||
self.credentials = dict()
|
||||
for auth in shared.cmd_opts.api_auth.split(","):
|
||||
user, password = auth.split(":")
|
||||
self.credentials[user] = password
|
||||
|
||||
self.router = APIRouter()
|
||||
self.app = app
|
||||
self.queue_lock = queue_lock
|
||||
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
|
||||
api_middleware(self.app)
|
||||
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
|
||||
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
|
||||
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
|
||||
self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
|
||||
self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
|
||||
self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
|
||||
self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
|
||||
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
|
||||
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
|
||||
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
|
||||
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
|
||||
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
|
||||
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
|
||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
|
||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
|
||||
self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
|
||||
self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
|
||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
|
||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
|
||||
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
|
||||
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
|
||||
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
|
||||
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
|
||||
|
||||
def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
|
||||
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
|
||||
def add_api_route(self, path: str, endpoint, **kwargs):
|
||||
if shared.cmd_opts.api_auth:
|
||||
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
|
||||
return self.app.add_api_route(path, endpoint, **kwargs)
|
||||
|
||||
if sampler_index is None:
|
||||
raise HTTPException(status_code=404, detail="Sampler not found")
|
||||
def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
|
||||
if credentials.username in self.credentials:
|
||||
if compare_digest(credentials.password, self.credentials[credentials.username]):
|
||||
return True
|
||||
|
||||
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
|
||||
|
||||
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
||||
populate = txt2imgreq.copy(update={ # Override __init__ params
|
||||
"sd_model": shared.sd_model,
|
||||
"sampler_index": sampler_index[0],
|
||||
"sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
|
||||
"do_not_save_samples": True,
|
||||
"do_not_save_grid": True
|
||||
}
|
||||
)
|
||||
p = StableDiffusionProcessingTxt2Img(**vars(populate))
|
||||
if populate.sampler_name:
|
||||
populate.sampler_index = None # prevent a warning later on
|
||||
|
||||
with self.queue_lock:
|
||||
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **vars(populate))
|
||||
|
||||
shared.state.begin()
|
||||
processed = process_images(p)
|
||||
shared.state.end()
|
||||
|
||||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images))
|
||||
|
||||
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
||||
|
||||
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
|
||||
init_images = img2imgreq.init_images
|
||||
if init_images is None:
|
||||
raise HTTPException(status_code=404, detail="Init image not found")
|
||||
|
||||
mask = img2imgreq.mask
|
||||
if mask:
|
||||
mask = decode_base64_to_image(mask)
|
||||
|
||||
populate = img2imgreq.copy(update={ # Override __init__ params
|
||||
"sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
|
||||
"do_not_save_samples": True,
|
||||
"do_not_save_grid": True,
|
||||
"mask": mask
|
||||
}
|
||||
)
|
||||
if populate.sampler_name:
|
||||
populate.sampler_index = None # prevent a warning later on
|
||||
|
||||
args = vars(populate)
|
||||
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
|
||||
|
||||
with self.queue_lock:
|
||||
p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
|
||||
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
||||
|
||||
shared.state.begin()
|
||||
processed = process_images(p)
|
||||
shared.state.end()
|
||||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images))
|
||||
|
||||
if not img2imgreq.include_init_images:
|
||||
img2imgreq.init_images = None
|
||||
img2imgreq.mask = None
|
||||
|
||||
return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
|
||||
|
||||
def extras_single_image_api(self, req: ExtrasSingleImageRequest):
|
||||
reqDict = setUpscalers(req)
|
||||
|
||||
reqDict['image'] = decode_base64_to_image(reqDict['image'])
|
||||
|
||||
with self.queue_lock:
|
||||
result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||
|
||||
return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
|
||||
|
||||
def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
|
||||
reqDict = setUpscalers(req)
|
||||
|
||||
def prepareFiles(file):
|
||||
file = decode_base64_to_file(file.data, file_path=file.name)
|
||||
file.orig_name = file.name
|
||||
return file
|
||||
|
||||
reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
|
||||
reqDict.pop('imageList')
|
||||
|
||||
with self.queue_lock:
|
||||
result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||
|
||||
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
||||
|
||||
def pnginfoapi(self, req: PNGInfoRequest):
|
||||
if(not req.image.strip()):
|
||||
return PNGInfoResponse(info="")
|
||||
|
||||
result = run_pnginfo(decode_base64_to_image(req.image.strip()))
|
||||
|
||||
return PNGInfoResponse(info=result[1])
|
||||
|
||||
def progressapi(self, req: ProgressRequest = Depends()):
|
||||
# copy from check_progress_call of ui.py
|
||||
|
||||
if shared.state.job_count == 0:
|
||||
return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict())
|
||||
|
||||
# avoid dividing zero
|
||||
progress = 0.01
|
||||
|
||||
if shared.state.job_count > 0:
|
||||
progress += shared.state.job_no / shared.state.job_count
|
||||
if shared.state.sampling_steps > 0:
|
||||
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
|
||||
|
||||
time_since_start = time.time() - shared.state.time_start
|
||||
eta = (time_since_start/progress)
|
||||
eta_relative = eta-time_since_start
|
||||
|
||||
progress = min(progress, 1)
|
||||
|
||||
shared.state.set_current_image()
|
||||
|
||||
current_image = None
|
||||
if shared.state.current_image and not req.skip_current_image:
|
||||
current_image = encode_pil_to_base64(shared.state.current_image)
|
||||
|
||||
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
|
||||
|
||||
def interrogateapi(self, interrogatereq: InterrogateRequest):
|
||||
image_b64 = interrogatereq.image
|
||||
if image_b64 is None:
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
|
||||
img = decode_base64_to_image(image_b64)
|
||||
img = img.convert('RGB')
|
||||
|
||||
# Override object param
|
||||
with self.queue_lock:
|
||||
processed = process_images(p)
|
||||
if interrogatereq.model == "clip":
|
||||
processed = shared.interrogator.interrogate(img)
|
||||
elif interrogatereq.model == "deepdanbooru":
|
||||
processed = deepbooru.model.tag(img)
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
b64images = []
|
||||
for i in processed.images:
|
||||
buffer = io.BytesIO()
|
||||
i.save(buffer, format="png")
|
||||
b64images.append(base64.b64encode(buffer.getvalue()))
|
||||
return InterrogateResponse(caption=processed)
|
||||
|
||||
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info))
|
||||
def interruptapi(self):
|
||||
shared.state.interrupt()
|
||||
|
||||
return {}
|
||||
|
||||
def skip(self):
|
||||
shared.state.skip()
|
||||
|
||||
def img2imgapi(self):
|
||||
raise NotImplementedError
|
||||
def get_config(self):
|
||||
options = {}
|
||||
for key in shared.opts.data.keys():
|
||||
metadata = shared.opts.data_labels.get(key)
|
||||
if(metadata is not None):
|
||||
options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
|
||||
else:
|
||||
options.update({key: shared.opts.data.get(key, None)})
|
||||
|
||||
def extrasapi(self):
|
||||
raise NotImplementedError
|
||||
return options
|
||||
|
||||
def pnginfoapi(self):
|
||||
raise NotImplementedError
|
||||
def set_config(self, req: Dict[str, Any]):
|
||||
for k, v in req.items():
|
||||
shared.opts.set(k, v)
|
||||
|
||||
shared.opts.save(shared.config_filename)
|
||||
return
|
||||
|
||||
def get_cmd_flags(self):
|
||||
return vars(shared.cmd_opts)
|
||||
|
||||
def get_samplers(self):
|
||||
return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
|
||||
|
||||
def get_upscalers(self):
|
||||
upscalers = []
|
||||
|
||||
for upscaler in shared.sd_upscalers:
|
||||
u = upscaler.scaler
|
||||
upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url})
|
||||
|
||||
return upscalers
|
||||
|
||||
def get_sd_models(self):
|
||||
return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
|
||||
|
||||
def get_hypernetworks(self):
|
||||
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
||||
|
||||
def get_face_restorers(self):
|
||||
return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]
|
||||
|
||||
def get_realesrgan_models(self):
|
||||
return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
|
||||
|
||||
def get_prompt_styles(self):
|
||||
styleList = []
|
||||
for k in shared.prompt_styles.styles:
|
||||
style = shared.prompt_styles.styles[k]
|
||||
styleList.append({"name":style[0], "prompt": style[1], "negative_prompt": style[2]})
|
||||
|
||||
return styleList
|
||||
|
||||
def get_artists_categories(self):
|
||||
return shared.artist_db.cats
|
||||
|
||||
def get_artists(self):
|
||||
return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
|
||||
|
||||
def get_embeddings(self):
|
||||
db = sd_hijack.model_hijack.embedding_db
|
||||
|
||||
def convert_embedding(embedding):
|
||||
return {
|
||||
"step": embedding.step,
|
||||
"sd_checkpoint": embedding.sd_checkpoint,
|
||||
"sd_checkpoint_name": embedding.sd_checkpoint_name,
|
||||
"shape": embedding.shape,
|
||||
"vectors": embedding.vectors,
|
||||
}
|
||||
|
||||
def convert_embeddings(embeddings):
|
||||
return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
|
||||
|
||||
return {
|
||||
"loaded": convert_embeddings(db.word_embeddings),
|
||||
"skipped": convert_embeddings(db.skipped_embeddings),
|
||||
}
|
||||
|
||||
def refresh_checkpoints(self):
|
||||
shared.refresh_checkpoints()
|
||||
|
||||
def create_embedding(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
filename = create_embedding(**args) # create empty embedding
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
|
||||
shared.state.end()
|
||||
return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename))
|
||||
except AssertionError as e:
|
||||
shared.state.end()
|
||||
return TrainResponse(info = "create embedding error: {error}".format(error = e))
|
||||
|
||||
def create_hypernetwork(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
filename = create_hypernetwork(**args) # create empty embedding
|
||||
shared.state.end()
|
||||
return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename))
|
||||
except AssertionError as e:
|
||||
shared.state.end()
|
||||
return TrainResponse(info = "create hypernetwork error: {error}".format(error = e))
|
||||
|
||||
def preprocess(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
|
||||
shared.state.end()
|
||||
return PreprocessResponse(info = 'preprocess complete')
|
||||
except KeyError as e:
|
||||
shared.state.end()
|
||||
return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e))
|
||||
except AssertionError as e:
|
||||
shared.state.end()
|
||||
return PreprocessResponse(info = "preprocess error: {error}".format(error = e))
|
||||
except FileNotFoundError as e:
|
||||
shared.state.end()
|
||||
return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e))
|
||||
|
||||
def train_embedding(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
apply_optimizations = shared.opts.training_xattention_optimizations
|
||||
error = None
|
||||
filename = ''
|
||||
if not apply_optimizations:
|
||||
sd_hijack.undo_optimizations()
|
||||
try:
|
||||
embedding, filename = train_embedding(**args) # can take a long time to complete
|
||||
except Exception as e:
|
||||
error = e
|
||||
finally:
|
||||
if not apply_optimizations:
|
||||
sd_hijack.apply_optimizations()
|
||||
shared.state.end()
|
||||
return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
|
||||
except AssertionError as msg:
|
||||
shared.state.end()
|
||||
return TrainResponse(info = "train embedding error: {msg}".format(msg = msg))
|
||||
|
||||
def train_hypernetwork(self, args: dict):
|
||||
try:
|
||||
shared.state.begin()
|
||||
initial_hypernetwork = shared.loaded_hypernetwork
|
||||
apply_optimizations = shared.opts.training_xattention_optimizations
|
||||
error = None
|
||||
filename = ''
|
||||
if not apply_optimizations:
|
||||
sd_hijack.undo_optimizations()
|
||||
try:
|
||||
hypernetwork, filename = train_hypernetwork(*args)
|
||||
except Exception as e:
|
||||
error = e
|
||||
finally:
|
||||
shared.loaded_hypernetwork = initial_hypernetwork
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
if not apply_optimizations:
|
||||
sd_hijack.apply_optimizations()
|
||||
shared.state.end()
|
||||
return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
|
||||
except AssertionError as msg:
|
||||
shared.state.end()
|
||||
return TrainResponse(info = "train embedding error: {error}".format(error = error))
|
||||
|
||||
def launch(self, server_name, port):
|
||||
self.app.include_router(self.router)
|
||||
|
261
modules/api/models.py
Normal file
261
modules/api/models.py
Normal file
@ -0,0 +1,261 @@
|
||||
import inspect
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from typing import Any, Optional
|
||||
from typing_extensions import Literal
|
||||
from inflection import underscore
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
|
||||
from modules.shared import sd_upscalers, opts, parser
|
||||
from typing import Dict, List
|
||||
|
||||
API_NOT_ALLOWED = [
|
||||
"self",
|
||||
"kwargs",
|
||||
"sd_model",
|
||||
"outpath_samples",
|
||||
"outpath_grids",
|
||||
"sampler_index",
|
||||
"do_not_save_samples",
|
||||
"do_not_save_grid",
|
||||
"extra_generation_params",
|
||||
"overlay_images",
|
||||
"do_not_reload_embeddings",
|
||||
"seed_enable_extras",
|
||||
"prompt_for_display",
|
||||
"sampler_noise_scheduler_override",
|
||||
"ddim_discretize"
|
||||
]
|
||||
|
||||
class ModelDef(BaseModel):
|
||||
"""Assistance Class for Pydantic Dynamic Model Generation"""
|
||||
|
||||
field: str
|
||||
field_alias: str
|
||||
field_type: Any
|
||||
field_value: Any
|
||||
field_exclude: bool = False
|
||||
|
||||
|
||||
class PydanticModelGenerator:
|
||||
"""
|
||||
Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
|
||||
source_data is a snapshot of the default values produced by the class
|
||||
params are the names of the actual keys required by __init__
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = None,
|
||||
class_instance = None,
|
||||
additional_fields = None,
|
||||
):
|
||||
def field_type_generator(k, v):
|
||||
# field_type = str if not overrides.get(k) else overrides[k]["type"]
|
||||
# print(k, v.annotation, v.default)
|
||||
field_type = v.annotation
|
||||
|
||||
return Optional[field_type]
|
||||
|
||||
def merge_class_params(class_):
|
||||
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
|
||||
parameters = {}
|
||||
for classes in all_classes:
|
||||
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
|
||||
return parameters
|
||||
|
||||
|
||||
self._model_name = model_name
|
||||
self._class_data = merge_class_params(class_instance)
|
||||
|
||||
self._model_def = [
|
||||
ModelDef(
|
||||
field=underscore(k),
|
||||
field_alias=k,
|
||||
field_type=field_type_generator(k, v),
|
||||
field_value=v.default
|
||||
)
|
||||
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
|
||||
]
|
||||
|
||||
for fields in additional_fields:
|
||||
self._model_def.append(ModelDef(
|
||||
field=underscore(fields["key"]),
|
||||
field_alias=fields["key"],
|
||||
field_type=fields["type"],
|
||||
field_value=fields["default"],
|
||||
field_exclude=fields["exclude"] if "exclude" in fields else False))
|
||||
|
||||
def generate_model(self):
|
||||
"""
|
||||
Creates a pydantic BaseModel
|
||||
from the json and overrides provided at initialization
|
||||
"""
|
||||
fields = {
|
||||
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def
|
||||
}
|
||||
DynamicModel = create_model(self._model_name, **fields)
|
||||
DynamicModel.__config__.allow_population_by_field_name = True
|
||||
DynamicModel.__config__.allow_mutation = True
|
||||
return DynamicModel
|
||||
|
||||
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
|
||||
"StableDiffusionProcessingTxt2Img",
|
||||
StableDiffusionProcessingTxt2Img,
|
||||
[{"key": "sampler_index", "type": str, "default": "Euler"}]
|
||||
).generate_model()
|
||||
|
||||
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
|
||||
"StableDiffusionProcessingImg2Img",
|
||||
StableDiffusionProcessingImg2Img,
|
||||
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
|
||||
).generate_model()
|
||||
|
||||
class TextToImageResponse(BaseModel):
|
||||
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
parameters: dict
|
||||
info: str
|
||||
|
||||
class ImageToImageResponse(BaseModel):
|
||||
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
parameters: dict
|
||||
info: str
|
||||
|
||||
class ExtrasBaseRequest(BaseModel):
|
||||
resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.")
|
||||
show_extras_results: bool = Field(default=True, title="Show results", description="Should the backend return the generated image?")
|
||||
gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
|
||||
codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
|
||||
codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
|
||||
upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.")
|
||||
upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
|
||||
upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
|
||||
upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?")
|
||||
upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
|
||||
upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
|
||||
extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
|
||||
upscale_first: bool = Field(default=False, title="Upscale first", description="Should the upscaler run before restoring faces?")
|
||||
|
||||
class ExtraBaseResponse(BaseModel):
|
||||
html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")
|
||||
|
||||
class ExtrasSingleImageRequest(ExtrasBaseRequest):
|
||||
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
|
||||
|
||||
class ExtrasSingleImageResponse(ExtraBaseResponse):
|
||||
image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
|
||||
class FileData(BaseModel):
|
||||
data: str = Field(title="File data", description="Base64 representation of the file")
|
||||
name: str = Field(title="File name")
|
||||
|
||||
class ExtrasBatchImagesRequest(ExtrasBaseRequest):
|
||||
imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
|
||||
|
||||
class ExtrasBatchImagesResponse(ExtraBaseResponse):
|
||||
images: List[str] = Field(title="Images", description="The generated images in base64 format.")
|
||||
|
||||
class PNGInfoRequest(BaseModel):
|
||||
image: str = Field(title="Image", description="The base64 encoded PNG image")
|
||||
|
||||
class PNGInfoResponse(BaseModel):
|
||||
info: str = Field(title="Image info", description="A string with all the info the image had")
|
||||
|
||||
class ProgressRequest(BaseModel):
|
||||
skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
|
||||
|
||||
class ProgressResponse(BaseModel):
|
||||
progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
|
||||
eta_relative: float = Field(title="ETA in secs")
|
||||
state: dict = Field(title="State", description="The current state snapshot")
|
||||
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
|
||||
|
||||
class InterrogateRequest(BaseModel):
|
||||
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
|
||||
model: str = Field(default="clip", title="Model", description="The interrogate model used.")
|
||||
|
||||
class InterrogateResponse(BaseModel):
|
||||
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
|
||||
|
||||
class TrainResponse(BaseModel):
|
||||
info: str = Field(title="Train info", description="Response string from train embedding or hypernetwork task.")
|
||||
|
||||
class CreateResponse(BaseModel):
|
||||
info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")
|
||||
|
||||
class PreprocessResponse(BaseModel):
|
||||
info: str = Field(title="Preprocess info", description="Response string from preprocessing task.")
|
||||
|
||||
fields = {}
|
||||
for key, metadata in opts.data_labels.items():
|
||||
value = opts.data.get(key)
|
||||
optType = opts.typemap.get(type(metadata.default), type(value))
|
||||
|
||||
if (metadata is not None):
|
||||
fields.update({key: (Optional[optType], Field(
|
||||
default=metadata.default ,description=metadata.label))})
|
||||
else:
|
||||
fields.update({key: (Optional[optType], Field())})
|
||||
|
||||
OptionsModel = create_model("Options", **fields)
|
||||
|
||||
flags = {}
|
||||
_options = vars(parser)['_option_string_actions']
|
||||
for key in _options:
|
||||
if(_options[key].dest != 'help'):
|
||||
flag = _options[key]
|
||||
_type = str
|
||||
if _options[key].default is not None: _type = type(_options[key].default)
|
||||
flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
|
||||
|
||||
FlagsModel = create_model("Flags", **flags)
|
||||
|
||||
class SamplerItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
aliases: List[str] = Field(title="Aliases")
|
||||
options: Dict[str, str] = Field(title="Options")
|
||||
|
||||
class UpscalerItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
model_name: Optional[str] = Field(title="Model Name")
|
||||
model_path: Optional[str] = Field(title="Path")
|
||||
model_url: Optional[str] = Field(title="URL")
|
||||
|
||||
class SDModelItem(BaseModel):
|
||||
title: str = Field(title="Title")
|
||||
model_name: str = Field(title="Model Name")
|
||||
hash: str = Field(title="Hash")
|
||||
filename: str = Field(title="Filename")
|
||||
config: str = Field(title="Config file")
|
||||
|
||||
class HypernetworkItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
path: Optional[str] = Field(title="Path")
|
||||
|
||||
class FaceRestorerItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
cmd_dir: Optional[str] = Field(title="Path")
|
||||
|
||||
class RealesrganItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
path: Optional[str] = Field(title="Path")
|
||||
scale: Optional[int] = Field(title="Scale")
|
||||
|
||||
class PromptStyleItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
prompt: Optional[str] = Field(title="Prompt")
|
||||
negative_prompt: Optional[str] = Field(title="Negative Prompt")
|
||||
|
||||
class ArtistItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
score: float = Field(title="Score")
|
||||
category: str = Field(title="Category")
|
||||
|
||||
class EmbeddingItem(BaseModel):
|
||||
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
|
||||
sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
|
||||
sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
|
||||
shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
|
||||
vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
|
||||
|
||||
class EmbeddingsResponse(BaseModel):
|
||||
loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
|
||||
skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
|
@ -1,99 +0,0 @@
|
||||
from inflection import underscore
|
||||
from typing import Any, Dict, Optional
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img
|
||||
import inspect
|
||||
|
||||
|
||||
API_NOT_ALLOWED = [
|
||||
"self",
|
||||
"kwargs",
|
||||
"sd_model",
|
||||
"outpath_samples",
|
||||
"outpath_grids",
|
||||
"sampler_index",
|
||||
"do_not_save_samples",
|
||||
"do_not_save_grid",
|
||||
"extra_generation_params",
|
||||
"overlay_images",
|
||||
"do_not_reload_embeddings",
|
||||
"seed_enable_extras",
|
||||
"prompt_for_display",
|
||||
"sampler_noise_scheduler_override",
|
||||
"ddim_discretize"
|
||||
]
|
||||
|
||||
class ModelDef(BaseModel):
|
||||
"""Assistance Class for Pydantic Dynamic Model Generation"""
|
||||
|
||||
field: str
|
||||
field_alias: str
|
||||
field_type: Any
|
||||
field_value: Any
|
||||
|
||||
|
||||
class PydanticModelGenerator:
|
||||
"""
|
||||
Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
|
||||
source_data is a snapshot of the default values produced by the class
|
||||
params are the names of the actual keys required by __init__
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = None,
|
||||
class_instance = None,
|
||||
additional_fields = None,
|
||||
):
|
||||
def field_type_generator(k, v):
|
||||
# field_type = str if not overrides.get(k) else overrides[k]["type"]
|
||||
# print(k, v.annotation, v.default)
|
||||
field_type = v.annotation
|
||||
|
||||
return Optional[field_type]
|
||||
|
||||
def merge_class_params(class_):
|
||||
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
|
||||
parameters = {}
|
||||
for classes in all_classes:
|
||||
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
|
||||
return parameters
|
||||
|
||||
|
||||
self._model_name = model_name
|
||||
self._class_data = merge_class_params(class_instance)
|
||||
self._model_def = [
|
||||
ModelDef(
|
||||
field=underscore(k),
|
||||
field_alias=k,
|
||||
field_type=field_type_generator(k, v),
|
||||
field_value=v.default
|
||||
)
|
||||
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
|
||||
]
|
||||
|
||||
for fields in additional_fields:
|
||||
self._model_def.append(ModelDef(
|
||||
field=underscore(fields["key"]),
|
||||
field_alias=fields["key"],
|
||||
field_type=fields["type"],
|
||||
field_value=fields["default"]))
|
||||
|
||||
def generate_model(self):
|
||||
"""
|
||||
Creates a pydantic BaseModel
|
||||
from the json and overrides provided at initialization
|
||||
"""
|
||||
fields = {
|
||||
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def
|
||||
}
|
||||
DynamicModel = create_model(self._model_name, **fields)
|
||||
DynamicModel.__config__.allow_population_by_field_name = True
|
||||
DynamicModel.__config__.allow_mutation = True
|
||||
return DynamicModel
|
||||
|
||||
StableDiffusionProcessingAPI = PydanticModelGenerator(
|
||||
"StableDiffusionProcessingTxt2Img",
|
||||
StableDiffusionProcessingTxt2Img,
|
||||
[{"key": "sampler_index", "type": str, "default": "Euler"}]
|
||||
).generate_model()
|
@ -1,76 +0,0 @@
|
||||
import os.path
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import PIL.Image
|
||||
import numpy as np
|
||||
import torch
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
import modules.upscaler
|
||||
from modules import devices, modelloader
|
||||
from modules.bsrgan_model_arch import RRDBNet
|
||||
|
||||
|
||||
class UpscalerBSRGAN(modules.upscaler.Upscaler):
|
||||
def __init__(self, dirname):
|
||||
self.name = "BSRGAN"
|
||||
self.model_name = "BSRGAN 4x"
|
||||
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
|
||||
self.user_path = dirname
|
||||
super().__init__()
|
||||
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
|
||||
scalers = []
|
||||
if len(model_paths) == 0:
|
||||
scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4)
|
||||
scalers.append(scaler_data)
|
||||
for file in model_paths:
|
||||
if "http" in file:
|
||||
name = self.model_name
|
||||
else:
|
||||
name = modelloader.friendly_name(file)
|
||||
try:
|
||||
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
|
||||
scalers.append(scaler_data)
|
||||
except Exception:
|
||||
print(f"Error loading BSRGAN model: {file}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
self.scalers = scalers
|
||||
|
||||
def do_upscale(self, img: PIL.Image, selected_file):
|
||||
torch.cuda.empty_cache()
|
||||
model = self.load_model(selected_file)
|
||||
if model is None:
|
||||
return img
|
||||
model.to(devices.device_bsrgan)
|
||||
torch.cuda.empty_cache()
|
||||
img = np.array(img)
|
||||
img = img[:, :, ::-1]
|
||||
img = np.moveaxis(img, 2, 0) / 255
|
||||
img = torch.from_numpy(img).float()
|
||||
img = img.unsqueeze(0).to(devices.device_bsrgan)
|
||||
with torch.no_grad():
|
||||
output = model(img)
|
||||
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
output = 255. * np.moveaxis(output, 0, 2)
|
||||
output = output.astype(np.uint8)
|
||||
output = output[:, :, ::-1]
|
||||
torch.cuda.empty_cache()
|
||||
return PIL.Image.fromarray(output, 'RGB')
|
||||
|
||||
def load_model(self, path: str):
|
||||
if "http" in path:
|
||||
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
|
||||
progress=True)
|
||||
else:
|
||||
filename = path
|
||||
if not os.path.exists(filename) or filename is None:
|
||||
print(f"BSRGAN: Unable to load model from {filename}", file=sys.stderr)
|
||||
return None
|
||||
model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4) # define network
|
||||
model.load_state_dict(torch.load(filename), strict=True)
|
||||
model.eval()
|
||||
for k, v in model.named_parameters():
|
||||
v.requires_grad = False
|
||||
return model
|
||||
|
@ -1,102 +0,0 @@
|
||||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
|
||||
|
||||
def initialize_weights(net_l, scale=1):
|
||||
if not isinstance(net_l, list):
|
||||
net_l = [net_l]
|
||||
for net in net_l:
|
||||
for m in net.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||
m.weight.data *= scale # for residual block
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||
m.weight.data *= scale
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
|
||||
|
||||
def make_layer(block, n_layers):
|
||||
layers = []
|
||||
for _ in range(n_layers):
|
||||
layers.append(block())
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
def __init__(self, nf=64, gc=32, bias=True):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
# gc: growth channel, i.e. intermediate channels
|
||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
# initialization
|
||||
initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.lrelu(self.conv1(x))
|
||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
return x5 * 0.2 + x
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
'''Residual in Residual Dense Block'''
|
||||
|
||||
def __init__(self, nf, gc=32):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
|
||||
super(RRDBNet, self).__init__()
|
||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
||||
self.sf = sf
|
||||
|
||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
#### upsampling
|
||||
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
if self.sf==4:
|
||||
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.conv_first(x)
|
||||
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
||||
fea = fea + trunk
|
||||
|
||||
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||
if self.sf==4:
|
||||
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||
|
||||
return out
|
98
modules/call_queue.py
Normal file
98
modules/call_queue.py
Normal file
@ -0,0 +1,98 @@
|
||||
import html
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
import time
|
||||
|
||||
from modules import shared
|
||||
|
||||
queue_lock = threading.Lock()
|
||||
|
||||
|
||||
def wrap_queued_call(func):
|
||||
def f(*args, **kwargs):
|
||||
with queue_lock:
|
||||
res = func(*args, **kwargs)
|
||||
|
||||
return res
|
||||
|
||||
return f
|
||||
|
||||
|
||||
def wrap_gradio_gpu_call(func, extra_outputs=None):
|
||||
def f(*args, **kwargs):
|
||||
|
||||
shared.state.begin()
|
||||
|
||||
with queue_lock:
|
||||
res = func(*args, **kwargs)
|
||||
|
||||
shared.state.end()
|
||||
|
||||
return res
|
||||
|
||||
return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
|
||||
|
||||
|
||||
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
||||
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
|
||||
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
|
||||
if run_memmon:
|
||||
shared.mem_mon.monitor()
|
||||
t = time.perf_counter()
|
||||
|
||||
try:
|
||||
res = list(func(*args, **kwargs))
|
||||
except Exception as e:
|
||||
# When printing out our debug argument list, do not print out more than a MB of text
|
||||
max_debug_str_len = 131072 # (1024*1024)/8
|
||||
|
||||
print("Error completing request", file=sys.stderr)
|
||||
argStr = f"Arguments: {str(args)} {str(kwargs)}"
|
||||
print(argStr[:max_debug_str_len], file=sys.stderr)
|
||||
if len(argStr) > max_debug_str_len:
|
||||
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
|
||||
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
shared.state.job = ""
|
||||
shared.state.job_count = 0
|
||||
|
||||
if extra_outputs_array is None:
|
||||
extra_outputs_array = [None, '']
|
||||
|
||||
res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"]
|
||||
|
||||
shared.state.skipped = False
|
||||
shared.state.interrupted = False
|
||||
shared.state.job_count = 0
|
||||
|
||||
if not add_stats:
|
||||
return tuple(res)
|
||||
|
||||
elapsed = time.perf_counter() - t
|
||||
elapsed_m = int(elapsed // 60)
|
||||
elapsed_s = elapsed % 60
|
||||
elapsed_text = f"{elapsed_s:.2f}s"
|
||||
if elapsed_m > 0:
|
||||
elapsed_text = f"{elapsed_m}m "+elapsed_text
|
||||
|
||||
if run_memmon:
|
||||
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
|
||||
active_peak = mem_stats['active_peak']
|
||||
reserved_peak = mem_stats['reserved_peak']
|
||||
sys_peak = mem_stats['system_peak']
|
||||
sys_total = mem_stats['total']
|
||||
sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
|
||||
|
||||
vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
|
||||
else:
|
||||
vram_html = ''
|
||||
|
||||
# last item is always HTML
|
||||
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
|
||||
|
||||
return tuple(res)
|
||||
|
||||
return f
|
||||
|
@ -382,7 +382,7 @@ class VQAutoEncoder(nn.Module):
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
||||
logger.info(f'vqgan is loaded from: {model_path} [params]')
|
||||
else:
|
||||
raise ValueError(f'Wrong params!')
|
||||
raise ValueError('Wrong params!')
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
@ -431,7 +431,7 @@ class VQGANDiscriminator(nn.Module):
|
||||
elif 'params' in chkpt:
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
||||
else:
|
||||
raise ValueError(f'Wrong params!')
|
||||
raise ValueError('Wrong params!')
|
||||
|
||||
def forward(self, x):
|
||||
return self.main(x)
|
@ -36,6 +36,7 @@ def setup_model(dirname):
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from basicsr.utils import imwrite, img2tensor, tensor2img
|
||||
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
from facelib.detection.retinaface import retinaface
|
||||
from modules.shared import cmd_opts
|
||||
|
||||
net_class = CodeFormer
|
||||
@ -65,6 +66,8 @@ def setup_model(dirname):
|
||||
net.load_state_dict(checkpoint)
|
||||
net.eval()
|
||||
|
||||
if hasattr(retinaface, 'device'):
|
||||
retinaface.device = devices.device_codeformer
|
||||
face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
|
||||
|
||||
self.net = net
|
||||
|
@ -1,172 +1,99 @@
|
||||
import os.path
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
import multiprocessing
|
||||
import time
|
||||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
from modules import modelloader, paths, deepbooru_model, devices, images, shared
|
||||
|
||||
re_special = re.compile(r'([\\()])')
|
||||
|
||||
def get_deepbooru_tags(pil_image):
|
||||
"""
|
||||
This method is for running only one image at a time for simple use. Used to the img2img interrogate.
|
||||
"""
|
||||
from modules import shared # prevents circular reference
|
||||
|
||||
try:
|
||||
create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, create_deepbooru_opts())
|
||||
return get_tags_from_process(pil_image)
|
||||
finally:
|
||||
release_process()
|
||||
class DeepDanbooru:
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
|
||||
def load(self):
|
||||
if self.model is not None:
|
||||
return
|
||||
|
||||
OPT_INCLUDE_RANKS = "include_ranks"
|
||||
def create_deepbooru_opts():
|
||||
from modules import shared
|
||||
|
||||
return {
|
||||
"use_spaces": shared.opts.deepbooru_use_spaces,
|
||||
"use_escape": shared.opts.deepbooru_escape,
|
||||
"alpha_sort": shared.opts.deepbooru_sort_alpha,
|
||||
OPT_INCLUDE_RANKS: shared.opts.interrogate_return_ranks,
|
||||
}
|
||||
|
||||
|
||||
def deepbooru_process(queue, deepbooru_process_return, threshold, deepbooru_opts):
|
||||
model, tags = get_deepbooru_tags_model()
|
||||
while True: # while process is running, keep monitoring queue for new image
|
||||
pil_image = queue.get()
|
||||
if pil_image == "QUIT":
|
||||
break
|
||||
else:
|
||||
deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts)
|
||||
|
||||
|
||||
def create_deepbooru_process(threshold, deepbooru_opts):
|
||||
"""
|
||||
Creates deepbooru process. A queue is created to send images into the process. This enables multiple images
|
||||
to be processed in a row without reloading the model or creating a new process. To return the data, a shared
|
||||
dictionary is created to hold the tags created. To wait for tags to be returned, a value of -1 is assigned
|
||||
to the dictionary and the method adding the image to the queue should wait for this value to be updated with
|
||||
the tags.
|
||||
"""
|
||||
from modules import shared # prevents circular reference
|
||||
shared.deepbooru_process_manager = multiprocessing.Manager()
|
||||
shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue()
|
||||
shared.deepbooru_process_return = shared.deepbooru_process_manager.dict()
|
||||
shared.deepbooru_process_return["value"] = -1
|
||||
shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts))
|
||||
shared.deepbooru_process.start()
|
||||
|
||||
|
||||
def get_tags_from_process(image):
|
||||
from modules import shared
|
||||
|
||||
shared.deepbooru_process_return["value"] = -1
|
||||
shared.deepbooru_process_queue.put(image)
|
||||
while shared.deepbooru_process_return["value"] == -1:
|
||||
time.sleep(0.2)
|
||||
caption = shared.deepbooru_process_return["value"]
|
||||
shared.deepbooru_process_return["value"] = -1
|
||||
|
||||
return caption
|
||||
|
||||
|
||||
def release_process():
|
||||
"""
|
||||
Stops the deepbooru process to return used memory
|
||||
"""
|
||||
from modules import shared # prevents circular reference
|
||||
shared.deepbooru_process_queue.put("QUIT")
|
||||
shared.deepbooru_process.join()
|
||||
shared.deepbooru_process_queue = None
|
||||
shared.deepbooru_process = None
|
||||
shared.deepbooru_process_return = None
|
||||
shared.deepbooru_process_manager = None
|
||||
|
||||
def get_deepbooru_tags_model():
|
||||
import deepdanbooru as dd
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
this_folder = os.path.dirname(__file__)
|
||||
model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru'))
|
||||
if not os.path.exists(os.path.join(model_path, 'project.json')):
|
||||
# there is no point importing these every time
|
||||
import zipfile
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
load_file_from_url(
|
||||
r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
|
||||
model_path)
|
||||
with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref:
|
||||
zip_ref.extractall(model_path)
|
||||
os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"))
|
||||
|
||||
tags = dd.project.load_tags_from_project(model_path)
|
||||
model = dd.project.load_model_from_project(
|
||||
model_path, compile_model=False
|
||||
files = modelloader.load_models(
|
||||
model_path=os.path.join(paths.models_path, "torch_deepdanbooru"),
|
||||
model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',
|
||||
ext_filter=[".pt"],
|
||||
download_name='model-resnet_custom_v3.pt',
|
||||
)
|
||||
return model, tags
|
||||
|
||||
self.model = deepbooru_model.DeepDanbooruModel()
|
||||
self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
|
||||
|
||||
def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts):
|
||||
import deepdanbooru as dd
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
self.model.eval()
|
||||
self.model.to(devices.cpu, devices.dtype)
|
||||
|
||||
alpha_sort = deepbooru_opts['alpha_sort']
|
||||
use_spaces = deepbooru_opts['use_spaces']
|
||||
use_escape = deepbooru_opts['use_escape']
|
||||
include_ranks = deepbooru_opts['include_ranks']
|
||||
def start(self):
|
||||
self.load()
|
||||
self.model.to(devices.device)
|
||||
|
||||
width = model.input_shape[2]
|
||||
height = model.input_shape[1]
|
||||
image = np.array(pil_image)
|
||||
image = tf.image.resize(
|
||||
image,
|
||||
size=(height, width),
|
||||
method=tf.image.ResizeMethod.AREA,
|
||||
preserve_aspect_ratio=True,
|
||||
)
|
||||
image = image.numpy() # EagerTensor to np.array
|
||||
image = dd.image.transform_and_pad_image(image, width, height)
|
||||
image = image / 255.0
|
||||
image_shape = image.shape
|
||||
image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2]))
|
||||
def stop(self):
|
||||
if not shared.opts.interrogate_keep_models_in_memory:
|
||||
self.model.to(devices.cpu)
|
||||
devices.torch_gc()
|
||||
|
||||
y = model.predict(image)[0]
|
||||
def tag(self, pil_image):
|
||||
self.start()
|
||||
res = self.tag_multi(pil_image)
|
||||
self.stop()
|
||||
|
||||
result_dict = {}
|
||||
return res
|
||||
|
||||
for i, tag in enumerate(tags):
|
||||
result_dict[tag] = y[i]
|
||||
def tag_multi(self, pil_image, force_disable_ranks=False):
|
||||
threshold = shared.opts.interrogate_deepbooru_score_threshold
|
||||
use_spaces = shared.opts.deepbooru_use_spaces
|
||||
use_escape = shared.opts.deepbooru_escape
|
||||
alpha_sort = shared.opts.deepbooru_sort_alpha
|
||||
include_ranks = shared.opts.interrogate_return_ranks and not force_disable_ranks
|
||||
|
||||
pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
|
||||
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
|
||||
|
||||
with torch.no_grad(), devices.autocast():
|
||||
x = torch.from_numpy(a).to(devices.device)
|
||||
y = self.model(x)[0].detach().cpu().numpy()
|
||||
|
||||
probability_dict = {}
|
||||
|
||||
for tag, probability in zip(self.model.tags, y):
|
||||
if probability < threshold:
|
||||
continue
|
||||
|
||||
unsorted_tags_in_theshold = []
|
||||
result_tags_print = []
|
||||
for tag in tags:
|
||||
if result_dict[tag] >= threshold:
|
||||
if tag.startswith("rating:"):
|
||||
continue
|
||||
unsorted_tags_in_theshold.append((result_dict[tag], tag))
|
||||
result_tags_print.append(f'{result_dict[tag]} {tag}')
|
||||
|
||||
# sort tags
|
||||
result_tags_out = []
|
||||
sort_ndx = 0
|
||||
probability_dict[tag] = probability
|
||||
|
||||
if alpha_sort:
|
||||
sort_ndx = 1
|
||||
tags = sorted(probability_dict)
|
||||
else:
|
||||
tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
|
||||
|
||||
# sort by reverse by likelihood and normal for alpha, and format tag text as requested
|
||||
unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
|
||||
for weight, tag in unsorted_tags_in_theshold:
|
||||
res = []
|
||||
|
||||
filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")])
|
||||
|
||||
for tag in [x for x in tags if x not in filtertags]:
|
||||
probability = probability_dict[tag]
|
||||
tag_outformat = tag
|
||||
if use_spaces:
|
||||
tag_outformat = tag_outformat.replace('_', ' ')
|
||||
if use_escape:
|
||||
tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
|
||||
if include_ranks:
|
||||
tag_outformat = f"({tag_outformat}:{weight:.3f})"
|
||||
tag_outformat = f"({tag_outformat}:{probability:.3f})"
|
||||
|
||||
result_tags_out.append(tag_outformat)
|
||||
res.append(tag_outformat)
|
||||
|
||||
print('\n'.join(sorted(result_tags_print, reverse=True)))
|
||||
return ", ".join(res)
|
||||
|
||||
return ', '.join(result_tags_out)
|
||||
|
||||
model = DeepDanbooru()
|
||||
|
676
modules/deepbooru_model.py
Normal file
676
modules/deepbooru_model.py
Normal file
@ -0,0 +1,676 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
|
||||
|
||||
|
||||
class DeepDanbooruModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(DeepDanbooruModel, self).__init__()
|
||||
|
||||
self.tags = []
|
||||
|
||||
self.n_Conv_0 = nn.Conv2d(kernel_size=(7, 7), in_channels=3, out_channels=64, stride=(2, 2))
|
||||
self.n_MaxPool_0 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))
|
||||
self.n_Conv_1 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
||||
self.n_Conv_2 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=64)
|
||||
self.n_Conv_3 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
||||
self.n_Conv_4 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
||||
self.n_Conv_5 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
|
||||
self.n_Conv_6 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
||||
self.n_Conv_7 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
||||
self.n_Conv_8 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
|
||||
self.n_Conv_9 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
|
||||
self.n_Conv_10 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
|
||||
self.n_Conv_11 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=512, stride=(2, 2))
|
||||
self.n_Conv_12 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=128)
|
||||
self.n_Conv_13 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128, stride=(2, 2))
|
||||
self.n_Conv_14 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_15 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_16 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_17 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_18 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_19 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_20 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_21 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_22 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_23 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_24 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_25 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_26 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_27 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_28 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_29 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_30 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_31 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_32 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_33 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
|
||||
self.n_Conv_34 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
|
||||
self.n_Conv_35 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
|
||||
self.n_Conv_36 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=1024, stride=(2, 2))
|
||||
self.n_Conv_37 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=256)
|
||||
self.n_Conv_38 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
|
||||
self.n_Conv_39 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_40 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_41 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_42 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_43 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_44 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_45 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_46 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_47 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_48 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_49 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_50 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_51 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_52 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_53 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_54 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_55 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_56 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_57 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_58 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_59 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_60 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_61 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_62 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_63 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_64 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_65 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_66 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_67 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_68 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_69 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_70 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_71 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_72 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_73 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_74 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_75 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_76 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_77 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_78 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_79 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_80 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_81 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_82 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_83 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_84 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_85 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_86 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_87 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_88 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_89 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_90 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_91 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_92 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_93 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_94 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_95 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_96 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_97 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_98 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
|
||||
self.n_Conv_99 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_100 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=1024, stride=(2, 2))
|
||||
self.n_Conv_101 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_102 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_103 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_104 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_105 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_106 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_107 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_108 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_109 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_110 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_111 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_112 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_113 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_114 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_115 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_116 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_117 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_118 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_119 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_120 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_121 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_122 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_123 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_124 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_125 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_126 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_127 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_128 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_129 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_130 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_131 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_132 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_133 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_134 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_135 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_136 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_137 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_138 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_139 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_140 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_141 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_142 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_143 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_144 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_145 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_146 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_147 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_148 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_149 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_150 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_151 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_152 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_153 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_154 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_155 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
|
||||
self.n_Conv_156 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
|
||||
self.n_Conv_157 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
|
||||
self.n_Conv_158 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=2048, stride=(2, 2))
|
||||
self.n_Conv_159 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=512)
|
||||
self.n_Conv_160 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512, stride=(2, 2))
|
||||
self.n_Conv_161 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
||||
self.n_Conv_162 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
|
||||
self.n_Conv_163 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
|
||||
self.n_Conv_164 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
||||
self.n_Conv_165 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
|
||||
self.n_Conv_166 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
|
||||
self.n_Conv_167 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
|
||||
self.n_Conv_168 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=4096, stride=(2, 2))
|
||||
self.n_Conv_169 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=1024)
|
||||
self.n_Conv_170 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, stride=(2, 2))
|
||||
self.n_Conv_171 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
||||
self.n_Conv_172 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
|
||||
self.n_Conv_173 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
|
||||
self.n_Conv_174 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
||||
self.n_Conv_175 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
|
||||
self.n_Conv_176 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
|
||||
self.n_Conv_177 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
|
||||
self.n_Conv_178 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=9176, bias=False)
|
||||
|
||||
def forward(self, *inputs):
|
||||
t_358, = inputs
|
||||
t_359 = t_358.permute(*[0, 3, 1, 2])
|
||||
t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
|
||||
t_360 = self.n_Conv_0(t_359_padded)
|
||||
t_361 = F.relu(t_360)
|
||||
t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
|
||||
t_362 = self.n_MaxPool_0(t_361)
|
||||
t_363 = self.n_Conv_1(t_362)
|
||||
t_364 = self.n_Conv_2(t_362)
|
||||
t_365 = F.relu(t_364)
|
||||
t_365_padded = F.pad(t_365, [1, 1, 1, 1], value=0)
|
||||
t_366 = self.n_Conv_3(t_365_padded)
|
||||
t_367 = F.relu(t_366)
|
||||
t_368 = self.n_Conv_4(t_367)
|
||||
t_369 = torch.add(t_368, t_363)
|
||||
t_370 = F.relu(t_369)
|
||||
t_371 = self.n_Conv_5(t_370)
|
||||
t_372 = F.relu(t_371)
|
||||
t_372_padded = F.pad(t_372, [1, 1, 1, 1], value=0)
|
||||
t_373 = self.n_Conv_6(t_372_padded)
|
||||
t_374 = F.relu(t_373)
|
||||
t_375 = self.n_Conv_7(t_374)
|
||||
t_376 = torch.add(t_375, t_370)
|
||||
t_377 = F.relu(t_376)
|
||||
t_378 = self.n_Conv_8(t_377)
|
||||
t_379 = F.relu(t_378)
|
||||
t_379_padded = F.pad(t_379, [1, 1, 1, 1], value=0)
|
||||
t_380 = self.n_Conv_9(t_379_padded)
|
||||
t_381 = F.relu(t_380)
|
||||
t_382 = self.n_Conv_10(t_381)
|
||||
t_383 = torch.add(t_382, t_377)
|
||||
t_384 = F.relu(t_383)
|
||||
t_385 = self.n_Conv_11(t_384)
|
||||
t_386 = self.n_Conv_12(t_384)
|
||||
t_387 = F.relu(t_386)
|
||||
t_387_padded = F.pad(t_387, [0, 1, 0, 1], value=0)
|
||||
t_388 = self.n_Conv_13(t_387_padded)
|
||||
t_389 = F.relu(t_388)
|
||||
t_390 = self.n_Conv_14(t_389)
|
||||
t_391 = torch.add(t_390, t_385)
|
||||
t_392 = F.relu(t_391)
|
||||
t_393 = self.n_Conv_15(t_392)
|
||||
t_394 = F.relu(t_393)
|
||||
t_394_padded = F.pad(t_394, [1, 1, 1, 1], value=0)
|
||||
t_395 = self.n_Conv_16(t_394_padded)
|
||||
t_396 = F.relu(t_395)
|
||||
t_397 = self.n_Conv_17(t_396)
|
||||
t_398 = torch.add(t_397, t_392)
|
||||
t_399 = F.relu(t_398)
|
||||
t_400 = self.n_Conv_18(t_399)
|
||||
t_401 = F.relu(t_400)
|
||||
t_401_padded = F.pad(t_401, [1, 1, 1, 1], value=0)
|
||||
t_402 = self.n_Conv_19(t_401_padded)
|
||||
t_403 = F.relu(t_402)
|
||||
t_404 = self.n_Conv_20(t_403)
|
||||
t_405 = torch.add(t_404, t_399)
|
||||
t_406 = F.relu(t_405)
|
||||
t_407 = self.n_Conv_21(t_406)
|
||||
t_408 = F.relu(t_407)
|
||||
t_408_padded = F.pad(t_408, [1, 1, 1, 1], value=0)
|
||||
t_409 = self.n_Conv_22(t_408_padded)
|
||||
t_410 = F.relu(t_409)
|
||||
t_411 = self.n_Conv_23(t_410)
|
||||
t_412 = torch.add(t_411, t_406)
|
||||
t_413 = F.relu(t_412)
|
||||
t_414 = self.n_Conv_24(t_413)
|
||||
t_415 = F.relu(t_414)
|
||||
t_415_padded = F.pad(t_415, [1, 1, 1, 1], value=0)
|
||||
t_416 = self.n_Conv_25(t_415_padded)
|
||||
t_417 = F.relu(t_416)
|
||||
t_418 = self.n_Conv_26(t_417)
|
||||
t_419 = torch.add(t_418, t_413)
|
||||
t_420 = F.relu(t_419)
|
||||
t_421 = self.n_Conv_27(t_420)
|
||||
t_422 = F.relu(t_421)
|
||||
t_422_padded = F.pad(t_422, [1, 1, 1, 1], value=0)
|
||||
t_423 = self.n_Conv_28(t_422_padded)
|
||||
t_424 = F.relu(t_423)
|
||||
t_425 = self.n_Conv_29(t_424)
|
||||
t_426 = torch.add(t_425, t_420)
|
||||
t_427 = F.relu(t_426)
|
||||
t_428 = self.n_Conv_30(t_427)
|
||||
t_429 = F.relu(t_428)
|
||||
t_429_padded = F.pad(t_429, [1, 1, 1, 1], value=0)
|
||||
t_430 = self.n_Conv_31(t_429_padded)
|
||||
t_431 = F.relu(t_430)
|
||||
t_432 = self.n_Conv_32(t_431)
|
||||
t_433 = torch.add(t_432, t_427)
|
||||
t_434 = F.relu(t_433)
|
||||
t_435 = self.n_Conv_33(t_434)
|
||||
t_436 = F.relu(t_435)
|
||||
t_436_padded = F.pad(t_436, [1, 1, 1, 1], value=0)
|
||||
t_437 = self.n_Conv_34(t_436_padded)
|
||||
t_438 = F.relu(t_437)
|
||||
t_439 = self.n_Conv_35(t_438)
|
||||
t_440 = torch.add(t_439, t_434)
|
||||
t_441 = F.relu(t_440)
|
||||
t_442 = self.n_Conv_36(t_441)
|
||||
t_443 = self.n_Conv_37(t_441)
|
||||
t_444 = F.relu(t_443)
|
||||
t_444_padded = F.pad(t_444, [0, 1, 0, 1], value=0)
|
||||
t_445 = self.n_Conv_38(t_444_padded)
|
||||
t_446 = F.relu(t_445)
|
||||
t_447 = self.n_Conv_39(t_446)
|
||||
t_448 = torch.add(t_447, t_442)
|
||||
t_449 = F.relu(t_448)
|
||||
t_450 = self.n_Conv_40(t_449)
|
||||
t_451 = F.relu(t_450)
|
||||
t_451_padded = F.pad(t_451, [1, 1, 1, 1], value=0)
|
||||
t_452 = self.n_Conv_41(t_451_padded)
|
||||
t_453 = F.relu(t_452)
|
||||
t_454 = self.n_Conv_42(t_453)
|
||||
t_455 = torch.add(t_454, t_449)
|
||||
t_456 = F.relu(t_455)
|
||||
t_457 = self.n_Conv_43(t_456)
|
||||
t_458 = F.relu(t_457)
|
||||
t_458_padded = F.pad(t_458, [1, 1, 1, 1], value=0)
|
||||
t_459 = self.n_Conv_44(t_458_padded)
|
||||
t_460 = F.relu(t_459)
|
||||
t_461 = self.n_Conv_45(t_460)
|
||||
t_462 = torch.add(t_461, t_456)
|
||||
t_463 = F.relu(t_462)
|
||||
t_464 = self.n_Conv_46(t_463)
|
||||
t_465 = F.relu(t_464)
|
||||
t_465_padded = F.pad(t_465, [1, 1, 1, 1], value=0)
|
||||
t_466 = self.n_Conv_47(t_465_padded)
|
||||
t_467 = F.relu(t_466)
|
||||
t_468 = self.n_Conv_48(t_467)
|
||||
t_469 = torch.add(t_468, t_463)
|
||||
t_470 = F.relu(t_469)
|
||||
t_471 = self.n_Conv_49(t_470)
|
||||
t_472 = F.relu(t_471)
|
||||
t_472_padded = F.pad(t_472, [1, 1, 1, 1], value=0)
|
||||
t_473 = self.n_Conv_50(t_472_padded)
|
||||
t_474 = F.relu(t_473)
|
||||
t_475 = self.n_Conv_51(t_474)
|
||||
t_476 = torch.add(t_475, t_470)
|
||||
t_477 = F.relu(t_476)
|
||||
t_478 = self.n_Conv_52(t_477)
|
||||
t_479 = F.relu(t_478)
|
||||
t_479_padded = F.pad(t_479, [1, 1, 1, 1], value=0)
|
||||
t_480 = self.n_Conv_53(t_479_padded)
|
||||
t_481 = F.relu(t_480)
|
||||
t_482 = self.n_Conv_54(t_481)
|
||||
t_483 = torch.add(t_482, t_477)
|
||||
t_484 = F.relu(t_483)
|
||||
t_485 = self.n_Conv_55(t_484)
|
||||
t_486 = F.relu(t_485)
|
||||
t_486_padded = F.pad(t_486, [1, 1, 1, 1], value=0)
|
||||
t_487 = self.n_Conv_56(t_486_padded)
|
||||
t_488 = F.relu(t_487)
|
||||
t_489 = self.n_Conv_57(t_488)
|
||||
t_490 = torch.add(t_489, t_484)
|
||||
t_491 = F.relu(t_490)
|
||||
t_492 = self.n_Conv_58(t_491)
|
||||
t_493 = F.relu(t_492)
|
||||
t_493_padded = F.pad(t_493, [1, 1, 1, 1], value=0)
|
||||
t_494 = self.n_Conv_59(t_493_padded)
|
||||
t_495 = F.relu(t_494)
|
||||
t_496 = self.n_Conv_60(t_495)
|
||||
t_497 = torch.add(t_496, t_491)
|
||||
t_498 = F.relu(t_497)
|
||||
t_499 = self.n_Conv_61(t_498)
|
||||
t_500 = F.relu(t_499)
|
||||
t_500_padded = F.pad(t_500, [1, 1, 1, 1], value=0)
|
||||
t_501 = self.n_Conv_62(t_500_padded)
|
||||
t_502 = F.relu(t_501)
|
||||
t_503 = self.n_Conv_63(t_502)
|
||||
t_504 = torch.add(t_503, t_498)
|
||||
t_505 = F.relu(t_504)
|
||||
t_506 = self.n_Conv_64(t_505)
|
||||
t_507 = F.relu(t_506)
|
||||
t_507_padded = F.pad(t_507, [1, 1, 1, 1], value=0)
|
||||
t_508 = self.n_Conv_65(t_507_padded)
|
||||
t_509 = F.relu(t_508)
|
||||
t_510 = self.n_Conv_66(t_509)
|
||||
t_511 = torch.add(t_510, t_505)
|
||||
t_512 = F.relu(t_511)
|
||||
t_513 = self.n_Conv_67(t_512)
|
||||
t_514 = F.relu(t_513)
|
||||
t_514_padded = F.pad(t_514, [1, 1, 1, 1], value=0)
|
||||
t_515 = self.n_Conv_68(t_514_padded)
|
||||
t_516 = F.relu(t_515)
|
||||
t_517 = self.n_Conv_69(t_516)
|
||||
t_518 = torch.add(t_517, t_512)
|
||||
t_519 = F.relu(t_518)
|
||||
t_520 = self.n_Conv_70(t_519)
|
||||
t_521 = F.relu(t_520)
|
||||
t_521_padded = F.pad(t_521, [1, 1, 1, 1], value=0)
|
||||
t_522 = self.n_Conv_71(t_521_padded)
|
||||
t_523 = F.relu(t_522)
|
||||
t_524 = self.n_Conv_72(t_523)
|
||||
t_525 = torch.add(t_524, t_519)
|
||||
t_526 = F.relu(t_525)
|
||||
t_527 = self.n_Conv_73(t_526)
|
||||
t_528 = F.relu(t_527)
|
||||
t_528_padded = F.pad(t_528, [1, 1, 1, 1], value=0)
|
||||
t_529 = self.n_Conv_74(t_528_padded)
|
||||
t_530 = F.relu(t_529)
|
||||
t_531 = self.n_Conv_75(t_530)
|
||||
t_532 = torch.add(t_531, t_526)
|
||||
t_533 = F.relu(t_532)
|
||||
t_534 = self.n_Conv_76(t_533)
|
||||
t_535 = F.relu(t_534)
|
||||
t_535_padded = F.pad(t_535, [1, 1, 1, 1], value=0)
|
||||
t_536 = self.n_Conv_77(t_535_padded)
|
||||
t_537 = F.relu(t_536)
|
||||
t_538 = self.n_Conv_78(t_537)
|
||||
t_539 = torch.add(t_538, t_533)
|
||||
t_540 = F.relu(t_539)
|
||||
t_541 = self.n_Conv_79(t_540)
|
||||
t_542 = F.relu(t_541)
|
||||
t_542_padded = F.pad(t_542, [1, 1, 1, 1], value=0)
|
||||
t_543 = self.n_Conv_80(t_542_padded)
|
||||
t_544 = F.relu(t_543)
|
||||
t_545 = self.n_Conv_81(t_544)
|
||||
t_546 = torch.add(t_545, t_540)
|
||||
t_547 = F.relu(t_546)
|
||||
t_548 = self.n_Conv_82(t_547)
|
||||
t_549 = F.relu(t_548)
|
||||
t_549_padded = F.pad(t_549, [1, 1, 1, 1], value=0)
|
||||
t_550 = self.n_Conv_83(t_549_padded)
|
||||
t_551 = F.relu(t_550)
|
||||
t_552 = self.n_Conv_84(t_551)
|
||||
t_553 = torch.add(t_552, t_547)
|
||||
t_554 = F.relu(t_553)
|
||||
t_555 = self.n_Conv_85(t_554)
|
||||
t_556 = F.relu(t_555)
|
||||
t_556_padded = F.pad(t_556, [1, 1, 1, 1], value=0)
|
||||
t_557 = self.n_Conv_86(t_556_padded)
|
||||
t_558 = F.relu(t_557)
|
||||
t_559 = self.n_Conv_87(t_558)
|
||||
t_560 = torch.add(t_559, t_554)
|
||||
t_561 = F.relu(t_560)
|
||||
t_562 = self.n_Conv_88(t_561)
|
||||
t_563 = F.relu(t_562)
|
||||
t_563_padded = F.pad(t_563, [1, 1, 1, 1], value=0)
|
||||
t_564 = self.n_Conv_89(t_563_padded)
|
||||
t_565 = F.relu(t_564)
|
||||
t_566 = self.n_Conv_90(t_565)
|
||||
t_567 = torch.add(t_566, t_561)
|
||||
t_568 = F.relu(t_567)
|
||||
t_569 = self.n_Conv_91(t_568)
|
||||
t_570 = F.relu(t_569)
|
||||
t_570_padded = F.pad(t_570, [1, 1, 1, 1], value=0)
|
||||
t_571 = self.n_Conv_92(t_570_padded)
|
||||
t_572 = F.relu(t_571)
|
||||
t_573 = self.n_Conv_93(t_572)
|
||||
t_574 = torch.add(t_573, t_568)
|
||||
t_575 = F.relu(t_574)
|
||||
t_576 = self.n_Conv_94(t_575)
|
||||
t_577 = F.relu(t_576)
|
||||
t_577_padded = F.pad(t_577, [1, 1, 1, 1], value=0)
|
||||
t_578 = self.n_Conv_95(t_577_padded)
|
||||
t_579 = F.relu(t_578)
|
||||
t_580 = self.n_Conv_96(t_579)
|
||||
t_581 = torch.add(t_580, t_575)
|
||||
t_582 = F.relu(t_581)
|
||||
t_583 = self.n_Conv_97(t_582)
|
||||
t_584 = F.relu(t_583)
|
||||
t_584_padded = F.pad(t_584, [0, 1, 0, 1], value=0)
|
||||
t_585 = self.n_Conv_98(t_584_padded)
|
||||
t_586 = F.relu(t_585)
|
||||
t_587 = self.n_Conv_99(t_586)
|
||||
t_588 = self.n_Conv_100(t_582)
|
||||
t_589 = torch.add(t_587, t_588)
|
||||
t_590 = F.relu(t_589)
|
||||
t_591 = self.n_Conv_101(t_590)
|
||||
t_592 = F.relu(t_591)
|
||||
t_592_padded = F.pad(t_592, [1, 1, 1, 1], value=0)
|
||||
t_593 = self.n_Conv_102(t_592_padded)
|
||||
t_594 = F.relu(t_593)
|
||||
t_595 = self.n_Conv_103(t_594)
|
||||
t_596 = torch.add(t_595, t_590)
|
||||
t_597 = F.relu(t_596)
|
||||
t_598 = self.n_Conv_104(t_597)
|
||||
t_599 = F.relu(t_598)
|
||||
t_599_padded = F.pad(t_599, [1, 1, 1, 1], value=0)
|
||||
t_600 = self.n_Conv_105(t_599_padded)
|
||||
t_601 = F.relu(t_600)
|
||||
t_602 = self.n_Conv_106(t_601)
|
||||
t_603 = torch.add(t_602, t_597)
|
||||
t_604 = F.relu(t_603)
|
||||
t_605 = self.n_Conv_107(t_604)
|
||||
t_606 = F.relu(t_605)
|
||||
t_606_padded = F.pad(t_606, [1, 1, 1, 1], value=0)
|
||||
t_607 = self.n_Conv_108(t_606_padded)
|
||||
t_608 = F.relu(t_607)
|
||||
t_609 = self.n_Conv_109(t_608)
|
||||
t_610 = torch.add(t_609, t_604)
|
||||
t_611 = F.relu(t_610)
|
||||
t_612 = self.n_Conv_110(t_611)
|
||||
t_613 = F.relu(t_612)
|
||||
t_613_padded = F.pad(t_613, [1, 1, 1, 1], value=0)
|
||||
t_614 = self.n_Conv_111(t_613_padded)
|
||||
t_615 = F.relu(t_614)
|
||||
t_616 = self.n_Conv_112(t_615)
|
||||
t_617 = torch.add(t_616, t_611)
|
||||
t_618 = F.relu(t_617)
|
||||
t_619 = self.n_Conv_113(t_618)
|
||||
t_620 = F.relu(t_619)
|
||||
t_620_padded = F.pad(t_620, [1, 1, 1, 1], value=0)
|
||||
t_621 = self.n_Conv_114(t_620_padded)
|
||||
t_622 = F.relu(t_621)
|
||||
t_623 = self.n_Conv_115(t_622)
|
||||
t_624 = torch.add(t_623, t_618)
|
||||
t_625 = F.relu(t_624)
|
||||
t_626 = self.n_Conv_116(t_625)
|
||||
t_627 = F.relu(t_626)
|
||||
t_627_padded = F.pad(t_627, [1, 1, 1, 1], value=0)
|
||||
t_628 = self.n_Conv_117(t_627_padded)
|
||||
t_629 = F.relu(t_628)
|
||||
t_630 = self.n_Conv_118(t_629)
|
||||
t_631 = torch.add(t_630, t_625)
|
||||
t_632 = F.relu(t_631)
|
||||
t_633 = self.n_Conv_119(t_632)
|
||||
t_634 = F.relu(t_633)
|
||||
t_634_padded = F.pad(t_634, [1, 1, 1, 1], value=0)
|
||||
t_635 = self.n_Conv_120(t_634_padded)
|
||||
t_636 = F.relu(t_635)
|
||||
t_637 = self.n_Conv_121(t_636)
|
||||
t_638 = torch.add(t_637, t_632)
|
||||
t_639 = F.relu(t_638)
|
||||
t_640 = self.n_Conv_122(t_639)
|
||||
t_641 = F.relu(t_640)
|
||||
t_641_padded = F.pad(t_641, [1, 1, 1, 1], value=0)
|
||||
t_642 = self.n_Conv_123(t_641_padded)
|
||||
t_643 = F.relu(t_642)
|
||||
t_644 = self.n_Conv_124(t_643)
|
||||
t_645 = torch.add(t_644, t_639)
|
||||
t_646 = F.relu(t_645)
|
||||
t_647 = self.n_Conv_125(t_646)
|
||||
t_648 = F.relu(t_647)
|
||||
t_648_padded = F.pad(t_648, [1, 1, 1, 1], value=0)
|
||||
t_649 = self.n_Conv_126(t_648_padded)
|
||||
t_650 = F.relu(t_649)
|
||||
t_651 = self.n_Conv_127(t_650)
|
||||
t_652 = torch.add(t_651, t_646)
|
||||
t_653 = F.relu(t_652)
|
||||
t_654 = self.n_Conv_128(t_653)
|
||||
t_655 = F.relu(t_654)
|
||||
t_655_padded = F.pad(t_655, [1, 1, 1, 1], value=0)
|
||||
t_656 = self.n_Conv_129(t_655_padded)
|
||||
t_657 = F.relu(t_656)
|
||||
t_658 = self.n_Conv_130(t_657)
|
||||
t_659 = torch.add(t_658, t_653)
|
||||
t_660 = F.relu(t_659)
|
||||
t_661 = self.n_Conv_131(t_660)
|
||||
t_662 = F.relu(t_661)
|
||||
t_662_padded = F.pad(t_662, [1, 1, 1, 1], value=0)
|
||||
t_663 = self.n_Conv_132(t_662_padded)
|
||||
t_664 = F.relu(t_663)
|
||||
t_665 = self.n_Conv_133(t_664)
|
||||
t_666 = torch.add(t_665, t_660)
|
||||
t_667 = F.relu(t_666)
|
||||
t_668 = self.n_Conv_134(t_667)
|
||||
t_669 = F.relu(t_668)
|
||||
t_669_padded = F.pad(t_669, [1, 1, 1, 1], value=0)
|
||||
t_670 = self.n_Conv_135(t_669_padded)
|
||||
t_671 = F.relu(t_670)
|
||||
t_672 = self.n_Conv_136(t_671)
|
||||
t_673 = torch.add(t_672, t_667)
|
||||
t_674 = F.relu(t_673)
|
||||
t_675 = self.n_Conv_137(t_674)
|
||||
t_676 = F.relu(t_675)
|
||||
t_676_padded = F.pad(t_676, [1, 1, 1, 1], value=0)
|
||||
t_677 = self.n_Conv_138(t_676_padded)
|
||||
t_678 = F.relu(t_677)
|
||||
t_679 = self.n_Conv_139(t_678)
|
||||
t_680 = torch.add(t_679, t_674)
|
||||
t_681 = F.relu(t_680)
|
||||
t_682 = self.n_Conv_140(t_681)
|
||||
t_683 = F.relu(t_682)
|
||||
t_683_padded = F.pad(t_683, [1, 1, 1, 1], value=0)
|
||||
t_684 = self.n_Conv_141(t_683_padded)
|
||||
t_685 = F.relu(t_684)
|
||||
t_686 = self.n_Conv_142(t_685)
|
||||
t_687 = torch.add(t_686, t_681)
|
||||
t_688 = F.relu(t_687)
|
||||
t_689 = self.n_Conv_143(t_688)
|
||||
t_690 = F.relu(t_689)
|
||||
t_690_padded = F.pad(t_690, [1, 1, 1, 1], value=0)
|
||||
t_691 = self.n_Conv_144(t_690_padded)
|
||||
t_692 = F.relu(t_691)
|
||||
t_693 = self.n_Conv_145(t_692)
|
||||
t_694 = torch.add(t_693, t_688)
|
||||
t_695 = F.relu(t_694)
|
||||
t_696 = self.n_Conv_146(t_695)
|
||||
t_697 = F.relu(t_696)
|
||||
t_697_padded = F.pad(t_697, [1, 1, 1, 1], value=0)
|
||||
t_698 = self.n_Conv_147(t_697_padded)
|
||||
t_699 = F.relu(t_698)
|
||||
t_700 = self.n_Conv_148(t_699)
|
||||
t_701 = torch.add(t_700, t_695)
|
||||
t_702 = F.relu(t_701)
|
||||
t_703 = self.n_Conv_149(t_702)
|
||||
t_704 = F.relu(t_703)
|
||||
t_704_padded = F.pad(t_704, [1, 1, 1, 1], value=0)
|
||||
t_705 = self.n_Conv_150(t_704_padded)
|
||||
t_706 = F.relu(t_705)
|
||||
t_707 = self.n_Conv_151(t_706)
|
||||
t_708 = torch.add(t_707, t_702)
|
||||
t_709 = F.relu(t_708)
|
||||
t_710 = self.n_Conv_152(t_709)
|
||||
t_711 = F.relu(t_710)
|
||||
t_711_padded = F.pad(t_711, [1, 1, 1, 1], value=0)
|
||||
t_712 = self.n_Conv_153(t_711_padded)
|
||||
t_713 = F.relu(t_712)
|
||||
t_714 = self.n_Conv_154(t_713)
|
||||
t_715 = torch.add(t_714, t_709)
|
||||
t_716 = F.relu(t_715)
|
||||
t_717 = self.n_Conv_155(t_716)
|
||||
t_718 = F.relu(t_717)
|
||||
t_718_padded = F.pad(t_718, [1, 1, 1, 1], value=0)
|
||||
t_719 = self.n_Conv_156(t_718_padded)
|
||||
t_720 = F.relu(t_719)
|
||||
t_721 = self.n_Conv_157(t_720)
|
||||
t_722 = torch.add(t_721, t_716)
|
||||
t_723 = F.relu(t_722)
|
||||
t_724 = self.n_Conv_158(t_723)
|
||||
t_725 = self.n_Conv_159(t_723)
|
||||
t_726 = F.relu(t_725)
|
||||
t_726_padded = F.pad(t_726, [0, 1, 0, 1], value=0)
|
||||
t_727 = self.n_Conv_160(t_726_padded)
|
||||
t_728 = F.relu(t_727)
|
||||
t_729 = self.n_Conv_161(t_728)
|
||||
t_730 = torch.add(t_729, t_724)
|
||||
t_731 = F.relu(t_730)
|
||||
t_732 = self.n_Conv_162(t_731)
|
||||
t_733 = F.relu(t_732)
|
||||
t_733_padded = F.pad(t_733, [1, 1, 1, 1], value=0)
|
||||
t_734 = self.n_Conv_163(t_733_padded)
|
||||
t_735 = F.relu(t_734)
|
||||
t_736 = self.n_Conv_164(t_735)
|
||||
t_737 = torch.add(t_736, t_731)
|
||||
t_738 = F.relu(t_737)
|
||||
t_739 = self.n_Conv_165(t_738)
|
||||
t_740 = F.relu(t_739)
|
||||
t_740_padded = F.pad(t_740, [1, 1, 1, 1], value=0)
|
||||
t_741 = self.n_Conv_166(t_740_padded)
|
||||
t_742 = F.relu(t_741)
|
||||
t_743 = self.n_Conv_167(t_742)
|
||||
t_744 = torch.add(t_743, t_738)
|
||||
t_745 = F.relu(t_744)
|
||||
t_746 = self.n_Conv_168(t_745)
|
||||
t_747 = self.n_Conv_169(t_745)
|
||||
t_748 = F.relu(t_747)
|
||||
t_748_padded = F.pad(t_748, [0, 1, 0, 1], value=0)
|
||||
t_749 = self.n_Conv_170(t_748_padded)
|
||||
t_750 = F.relu(t_749)
|
||||
t_751 = self.n_Conv_171(t_750)
|
||||
t_752 = torch.add(t_751, t_746)
|
||||
t_753 = F.relu(t_752)
|
||||
t_754 = self.n_Conv_172(t_753)
|
||||
t_755 = F.relu(t_754)
|
||||
t_755_padded = F.pad(t_755, [1, 1, 1, 1], value=0)
|
||||
t_756 = self.n_Conv_173(t_755_padded)
|
||||
t_757 = F.relu(t_756)
|
||||
t_758 = self.n_Conv_174(t_757)
|
||||
t_759 = torch.add(t_758, t_753)
|
||||
t_760 = F.relu(t_759)
|
||||
t_761 = self.n_Conv_175(t_760)
|
||||
t_762 = F.relu(t_761)
|
||||
t_762_padded = F.pad(t_762, [1, 1, 1, 1], value=0)
|
||||
t_763 = self.n_Conv_176(t_762_padded)
|
||||
t_764 = F.relu(t_763)
|
||||
t_765 = self.n_Conv_177(t_764)
|
||||
t_766 = torch.add(t_765, t_760)
|
||||
t_767 = F.relu(t_766)
|
||||
t_768 = self.n_Conv_178(t_767)
|
||||
t_769 = F.avg_pool2d(t_768, kernel_size=t_768.shape[-2:])
|
||||
t_770 = torch.squeeze(t_769, 3)
|
||||
t_770 = torch.squeeze(t_770, 2)
|
||||
t_771 = torch.sigmoid(t_770)
|
||||
return t_771
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
self.tags = state_dict.get('tags', [])
|
||||
|
||||
super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'})
|
||||
|
@ -1,62 +1,96 @@
|
||||
import sys, os, shlex
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
|
||||
from modules import errors
|
||||
from packaging import version
|
||||
|
||||
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
|
||||
has_mps = getattr(torch, 'has_mps', False)
|
||||
|
||||
cpu = torch.device("cpu")
|
||||
# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
|
||||
# check `getattr` and try it for compatibility
|
||||
def has_mps() -> bool:
|
||||
if not getattr(torch, 'has_mps', False):
|
||||
return False
|
||||
try:
|
||||
torch.zeros(1).to(torch.device("mps"))
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def extract_device_id(args, name):
|
||||
for x in range(len(args)):
|
||||
if name in args[x]:
|
||||
return args[x + 1]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_cuda_device_string():
|
||||
from modules import shared
|
||||
|
||||
if shared.cmd_opts.device_id is not None:
|
||||
return f"cuda:{shared.cmd_opts.device_id}"
|
||||
|
||||
return "cuda"
|
||||
|
||||
|
||||
def get_optimal_device():
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
return torch.device(get_cuda_device_string())
|
||||
|
||||
if has_mps:
|
||||
if has_mps():
|
||||
return torch.device("mps")
|
||||
|
||||
return cpu
|
||||
|
||||
|
||||
def get_device_for(task):
|
||||
from modules import shared
|
||||
|
||||
if task in shared.cmd_opts.use_cpu:
|
||||
return cpu
|
||||
|
||||
return get_optimal_device()
|
||||
|
||||
|
||||
def torch_gc():
|
||||
if torch.cuda.is_available():
|
||||
with torch.cuda.device(get_cuda_device_string()):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
def enable_tf32():
|
||||
if torch.cuda.is_available():
|
||||
|
||||
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
||||
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
||||
if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
|
||||
|
||||
errors.run(enable_tf32, "Enabling TF32")
|
||||
|
||||
device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
|
||||
cpu = torch.device("cpu")
|
||||
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
|
||||
dtype = torch.float16
|
||||
dtype_vae = torch.float16
|
||||
|
||||
def randn(seed, shape):
|
||||
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
||||
if device.type == 'mps':
|
||||
generator = torch.Generator(device=cpu)
|
||||
generator.manual_seed(seed)
|
||||
noise = torch.randn(shape, generator=generator, device=cpu).to(device)
|
||||
return noise
|
||||
|
||||
def randn(seed, shape):
|
||||
torch.manual_seed(seed)
|
||||
if device.type == 'mps':
|
||||
return torch.randn(shape, device=cpu).to(device)
|
||||
return torch.randn(shape, device=device)
|
||||
|
||||
|
||||
def randn_without_seed(shape):
|
||||
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
||||
if device.type == 'mps':
|
||||
generator = torch.Generator(device=cpu)
|
||||
noise = torch.randn(shape, generator=generator, device=cpu).to(device)
|
||||
return noise
|
||||
|
||||
return torch.randn(shape, device=cpu).to(device)
|
||||
return torch.randn(shape, device=device)
|
||||
|
||||
|
||||
@ -70,3 +104,37 @@ def autocast(disable=False):
|
||||
return contextlib.nullcontext()
|
||||
|
||||
return torch.autocast("cuda")
|
||||
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
||||
orig_tensor_to = torch.Tensor.to
|
||||
def tensor_to_fix(self, *args, **kwargs):
|
||||
if self.device.type != 'mps' and \
|
||||
((len(args) > 0 and isinstance(args[0], torch.device) and args[0].type == 'mps') or \
|
||||
(isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')):
|
||||
self = self.contiguous()
|
||||
return orig_tensor_to(self, *args, **kwargs)
|
||||
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
||||
orig_layer_norm = torch.nn.functional.layer_norm
|
||||
def layer_norm_fix(*args, **kwargs):
|
||||
if len(args) > 0 and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps':
|
||||
args = list(args)
|
||||
args[0] = args[0].contiguous()
|
||||
return orig_layer_norm(*args, **kwargs)
|
||||
|
||||
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
||||
orig_tensor_numpy = torch.Tensor.numpy
|
||||
def numpy_fix(self, *args, **kwargs):
|
||||
if self.requires_grad:
|
||||
self = self.detach()
|
||||
return orig_tensor_numpy(self, *args, **kwargs)
|
||||
|
||||
|
||||
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
|
||||
if has_mps() and version.parse(torch.__version__) < version.parse("1.13"):
|
||||
torch.Tensor.to = tensor_to_fix
|
||||
torch.nn.functional.layer_norm = layer_norm_fix
|
||||
torch.Tensor.numpy = numpy_fix
|
||||
|
@ -2,9 +2,30 @@ import sys
|
||||
import traceback
|
||||
|
||||
|
||||
def print_error_explanation(message):
|
||||
lines = message.strip().split("\n")
|
||||
max_len = max([len(x) for x in lines])
|
||||
|
||||
print('=' * max_len, file=sys.stderr)
|
||||
for line in lines:
|
||||
print(line, file=sys.stderr)
|
||||
print('=' * max_len, file=sys.stderr)
|
||||
|
||||
|
||||
def display(e: Exception, task):
|
||||
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
message = str(e)
|
||||
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
|
||||
print_error_explanation("""
|
||||
The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file.
|
||||
See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
|
||||
""")
|
||||
|
||||
|
||||
def run(code, task):
|
||||
try:
|
||||
code()
|
||||
except Exception as e:
|
||||
print(f"{task}: {type(e).__name__}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
display(task, e)
|
||||
|
@ -11,62 +11,118 @@ from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
def fix_model_layers(crt_model, pretrained_net):
|
||||
# this code is adapted from https://github.com/xinntao/ESRGAN
|
||||
if 'conv_first.weight' in pretrained_net:
|
||||
return pretrained_net
|
||||
|
||||
if 'model.0.weight' not in pretrained_net:
|
||||
is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
|
||||
if is_realesrgan:
|
||||
raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
|
||||
else:
|
||||
raise Exception("The file is not a ESRGAN model.")
|
||||
def mod2normal(state_dict):
|
||||
# this code is copied from https://github.com/victorca25/iNNfer
|
||||
if 'conv_first.weight' in state_dict:
|
||||
crt_net = {}
|
||||
items = []
|
||||
for k, v in state_dict.items():
|
||||
items.append(k)
|
||||
|
||||
crt_net = crt_model.state_dict()
|
||||
load_net_clean = {}
|
||||
for k, v in pretrained_net.items():
|
||||
if k.startswith('module.'):
|
||||
load_net_clean[k[7:]] = v
|
||||
else:
|
||||
load_net_clean[k] = v
|
||||
pretrained_net = load_net_clean
|
||||
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
||||
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
||||
|
||||
tbd = []
|
||||
for k, v in crt_net.items():
|
||||
tbd.append(k)
|
||||
|
||||
# directly copy
|
||||
for k, v in crt_net.items():
|
||||
if k in pretrained_net and pretrained_net[k].size() == v.size():
|
||||
crt_net[k] = pretrained_net[k]
|
||||
tbd.remove(k)
|
||||
|
||||
crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
|
||||
crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
|
||||
|
||||
for k in tbd.copy():
|
||||
for k in items.copy():
|
||||
if 'RDB' in k:
|
||||
ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
|
||||
if '.weight' in k:
|
||||
ori_k = ori_k.replace('.weight', '.0.weight')
|
||||
elif '.bias' in k:
|
||||
ori_k = ori_k.replace('.bias', '.0.bias')
|
||||
crt_net[k] = pretrained_net[ori_k]
|
||||
tbd.remove(k)
|
||||
crt_net[ori_k] = state_dict[k]
|
||||
items.remove(k)
|
||||
|
||||
crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
|
||||
crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
|
||||
crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
|
||||
crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
|
||||
crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
|
||||
crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
|
||||
crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
|
||||
crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
|
||||
crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
|
||||
crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
|
||||
crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
|
||||
crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
|
||||
crt_net['model.3.weight'] = state_dict['upconv1.weight']
|
||||
crt_net['model.3.bias'] = state_dict['upconv1.bias']
|
||||
crt_net['model.6.weight'] = state_dict['upconv2.weight']
|
||||
crt_net['model.6.bias'] = state_dict['upconv2.bias']
|
||||
crt_net['model.8.weight'] = state_dict['HRconv.weight']
|
||||
crt_net['model.8.bias'] = state_dict['HRconv.bias']
|
||||
crt_net['model.10.weight'] = state_dict['conv_last.weight']
|
||||
crt_net['model.10.bias'] = state_dict['conv_last.bias']
|
||||
state_dict = crt_net
|
||||
return state_dict
|
||||
|
||||
|
||||
def resrgan2normal(state_dict, nb=23):
|
||||
# this code is copied from https://github.com/victorca25/iNNfer
|
||||
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
|
||||
re8x = 0
|
||||
crt_net = {}
|
||||
items = []
|
||||
for k, v in state_dict.items():
|
||||
items.append(k)
|
||||
|
||||
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
||||
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
||||
|
||||
for k in items.copy():
|
||||
if "rdb" in k:
|
||||
ori_k = k.replace('body.', 'model.1.sub.')
|
||||
ori_k = ori_k.replace('.rdb', '.RDB')
|
||||
if '.weight' in k:
|
||||
ori_k = ori_k.replace('.weight', '.0.weight')
|
||||
elif '.bias' in k:
|
||||
ori_k = ori_k.replace('.bias', '.0.bias')
|
||||
crt_net[ori_k] = state_dict[k]
|
||||
items.remove(k)
|
||||
|
||||
crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
|
||||
crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
|
||||
crt_net['model.3.weight'] = state_dict['conv_up1.weight']
|
||||
crt_net['model.3.bias'] = state_dict['conv_up1.bias']
|
||||
crt_net['model.6.weight'] = state_dict['conv_up2.weight']
|
||||
crt_net['model.6.bias'] = state_dict['conv_up2.bias']
|
||||
|
||||
if 'conv_up3.weight' in state_dict:
|
||||
# modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
|
||||
re8x = 3
|
||||
crt_net['model.9.weight'] = state_dict['conv_up3.weight']
|
||||
crt_net['model.9.bias'] = state_dict['conv_up3.bias']
|
||||
|
||||
crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
|
||||
crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
|
||||
crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
|
||||
crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
|
||||
|
||||
state_dict = crt_net
|
||||
return state_dict
|
||||
|
||||
|
||||
def infer_params(state_dict):
|
||||
# this code is copied from https://github.com/victorca25/iNNfer
|
||||
scale2x = 0
|
||||
scalemin = 6
|
||||
n_uplayer = 0
|
||||
plus = False
|
||||
|
||||
for block in list(state_dict):
|
||||
parts = block.split(".")
|
||||
n_parts = len(parts)
|
||||
if n_parts == 5 and parts[2] == "sub":
|
||||
nb = int(parts[3])
|
||||
elif n_parts == 3:
|
||||
part_num = int(parts[1])
|
||||
if (part_num > scalemin
|
||||
and parts[0] == "model"
|
||||
and parts[2] == "weight"):
|
||||
scale2x += 1
|
||||
if part_num > n_uplayer:
|
||||
n_uplayer = part_num
|
||||
out_nc = state_dict[block].shape[0]
|
||||
if not plus and "conv1x1" in block:
|
||||
plus = True
|
||||
|
||||
nf = state_dict["model.0.weight"].shape[0]
|
||||
in_nc = state_dict["model.0.weight"].shape[1]
|
||||
out_nc = out_nc
|
||||
scale = 2 ** scale2x
|
||||
|
||||
return in_nc, out_nc, nf, nb, plus, scale
|
||||
|
||||
return crt_net
|
||||
|
||||
class UpscalerESRGAN(Upscaler):
|
||||
def __init__(self, dirname):
|
||||
@ -109,20 +165,39 @@ class UpscalerESRGAN(Upscaler):
|
||||
print("Unable to load %s from %s" % (self.model_path, filename))
|
||||
return None
|
||||
|
||||
pretrained_net = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
||||
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
|
||||
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
||||
|
||||
pretrained_net = fix_model_layers(crt_model, pretrained_net)
|
||||
crt_model.load_state_dict(pretrained_net)
|
||||
crt_model.eval()
|
||||
if "params_ema" in state_dict:
|
||||
state_dict = state_dict["params_ema"]
|
||||
elif "params" in state_dict:
|
||||
state_dict = state_dict["params"]
|
||||
num_conv = 16 if "realesr-animevideov3" in filename else 32
|
||||
model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
return crt_model
|
||||
if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
|
||||
nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
|
||||
state_dict = resrgan2normal(state_dict, nb)
|
||||
elif "conv_first.weight" in state_dict:
|
||||
state_dict = mod2normal(state_dict)
|
||||
elif "model.0.weight" not in state_dict:
|
||||
raise Exception("The file is not a recognized ESRGAN model.")
|
||||
|
||||
in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
|
||||
|
||||
model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def upscale_without_tiling(model, img):
|
||||
img = np.array(img)
|
||||
img = img[:, :, ::-1]
|
||||
img = np.moveaxis(img, 2, 0) / 255
|
||||
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
|
||||
img = torch.from_numpy(img).float()
|
||||
img = img.unsqueeze(0).to(devices.device_esrgan)
|
||||
with torch.no_grad():
|
||||
|
@ -1,80 +1,463 @@
|
||||
# this file is taken from https://github.com/xinntao/ESRGAN
|
||||
# this file is adapted from https://github.com/victorca25/iNNfer
|
||||
|
||||
import math
|
||||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def make_layer(block, n_layers):
|
||||
layers = []
|
||||
for _ in range(n_layers):
|
||||
layers.append(block())
|
||||
return nn.Sequential(*layers)
|
||||
####################
|
||||
# RRDBNet Generator
|
||||
####################
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
|
||||
act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
|
||||
finalact=None, gaussian_noise=False, plus=False):
|
||||
super(RRDBNet, self).__init__()
|
||||
n_upscale = int(math.log(upscale, 2))
|
||||
if upscale == 3:
|
||||
n_upscale = 1
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
def __init__(self, nf=64, gc=32, bias=True):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
# gc: growth channel, i.e. intermediate channels
|
||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
self.resrgan_scale = 0
|
||||
if in_nc % 16 == 0:
|
||||
self.resrgan_scale = 1
|
||||
elif in_nc != 4 and in_nc % 4 == 0:
|
||||
self.resrgan_scale = 2
|
||||
|
||||
# initialization
|
||||
# mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
||||
fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
|
||||
rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
||||
norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
|
||||
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
|
||||
LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.lrelu(self.conv1(x))
|
||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
return x5 * 0.2 + x
|
||||
if upsample_mode == 'upconv':
|
||||
upsample_block = upconv_block
|
||||
elif upsample_mode == 'pixelshuffle':
|
||||
upsample_block = pixelshuffle_block
|
||||
else:
|
||||
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
|
||||
if upscale == 3:
|
||||
upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
|
||||
else:
|
||||
upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
|
||||
HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
|
||||
HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
|
||||
|
||||
outact = act(finalact) if finalact else None
|
||||
|
||||
self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
|
||||
*upsampler, HR_conv0, HR_conv1, outact)
|
||||
|
||||
def forward(self, x, outm=None):
|
||||
if self.resrgan_scale == 1:
|
||||
feat = pixel_unshuffle(x, scale=4)
|
||||
elif self.resrgan_scale == 2:
|
||||
feat = pixel_unshuffle(x, scale=2)
|
||||
else:
|
||||
feat = x
|
||||
|
||||
return self.model(feat)
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
'''Residual in Residual Dense Block'''
|
||||
"""
|
||||
Residual in Residual Dense Block
|
||||
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
|
||||
"""
|
||||
|
||||
def __init__(self, nf, gc=32):
|
||||
def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
||||
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
|
||||
spectral_norm=False, gaussian_noise=False, plus=False):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
||||
# This is for backwards compatibility with existing models
|
||||
if nr == 3:
|
||||
self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||
gaussian_noise=gaussian_noise, plus=plus)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||
gaussian_noise=gaussian_noise, plus=plus)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||
gaussian_noise=gaussian_noise, plus=plus)
|
||||
else:
|
||||
RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
|
||||
self.RDBs = nn.Sequential(*RDB_list)
|
||||
|
||||
def forward(self, x):
|
||||
if hasattr(self, 'RDB1'):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
else:
|
||||
out = self.RDBs(x)
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
|
||||
super(RRDBNet, self).__init__()
|
||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
"""
|
||||
Residual Dense Block
|
||||
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
|
||||
Modified options that can be used:
|
||||
- "Partial Convolution based Padding" arXiv:1811.11718
|
||||
- "Spectral normalization" arXiv:1802.05957
|
||||
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
||||
{Rakotonirina} and A. {Rasoanaivo}
|
||||
"""
|
||||
|
||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
#### upsampling
|
||||
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||
def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
||||
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
|
||||
spectral_norm=False, gaussian_noise=False, plus=False):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
self.noise = GaussianNoise() if gaussian_noise else None
|
||||
self.conv1x1 = conv1x1(nf, gc) if plus else None
|
||||
|
||||
self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
if mode == 'CNA':
|
||||
last_act = None
|
||||
else:
|
||||
last_act = act_type
|
||||
self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.conv_first(x)
|
||||
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
||||
fea = fea + trunk
|
||||
x1 = self.conv1(x)
|
||||
x2 = self.conv2(torch.cat((x, x1), 1))
|
||||
if self.conv1x1:
|
||||
x2 = x2 + self.conv1x1(x)
|
||||
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
||||
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
||||
if self.conv1x1:
|
||||
x4 = x4 + x2
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
if self.noise:
|
||||
return self.noise(x5.mul(0.2) + x)
|
||||
else:
|
||||
return x5 * 0.2 + x
|
||||
|
||||
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||
|
||||
####################
|
||||
# ESRGANplus
|
||||
####################
|
||||
|
||||
class GaussianNoise(nn.Module):
|
||||
def __init__(self, sigma=0.1, is_relative_detach=False):
|
||||
super().__init__()
|
||||
self.sigma = sigma
|
||||
self.is_relative_detach = is_relative_detach
|
||||
self.noise = torch.tensor(0, dtype=torch.float)
|
||||
|
||||
def forward(self, x):
|
||||
if self.training and self.sigma != 0:
|
||||
self.noise = self.noise.to(x.device)
|
||||
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
|
||||
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
|
||||
x = x + sampled_noise
|
||||
return x
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
####################
|
||||
# SRVGGNetCompact
|
||||
####################
|
||||
|
||||
class SRVGGNetCompact(nn.Module):
|
||||
"""A compact VGG-style network structure for super-resolution.
|
||||
This class is copied from https://github.com/xinntao/Real-ESRGAN
|
||||
"""
|
||||
|
||||
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
|
||||
super(SRVGGNetCompact, self).__init__()
|
||||
self.num_in_ch = num_in_ch
|
||||
self.num_out_ch = num_out_ch
|
||||
self.num_feat = num_feat
|
||||
self.num_conv = num_conv
|
||||
self.upscale = upscale
|
||||
self.act_type = act_type
|
||||
|
||||
self.body = nn.ModuleList()
|
||||
# the first conv
|
||||
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
|
||||
# the first activation
|
||||
if act_type == 'relu':
|
||||
activation = nn.ReLU(inplace=True)
|
||||
elif act_type == 'prelu':
|
||||
activation = nn.PReLU(num_parameters=num_feat)
|
||||
elif act_type == 'leakyrelu':
|
||||
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
self.body.append(activation)
|
||||
|
||||
# the body structure
|
||||
for _ in range(num_conv):
|
||||
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
|
||||
# activation
|
||||
if act_type == 'relu':
|
||||
activation = nn.ReLU(inplace=True)
|
||||
elif act_type == 'prelu':
|
||||
activation = nn.PReLU(num_parameters=num_feat)
|
||||
elif act_type == 'leakyrelu':
|
||||
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
self.body.append(activation)
|
||||
|
||||
# the last conv
|
||||
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
|
||||
# upsample
|
||||
self.upsampler = nn.PixelShuffle(upscale)
|
||||
|
||||
def forward(self, x):
|
||||
out = x
|
||||
for i in range(0, len(self.body)):
|
||||
out = self.body[i](out)
|
||||
|
||||
out = self.upsampler(out)
|
||||
# add the nearest upsampled image, so that the network learns the residual
|
||||
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
|
||||
out += base
|
||||
return out
|
||||
|
||||
|
||||
####################
|
||||
# Upsampler
|
||||
####################
|
||||
|
||||
class Upsample(nn.Module):
|
||||
r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
|
||||
The input data is assumed to be of the form
|
||||
`minibatch x channels x [optional depth] x [optional height] x width`.
|
||||
"""
|
||||
|
||||
def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
||||
super(Upsample, self).__init__()
|
||||
if isinstance(scale_factor, tuple):
|
||||
self.scale_factor = tuple(float(factor) for factor in scale_factor)
|
||||
else:
|
||||
self.scale_factor = float(scale_factor) if scale_factor else None
|
||||
self.mode = mode
|
||||
self.size = size
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x):
|
||||
return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
|
||||
|
||||
def extra_repr(self):
|
||||
if self.scale_factor is not None:
|
||||
info = 'scale_factor=' + str(self.scale_factor)
|
||||
else:
|
||||
info = 'size=' + str(self.size)
|
||||
info += ', mode=' + self.mode
|
||||
return info
|
||||
|
||||
|
||||
def pixel_unshuffle(x, scale):
|
||||
""" Pixel unshuffle.
|
||||
Args:
|
||||
x (Tensor): Input feature with shape (b, c, hh, hw).
|
||||
scale (int): Downsample ratio.
|
||||
Returns:
|
||||
Tensor: the pixel unshuffled feature.
|
||||
"""
|
||||
b, c, hh, hw = x.size()
|
||||
out_channel = c * (scale**2)
|
||||
assert hh % scale == 0 and hw % scale == 0
|
||||
h = hh // scale
|
||||
w = hw // scale
|
||||
x_view = x.view(b, c, h, scale, w, scale)
|
||||
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|
||||
|
||||
|
||||
def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
||||
pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
|
||||
"""
|
||||
Pixel shuffle layer
|
||||
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
|
||||
Neural Network, CVPR17)
|
||||
"""
|
||||
conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
|
||||
pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
|
||||
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
||||
|
||||
n = norm(norm_type, out_nc) if norm_type else None
|
||||
a = act(act_type) if act_type else None
|
||||
return sequential(conv, pixel_shuffle, n, a)
|
||||
|
||||
|
||||
def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
||||
pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
|
||||
""" Upconv layer """
|
||||
upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
|
||||
upsample = Upsample(scale_factor=upscale_factor, mode=mode)
|
||||
conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
|
||||
pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
|
||||
return sequential(upsample, conv)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
####################
|
||||
# Basic blocks
|
||||
####################
|
||||
|
||||
|
||||
def make_layer(basic_block, num_basic_block, **kwarg):
|
||||
"""Make layers by stacking the same blocks.
|
||||
Args:
|
||||
basic_block (nn.module): nn.module class for basic block. (block)
|
||||
num_basic_block (int): number of blocks. (n_layers)
|
||||
Returns:
|
||||
nn.Sequential: Stacked blocks in nn.Sequential.
|
||||
"""
|
||||
layers = []
|
||||
for _ in range(num_basic_block):
|
||||
layers.append(basic_block(**kwarg))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
|
||||
""" activation helper """
|
||||
act_type = act_type.lower()
|
||||
if act_type == 'relu':
|
||||
layer = nn.ReLU(inplace)
|
||||
elif act_type in ('leakyrelu', 'lrelu'):
|
||||
layer = nn.LeakyReLU(neg_slope, inplace)
|
||||
elif act_type == 'prelu':
|
||||
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
||||
elif act_type == 'tanh': # [-1, 1] range output
|
||||
layer = nn.Tanh()
|
||||
elif act_type == 'sigmoid': # [0, 1] range output
|
||||
layer = nn.Sigmoid()
|
||||
else:
|
||||
raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
|
||||
return layer
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
def __init__(self, *kwargs):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x, *kwargs):
|
||||
return x
|
||||
|
||||
|
||||
def norm(norm_type, nc):
|
||||
""" Return a normalization layer """
|
||||
norm_type = norm_type.lower()
|
||||
if norm_type == 'batch':
|
||||
layer = nn.BatchNorm2d(nc, affine=True)
|
||||
elif norm_type == 'instance':
|
||||
layer = nn.InstanceNorm2d(nc, affine=False)
|
||||
elif norm_type == 'none':
|
||||
def norm_layer(x): return Identity()
|
||||
else:
|
||||
raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
|
||||
return layer
|
||||
|
||||
|
||||
def pad(pad_type, padding):
|
||||
""" padding layer helper """
|
||||
pad_type = pad_type.lower()
|
||||
if padding == 0:
|
||||
return None
|
||||
if pad_type == 'reflect':
|
||||
layer = nn.ReflectionPad2d(padding)
|
||||
elif pad_type == 'replicate':
|
||||
layer = nn.ReplicationPad2d(padding)
|
||||
elif pad_type == 'zero':
|
||||
layer = nn.ZeroPad2d(padding)
|
||||
else:
|
||||
raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
|
||||
return layer
|
||||
|
||||
|
||||
def get_valid_padding(kernel_size, dilation):
|
||||
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
|
||||
padding = (kernel_size - 1) // 2
|
||||
return padding
|
||||
|
||||
|
||||
class ShortcutBlock(nn.Module):
|
||||
""" Elementwise sum the output of a submodule to its input """
|
||||
def __init__(self, submodule):
|
||||
super(ShortcutBlock, self).__init__()
|
||||
self.sub = submodule
|
||||
|
||||
def forward(self, x):
|
||||
output = x + self.sub(x)
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
|
||||
|
||||
|
||||
def sequential(*args):
|
||||
""" Flatten Sequential. It unwraps nn.Sequential. """
|
||||
if len(args) == 1:
|
||||
if isinstance(args[0], OrderedDict):
|
||||
raise NotImplementedError('sequential does not support OrderedDict input.')
|
||||
return args[0] # No sequential is needed.
|
||||
modules = []
|
||||
for module in args:
|
||||
if isinstance(module, nn.Sequential):
|
||||
for submodule in module.children():
|
||||
modules.append(submodule)
|
||||
elif isinstance(module, nn.Module):
|
||||
modules.append(module)
|
||||
return nn.Sequential(*modules)
|
||||
|
||||
|
||||
def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
|
||||
pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
|
||||
spectral_norm=False):
|
||||
""" Conv layer with padding, normalization, activation """
|
||||
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
|
||||
padding = get_valid_padding(kernel_size, dilation)
|
||||
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
|
||||
padding = padding if pad_type == 'zero' else 0
|
||||
|
||||
if convtype=='PartialConv2D':
|
||||
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
elif convtype=='DeformConv2D':
|
||||
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
elif convtype=='Conv3D':
|
||||
c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
else:
|
||||
c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
|
||||
if spectral_norm:
|
||||
c = nn.utils.spectral_norm(c)
|
||||
|
||||
a = act(act_type) if act_type else None
|
||||
if 'CNA' in mode:
|
||||
n = norm(norm_type, out_nc) if norm_type else None
|
||||
return sequential(p, c, n, a)
|
||||
elif mode == 'NAC':
|
||||
if norm_type is None and act_type is not None:
|
||||
a = act(act_type, inplace=False)
|
||||
n = norm(norm_type, in_nc) if norm_type else None
|
||||
return sequential(n, a, p, c)
|
||||
|
99
modules/extensions.py
Normal file
99
modules/extensions.py
Normal file
@ -0,0 +1,99 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import git
|
||||
|
||||
from modules import paths, shared
|
||||
|
||||
extensions = []
|
||||
extensions_dir = os.path.join(paths.script_path, "extensions")
|
||||
extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
|
||||
|
||||
|
||||
def active():
|
||||
return [x for x in extensions if x.enabled]
|
||||
|
||||
|
||||
class Extension:
|
||||
def __init__(self, name, path, enabled=True, is_builtin=False):
|
||||
self.name = name
|
||||
self.path = path
|
||||
self.enabled = enabled
|
||||
self.status = ''
|
||||
self.can_update = False
|
||||
self.is_builtin = is_builtin
|
||||
|
||||
repo = None
|
||||
try:
|
||||
if os.path.exists(os.path.join(path, ".git")):
|
||||
repo = git.Repo(path)
|
||||
except Exception:
|
||||
print(f"Error reading github repository info from {path}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
if repo is None or repo.bare:
|
||||
self.remote = None
|
||||
else:
|
||||
try:
|
||||
self.remote = next(repo.remote().urls, None)
|
||||
self.status = 'unknown'
|
||||
except Exception:
|
||||
self.remote = None
|
||||
|
||||
def list_files(self, subdir, extension):
|
||||
from modules import scripts
|
||||
|
||||
dirpath = os.path.join(self.path, subdir)
|
||||
if not os.path.isdir(dirpath):
|
||||
return []
|
||||
|
||||
res = []
|
||||
for filename in sorted(os.listdir(dirpath)):
|
||||
res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
|
||||
|
||||
res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
||||
|
||||
return res
|
||||
|
||||
def check_updates(self):
|
||||
repo = git.Repo(self.path)
|
||||
for fetch in repo.remote().fetch("--dry-run"):
|
||||
if fetch.flags != fetch.HEAD_UPTODATE:
|
||||
self.can_update = True
|
||||
self.status = "behind"
|
||||
return
|
||||
|
||||
self.can_update = False
|
||||
self.status = "latest"
|
||||
|
||||
def fetch_and_reset_hard(self):
|
||||
repo = git.Repo(self.path)
|
||||
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
|
||||
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
|
||||
repo.git.fetch('--all')
|
||||
repo.git.reset('--hard', 'origin')
|
||||
|
||||
|
||||
def list_extensions():
|
||||
extensions.clear()
|
||||
|
||||
if not os.path.isdir(extensions_dir):
|
||||
return
|
||||
|
||||
paths = []
|
||||
for dirname in [extensions_dir, extensions_builtin_dir]:
|
||||
if not os.path.isdir(dirname):
|
||||
return
|
||||
|
||||
for extension_dirname in sorted(os.listdir(dirname)):
|
||||
path = os.path.join(dirname, extension_dirname)
|
||||
if not os.path.isdir(path):
|
||||
continue
|
||||
|
||||
paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
|
||||
|
||||
for dirname, path, is_builtin in paths:
|
||||
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
|
||||
extensions.append(extension)
|
||||
|
@ -1,5 +1,8 @@
|
||||
from __future__ import annotations
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
@ -7,7 +10,11 @@ from PIL import Image
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from modules import processing, shared, images, devices, sd_models
|
||||
from typing import Callable, List, OrderedDict, Tuple
|
||||
from functools import partial
|
||||
from dataclasses import dataclass
|
||||
|
||||
from modules import processing, shared, images, devices, sd_models, sd_samplers
|
||||
from modules.shared import opts
|
||||
import modules.gfpgan_model
|
||||
from modules.ui import plaintext_to_html
|
||||
@ -15,14 +22,45 @@ import modules.codeformer_model
|
||||
import piexif
|
||||
import piexif.helper
|
||||
import gradio as gr
|
||||
import safetensors.torch
|
||||
|
||||
class LruCache(OrderedDict):
|
||||
@dataclass(frozen=True)
|
||||
class Key:
|
||||
image_hash: int
|
||||
info_hash: int
|
||||
args_hash: int
|
||||
|
||||
@dataclass
|
||||
class Value:
|
||||
image: Image.Image
|
||||
info: str
|
||||
|
||||
def __init__(self, max_size: int = 5, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._max_size = max_size
|
||||
|
||||
def get(self, key: LruCache.Key) -> LruCache.Value:
|
||||
ret = super().get(key)
|
||||
if ret is not None:
|
||||
self.move_to_end(key) # Move to end of eviction list
|
||||
return ret
|
||||
|
||||
def put(self, key: LruCache.Key, value: LruCache.Value) -> None:
|
||||
self[key] = value
|
||||
while len(self) > self._max_size:
|
||||
self.popitem(last=False)
|
||||
|
||||
|
||||
cached_images = {}
|
||||
cached_images: LruCache = LruCache(max_size=5)
|
||||
|
||||
|
||||
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
|
||||
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
|
||||
devices.torch_gc()
|
||||
|
||||
shared.state.begin()
|
||||
shared.state.job = 'extras'
|
||||
|
||||
imageArr = []
|
||||
# Also keep track of original file names
|
||||
imageNameArr = []
|
||||
@ -39,9 +77,12 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||
|
||||
if input_dir == '':
|
||||
return outputs, "Please select an input directory.", ''
|
||||
image_list = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
|
||||
image_list = shared.listfiles(input_dir)
|
||||
for img in image_list:
|
||||
try:
|
||||
image = Image.open(img)
|
||||
except Exception:
|
||||
continue
|
||||
imageArr.append(image)
|
||||
imageNameArr.append(img)
|
||||
else:
|
||||
@ -53,16 +94,10 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||
else:
|
||||
outpath = opts.outdir_samples or opts.outdir_extras_samples
|
||||
|
||||
# Extra operation definitions
|
||||
|
||||
for image, image_name in zip(imageArr, imageNameArr):
|
||||
if image is None:
|
||||
return outputs, "Please select an input image.", ''
|
||||
existing_pnginfo = image.info or {}
|
||||
|
||||
image = image.convert("RGB")
|
||||
info = ""
|
||||
|
||||
if gfpgan_visibility > 0:
|
||||
def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
||||
shared.state.job = 'extras-gfpgan'
|
||||
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
@ -70,9 +105,10 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||
res = Image.blend(image, res, gfpgan_visibility)
|
||||
|
||||
info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
|
||||
image = res
|
||||
return (res, info)
|
||||
|
||||
if codeformer_visibility > 0:
|
||||
def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
||||
shared.state.job = 'extras-codeformer'
|
||||
restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
@ -80,53 +116,106 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||
res = Image.blend(image, res, codeformer_visibility)
|
||||
|
||||
info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
|
||||
image = res
|
||||
return (res, info)
|
||||
|
||||
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
|
||||
shared.state.job = 'extras-upscale'
|
||||
upscaler = shared.sd_upscalers[scaler_index]
|
||||
res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
|
||||
if mode == 1 and crop:
|
||||
cropped = Image.new("RGB", (resize_w, resize_h))
|
||||
cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2))
|
||||
res = cropped
|
||||
return res
|
||||
|
||||
def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
||||
# Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text
|
||||
nonlocal upscaling_resize
|
||||
if resize_mode == 1:
|
||||
upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
|
||||
crop_info = " (crop)" if upscaling_crop else ""
|
||||
info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"
|
||||
return (image, info)
|
||||
|
||||
if upscaling_resize != 1.0:
|
||||
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
|
||||
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
|
||||
pixels = tuple(np.array(small).flatten().tolist())
|
||||
key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight,
|
||||
resize_mode, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop) + pixels
|
||||
@dataclass
|
||||
class UpscaleParams:
|
||||
upscaler_idx: int
|
||||
blend_alpha: float
|
||||
|
||||
c = cached_images.get(key)
|
||||
if c is None:
|
||||
upscaler = shared.sd_upscalers[scaler_index]
|
||||
c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
|
||||
if mode == 1 and crop:
|
||||
cropped = Image.new("RGB", (resize_w, resize_h))
|
||||
cropped.paste(c, box=(resize_w // 2 - c.width // 2, resize_h // 2 - c.height // 2))
|
||||
c = cropped
|
||||
cached_images[key] = c
|
||||
def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
|
||||
blended_result: Image.Image = None
|
||||
image_hash: str = hash(np.array(image.getdata()).tobytes())
|
||||
for upscaler in params:
|
||||
upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode,
|
||||
upscaling_resize_w, upscaling_resize_h, upscaling_crop)
|
||||
cache_key = LruCache.Key(image_hash=image_hash,
|
||||
info_hash=hash(info),
|
||||
args_hash=hash(upscale_args))
|
||||
cached_entry = cached_images.get(cache_key)
|
||||
if cached_entry is None:
|
||||
res = upscale(image, *upscale_args)
|
||||
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n"
|
||||
cached_images.put(cache_key, LruCache.Value(image=res, info=info))
|
||||
else:
|
||||
res, info = cached_entry.image, cached_entry.info
|
||||
|
||||
return c
|
||||
if blended_result is None:
|
||||
blended_result = res
|
||||
else:
|
||||
blended_result = Image.blend(blended_result, res, upscaler.blend_alpha)
|
||||
return (blended_result, info)
|
||||
|
||||
info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
|
||||
res = upscale(image, extras_upscaler_1, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
|
||||
# Build a list of operations to run
|
||||
facefix_ops: List[Callable] = []
|
||||
facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else []
|
||||
facefix_ops += [run_codeformer] if codeformer_visibility > 0 else []
|
||||
|
||||
upscale_ops: List[Callable] = []
|
||||
upscale_ops += [run_prepare_crop] if resize_mode == 1 else []
|
||||
|
||||
if upscaling_resize != 0:
|
||||
step_params: List[UpscaleParams] = []
|
||||
step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0))
|
||||
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
|
||||
res2 = upscale(image, extras_upscaler_2, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
|
||||
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
|
||||
res = Image.blend(res, res2, extras_upscaler_2_visibility)
|
||||
step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility))
|
||||
|
||||
image = res
|
||||
upscale_ops.append(partial(run_upscalers_blend, step_params))
|
||||
|
||||
while len(cached_images) > 2:
|
||||
del cached_images[next(iter(cached_images.keys()))]
|
||||
extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops)
|
||||
|
||||
images.save_image(image, path=outpath, basename="", seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
|
||||
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo,
|
||||
forced_filename=image_name if opts.use_original_name_batch else None)
|
||||
for image, image_name in zip(imageArr, imageNameArr):
|
||||
if image is None:
|
||||
return outputs, "Please select an input image.", ''
|
||||
|
||||
if opts.enable_pnginfo:
|
||||
shared.state.textinfo = f'Processing image {image_name}'
|
||||
|
||||
existing_pnginfo = image.info or {}
|
||||
|
||||
image = image.convert("RGB")
|
||||
info = ""
|
||||
# Run each operation on each image
|
||||
for op in extras_ops:
|
||||
image, info = op(image, info)
|
||||
|
||||
if opts.use_original_name_batch and image_name is not None:
|
||||
basename = os.path.splitext(os.path.basename(image_name))[0]
|
||||
else:
|
||||
basename = ''
|
||||
|
||||
if opts.enable_pnginfo: # append info before save
|
||||
image.info = existing_pnginfo
|
||||
image.info["extras"] = info
|
||||
|
||||
if save_output:
|
||||
# Add upscaler name as a suffix.
|
||||
suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else ""
|
||||
# Add second upscaler if applicable.
|
||||
if suffix and extras_upscaler_2 and extras_upscaler_2_visibility:
|
||||
suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}"
|
||||
|
||||
images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
|
||||
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix)
|
||||
|
||||
if extras_mode != 2 or show_extras_results :
|
||||
outputs.append(image)
|
||||
|
||||
@ -134,30 +223,16 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
||||
|
||||
return outputs, plaintext_to_html(info), ''
|
||||
|
||||
def clear_cache():
|
||||
cached_images.clear()
|
||||
|
||||
|
||||
def run_pnginfo(image):
|
||||
if image is None:
|
||||
return '', '', ''
|
||||
|
||||
items = image.info
|
||||
geninfo = ''
|
||||
|
||||
if "exif" in image.info:
|
||||
exif = piexif.load(image.info["exif"])
|
||||
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
|
||||
try:
|
||||
exif_comment = piexif.helper.UserComment.load(exif_comment)
|
||||
except ValueError:
|
||||
exif_comment = exif_comment.decode('utf8', errors="ignore")
|
||||
|
||||
items['exif comment'] = exif_comment
|
||||
geninfo = exif_comment
|
||||
|
||||
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
|
||||
'loop', 'background', 'timestamp', 'duration']:
|
||||
items.pop(field, None)
|
||||
|
||||
geninfo = items.get('parameters', geninfo)
|
||||
geninfo, items = images.read_info_from_image(image)
|
||||
items = {**{'parameters': geninfo}, **items}
|
||||
|
||||
info = ''
|
||||
for key, text in items.items():
|
||||
@ -175,7 +250,10 @@ def run_pnginfo(image):
|
||||
return '', geninfo, info
|
||||
|
||||
|
||||
def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name):
|
||||
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
|
||||
shared.state.begin()
|
||||
shared.state.job = 'model-merge'
|
||||
|
||||
def weighted_sum(theta0, theta1, alpha):
|
||||
return ((1 - alpha) * theta0) + (alpha * theta1)
|
||||
|
||||
@ -187,23 +265,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
|
||||
|
||||
primary_model_info = sd_models.checkpoints_list[primary_model_name]
|
||||
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
|
||||
teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None)
|
||||
|
||||
print(f"Loading {primary_model_info.filename}...")
|
||||
primary_model = torch.load(primary_model_info.filename, map_location='cpu')
|
||||
theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
|
||||
|
||||
print(f"Loading {secondary_model_info.filename}...")
|
||||
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
|
||||
theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
|
||||
|
||||
if teritary_model_info is not None:
|
||||
print(f"Loading {teritary_model_info.filename}...")
|
||||
teritary_model = torch.load(teritary_model_info.filename, map_location='cpu')
|
||||
theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model)
|
||||
else:
|
||||
teritary_model = None
|
||||
theta_2 = None
|
||||
tertiary_model_info = sd_models.checkpoints_list.get(tertiary_model_name, None)
|
||||
result_is_inpainting_model = False
|
||||
|
||||
theta_funcs = {
|
||||
"Weighted sum": (None, weighted_sum),
|
||||
@ -211,9 +274,19 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
|
||||
}
|
||||
theta_func1, theta_func2 = theta_funcs[interp_method]
|
||||
|
||||
print(f"Merging...")
|
||||
if theta_func1 and not tertiary_model_info:
|
||||
shared.state.textinfo = "Failed: Interpolation method requires a tertiary model."
|
||||
shared.state.end()
|
||||
return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
|
||||
|
||||
shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
|
||||
print(f"Loading {secondary_model_info.filename}...")
|
||||
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
||||
|
||||
if theta_func1:
|
||||
print(f"Loading {tertiary_model_info.filename}...")
|
||||
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
|
||||
|
||||
for key in tqdm.tqdm(theta_1.keys()):
|
||||
if 'model' in key:
|
||||
if key in theta_2:
|
||||
@ -221,12 +294,33 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
|
||||
theta_1[key] = theta_func1(theta_1[key], t2)
|
||||
else:
|
||||
theta_1[key] = torch.zeros_like(theta_1[key])
|
||||
del theta_2, teritary_model
|
||||
del theta_2
|
||||
|
||||
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
|
||||
print(f"Loading {primary_model_info.filename}...")
|
||||
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
|
||||
|
||||
print("Merging...")
|
||||
|
||||
for key in tqdm.tqdm(theta_0.keys()):
|
||||
if 'model' in key and key in theta_1:
|
||||
a = theta_0[key]
|
||||
b = theta_1[key]
|
||||
|
||||
theta_0[key] = theta_func2(theta_0[key], theta_1[key], multiplier)
|
||||
shared.state.textinfo = f'Merging layer {key}'
|
||||
# this enables merging an inpainting model (A) with another one (B);
|
||||
# where normal model would have 4 channels, for latenst space, inpainting model would
|
||||
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
|
||||
if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]:
|
||||
if a.shape[1] == 4 and b.shape[1] == 9:
|
||||
raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")
|
||||
|
||||
assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
|
||||
|
||||
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
|
||||
result_is_inpainting_model = True
|
||||
else:
|
||||
theta_0[key] = theta_func2(a, b, multiplier)
|
||||
|
||||
if save_as_half:
|
||||
theta_0[key] = theta_0[key].half()
|
||||
@ -237,17 +331,35 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
|
||||
theta_0[key] = theta_1[key]
|
||||
if save_as_half:
|
||||
theta_0[key] = theta_0[key].half()
|
||||
del theta_1
|
||||
|
||||
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
|
||||
|
||||
filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
|
||||
filename = filename if custom_name == '' else (custom_name + '.ckpt')
|
||||
filename = \
|
||||
primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + \
|
||||
secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + \
|
||||
interp_method.replace(" ", "_") + \
|
||||
'-merged.' + \
|
||||
("inpainting." if result_is_inpainting_model else "") + \
|
||||
checkpoint_format
|
||||
|
||||
filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format)
|
||||
|
||||
output_modelname = os.path.join(ckpt_dir, filename)
|
||||
|
||||
shared.state.textinfo = f"Saving to {output_modelname}..."
|
||||
print(f"Saving to {output_modelname}...")
|
||||
torch.save(primary_model, output_modelname)
|
||||
|
||||
_, extension = os.path.splitext(output_modelname)
|
||||
if extension.lower() == ".safetensors":
|
||||
safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"})
|
||||
else:
|
||||
torch.save(theta_0, output_modelname)
|
||||
|
||||
sd_models.list_models()
|
||||
|
||||
print(f"Checkpoint saved.")
|
||||
print("Checkpoint saved.")
|
||||
shared.state.textinfo = "Checkpoint saved to " + output_modelname
|
||||
shared.state.end()
|
||||
|
||||
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
|
||||
|
@ -1,14 +1,222 @@
|
||||
import base64
|
||||
import io
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
from modules.shared import script_path
|
||||
from modules import shared
|
||||
from modules import shared, ui_tempdir
|
||||
import tempfile
|
||||
from PIL import Image
|
||||
|
||||
re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)"
|
||||
re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
|
||||
re_param = re.compile(re_param_code)
|
||||
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
|
||||
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
||||
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
|
||||
type_of_gr_update = type(gr.update())
|
||||
paste_fields = {}
|
||||
bind_list = []
|
||||
|
||||
|
||||
def reset():
|
||||
paste_fields.clear()
|
||||
bind_list.clear()
|
||||
|
||||
|
||||
def quote(text):
|
||||
if ',' not in str(text):
|
||||
return text
|
||||
|
||||
text = str(text)
|
||||
text = text.replace('\\', '\\\\')
|
||||
text = text.replace('"', '\\"')
|
||||
return f'"{text}"'
|
||||
|
||||
|
||||
def image_from_url_text(filedata):
|
||||
if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
|
||||
filedata = filedata[0]
|
||||
|
||||
if type(filedata) == dict and filedata.get("is_file", False):
|
||||
filename = filedata["name"]
|
||||
is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
|
||||
assert is_in_right_dir, 'trying to open image file outside of allowed directories'
|
||||
|
||||
return Image.open(filename)
|
||||
|
||||
if type(filedata) == list:
|
||||
if len(filedata) == 0:
|
||||
return None
|
||||
|
||||
filedata = filedata[0]
|
||||
|
||||
if filedata.startswith("data:image/png;base64,"):
|
||||
filedata = filedata[len("data:image/png;base64,"):]
|
||||
|
||||
filedata = base64.decodebytes(filedata.encode('utf-8'))
|
||||
image = Image.open(io.BytesIO(filedata))
|
||||
return image
|
||||
|
||||
|
||||
def add_paste_fields(tabname, init_img, fields):
|
||||
paste_fields[tabname] = {"init_img": init_img, "fields": fields}
|
||||
|
||||
# backwards compatibility for existing extensions
|
||||
import modules.ui
|
||||
if tabname == 'txt2img':
|
||||
modules.ui.txt2img_paste_fields = fields
|
||||
elif tabname == 'img2img':
|
||||
modules.ui.img2img_paste_fields = fields
|
||||
|
||||
|
||||
def integrate_settings_paste_fields(component_dict):
|
||||
from modules import ui
|
||||
|
||||
settings_map = {
|
||||
'sd_hypernetwork': 'Hypernet',
|
||||
'sd_hypernetwork_strength': 'Hypernet strength',
|
||||
'CLIP_stop_at_last_layers': 'Clip skip',
|
||||
'inpainting_mask_weight': 'Conditional mask weight',
|
||||
'sd_model_checkpoint': 'Model hash',
|
||||
'eta_noise_seed_delta': 'ENSD',
|
||||
'initial_noise_multiplier': 'Noise multiplier',
|
||||
}
|
||||
settings_paste_fields = [
|
||||
(component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None)))
|
||||
for k, v in settings_map.items()
|
||||
]
|
||||
|
||||
for tabname, info in paste_fields.items():
|
||||
if info["fields"] is not None:
|
||||
info["fields"] += settings_paste_fields
|
||||
|
||||
|
||||
def create_buttons(tabs_list):
|
||||
buttons = {}
|
||||
for tab in tabs_list:
|
||||
buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
|
||||
return buttons
|
||||
|
||||
|
||||
#if send_generate_info is a tab name, mean generate_info comes from the params fields of the tab
|
||||
def bind_buttons(buttons, send_image, send_generate_info):
|
||||
bind_list.append([buttons, send_image, send_generate_info])
|
||||
|
||||
|
||||
def send_image_and_dimensions(x):
|
||||
if isinstance(x, Image.Image):
|
||||
img = x
|
||||
else:
|
||||
img = image_from_url_text(x)
|
||||
|
||||
if shared.opts.send_size and isinstance(img, Image.Image):
|
||||
w = img.width
|
||||
h = img.height
|
||||
else:
|
||||
w = gr.update()
|
||||
h = gr.update()
|
||||
|
||||
return img, w, h
|
||||
|
||||
|
||||
def run_bind():
|
||||
for buttons, source_image_component, send_generate_info in bind_list:
|
||||
for tab in buttons:
|
||||
button = buttons[tab]
|
||||
destination_image_component = paste_fields[tab]["init_img"]
|
||||
fields = paste_fields[tab]["fields"]
|
||||
|
||||
destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
|
||||
destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
|
||||
|
||||
if source_image_component and destination_image_component:
|
||||
if isinstance(source_image_component, gr.Gallery):
|
||||
func = send_image_and_dimensions if destination_width_component else image_from_url_text
|
||||
jsfunc = "extract_image_from_gallery"
|
||||
else:
|
||||
func = send_image_and_dimensions if destination_width_component else lambda x: x
|
||||
jsfunc = None
|
||||
|
||||
button.click(
|
||||
fn=func,
|
||||
_js=jsfunc,
|
||||
inputs=[source_image_component],
|
||||
outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
|
||||
)
|
||||
|
||||
if send_generate_info and fields is not None:
|
||||
if send_generate_info in paste_fields:
|
||||
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
|
||||
button.click(
|
||||
fn=lambda *x: x,
|
||||
inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
|
||||
outputs=[field for field, name in fields if name in paste_field_names],
|
||||
)
|
||||
else:
|
||||
connect_paste(button, fields, send_generate_info)
|
||||
|
||||
button.click(
|
||||
fn=None,
|
||||
_js=f"switch_to_{tab}",
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
)
|
||||
|
||||
|
||||
def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
|
||||
"""Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
|
||||
|
||||
Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
|
||||
parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
|
||||
|
||||
If the infotext has no hash, then a hypernet with the same name will be selected instead.
|
||||
"""
|
||||
hypernet_name = hypernet_name.lower()
|
||||
if hypernet_hash is not None:
|
||||
# Try to match the hash in the name
|
||||
for hypernet_key in shared.hypernetworks.keys():
|
||||
result = re_hypernet_hash.search(hypernet_key)
|
||||
if result is not None and result[1] == hypernet_hash:
|
||||
return hypernet_key
|
||||
else:
|
||||
# Fall back to a hypernet with the same name
|
||||
for hypernet_key in shared.hypernetworks.keys():
|
||||
if hypernet_key.lower().startswith(hypernet_name):
|
||||
return hypernet_key
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def restore_old_hires_fix_params(res):
|
||||
"""for infotexts that specify old First pass size parameter, convert it into
|
||||
width, height, and hr scale"""
|
||||
|
||||
firstpass_width = res.get('First pass size-1', None)
|
||||
firstpass_height = res.get('First pass size-2', None)
|
||||
|
||||
if firstpass_width is None or firstpass_height is None:
|
||||
return
|
||||
|
||||
firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
|
||||
width = int(res.get("Size-1", 512))
|
||||
height = int(res.get("Size-2", 512))
|
||||
|
||||
if firstpass_width == 0 or firstpass_height == 0:
|
||||
# old algorithm for auto-calculating first pass size
|
||||
desired_pixel_count = 512 * 512
|
||||
actual_pixel_count = width * height
|
||||
scale = math.sqrt(desired_pixel_count / actual_pixel_count)
|
||||
firstpass_width = math.ceil(scale * width / 64) * 64
|
||||
firstpass_height = math.ceil(scale * height / 64) * 64
|
||||
|
||||
hr_scale = width / firstpass_width if firstpass_width > 0 else height / firstpass_height
|
||||
|
||||
res['Size-1'] = firstpass_width
|
||||
res['Size-2'] = firstpass_height
|
||||
res['Hires upscale'] = hr_scale
|
||||
|
||||
|
||||
def parse_generation_parameters(x: str):
|
||||
@ -56,10 +264,24 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||
else:
|
||||
res[k] = v
|
||||
|
||||
# Missing CLIP skip means it was set to 1 (the default)
|
||||
if "Clip skip" not in res:
|
||||
res["Clip skip"] = "1"
|
||||
|
||||
if "Hypernet strength" not in res:
|
||||
res["Hypernet strength"] = "1"
|
||||
|
||||
if "Hypernet" in res:
|
||||
hypernet_name = res["Hypernet"]
|
||||
hypernet_hash = res.get("Hypernet hash", None)
|
||||
res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
|
||||
|
||||
restore_old_hires_fix_params(res)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def connect_paste(button, paste_fields, input_comp, js=None):
|
||||
def connect_paste(button, paste_fields, input_comp, jsfunc=None):
|
||||
def paste_func(prompt):
|
||||
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
|
||||
filename = os.path.join(script_path, "params.txt")
|
||||
@ -83,7 +305,12 @@ def connect_paste(button, paste_fields, input_comp, js=None):
|
||||
else:
|
||||
try:
|
||||
valtype = type(output.value)
|
||||
|
||||
if valtype == bool and v == "False":
|
||||
val = False
|
||||
else:
|
||||
val = valtype(v)
|
||||
|
||||
res.append(gr.update(value=val))
|
||||
except Exception:
|
||||
res.append(gr.update())
|
||||
@ -92,7 +319,9 @@ def connect_paste(button, paste_fields, input_comp, js=None):
|
||||
|
||||
button.click(
|
||||
fn=paste_func,
|
||||
_js=js,
|
||||
_js=jsfunc,
|
||||
inputs=[input_comp],
|
||||
outputs=[x[0] for x in paste_fields],
|
||||
)
|
||||
|
||||
|
||||
|
@ -36,7 +36,9 @@ def gfpgann():
|
||||
else:
|
||||
print("Unable to load gfpgan model!")
|
||||
return None
|
||||
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
|
||||
if hasattr(facexlib.detection.retinaface, 'device'):
|
||||
facexlib.detection.retinaface.device = devices.device_gfpgan
|
||||
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
|
||||
loaded_gfpgan_model = model
|
||||
|
||||
return model
|
||||
|
@ -1,40 +1,72 @@
|
||||
import csv
|
||||
import datetime
|
||||
import glob
|
||||
import html
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import tqdm
|
||||
import csv
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
|
||||
from ldm.util import default
|
||||
from modules import devices, shared, processing, sd_models
|
||||
import torch
|
||||
from torch import einsum
|
||||
from einops import rearrange, repeat
|
||||
import modules.textual_inversion.dataset
|
||||
import torch
|
||||
import tqdm
|
||||
from einops import rearrange, repeat
|
||||
from ldm.util import default
|
||||
from modules import devices, processing, sd_models, shared, sd_samplers
|
||||
from modules.textual_inversion import textual_inversion
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
from torch import einsum
|
||||
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
|
||||
|
||||
from collections import defaultdict, deque
|
||||
from statistics import stdev, mean
|
||||
|
||||
|
||||
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
|
||||
|
||||
class HypernetworkModule(torch.nn.Module):
|
||||
multiplier = 1.0
|
||||
activation_dict = {
|
||||
"linear": torch.nn.Identity,
|
||||
"relu": torch.nn.ReLU,
|
||||
"leakyrelu": torch.nn.LeakyReLU,
|
||||
"elu": torch.nn.ELU,
|
||||
"swish": torch.nn.Hardswish,
|
||||
"tanh": torch.nn.Tanh,
|
||||
"sigmoid": torch.nn.Sigmoid,
|
||||
}
|
||||
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
|
||||
|
||||
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False):
|
||||
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
|
||||
add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False):
|
||||
super().__init__()
|
||||
|
||||
assert layer_structure is not None, "layer_structure mut not be None"
|
||||
assert layer_structure is not None, "layer_structure must not be None"
|
||||
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
||||
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
||||
|
||||
linears = []
|
||||
for i in range(len(layer_structure) - 1):
|
||||
|
||||
# Add a fully-connected layer
|
||||
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
||||
|
||||
# Add an activation func except last layer
|
||||
if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
|
||||
pass
|
||||
elif activation_func in self.activation_dict:
|
||||
linears.append(self.activation_dict[activation_func]())
|
||||
else:
|
||||
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
|
||||
|
||||
# Add layer normalization
|
||||
if add_layer_norm:
|
||||
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
||||
|
||||
# Add dropout except last layer
|
||||
if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2):
|
||||
linears.append(torch.nn.Dropout(p=0.3))
|
||||
|
||||
self.linear = torch.nn.Sequential(*linears)
|
||||
|
||||
if state_dict is not None:
|
||||
@ -42,9 +74,25 @@ class HypernetworkModule(torch.nn.Module):
|
||||
self.load_state_dict(state_dict)
|
||||
else:
|
||||
for layer in self.linear:
|
||||
layer.weight.data.normal_(mean=0.0, std=0.01)
|
||||
layer.bias.data.zero_()
|
||||
|
||||
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
||||
w, b = layer.weight.data, layer.bias.data
|
||||
if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
|
||||
normal_(w, mean=0.0, std=0.01)
|
||||
normal_(b, mean=0.0, std=0)
|
||||
elif weight_init == 'XavierUniform':
|
||||
xavier_uniform_(w)
|
||||
zeros_(b)
|
||||
elif weight_init == 'XavierNormal':
|
||||
xavier_normal_(w)
|
||||
zeros_(b)
|
||||
elif weight_init == 'KaimingUniform':
|
||||
kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
|
||||
zeros_(b)
|
||||
elif weight_init == 'KaimingNormal':
|
||||
kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
|
||||
zeros_(b)
|
||||
else:
|
||||
raise KeyError(f"Key {weight_init} is not defined as initialization!")
|
||||
self.to(devices.device)
|
||||
|
||||
def fix_old_state_dict(self, state_dict):
|
||||
@ -69,6 +117,7 @@ class HypernetworkModule(torch.nn.Module):
|
||||
def trainables(self):
|
||||
layer_structure = []
|
||||
for layer in self.linear:
|
||||
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
||||
layer_structure += [layer.weight, layer.bias]
|
||||
return layer_structure
|
||||
|
||||
@ -81,7 +130,7 @@ class Hypernetwork:
|
||||
filename = None
|
||||
name = None
|
||||
|
||||
def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False):
|
||||
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
|
||||
self.filename = None
|
||||
self.name = name
|
||||
self.layers = {}
|
||||
@ -89,26 +138,48 @@ class Hypernetwork:
|
||||
self.sd_checkpoint = None
|
||||
self.sd_checkpoint_name = None
|
||||
self.layer_structure = layer_structure
|
||||
self.activation_func = activation_func
|
||||
self.weight_init = weight_init
|
||||
self.add_layer_norm = add_layer_norm
|
||||
self.use_dropout = use_dropout
|
||||
self.activate_output = activate_output
|
||||
self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True
|
||||
self.optimizer_name = None
|
||||
self.optimizer_state_dict = None
|
||||
|
||||
for size in enable_sizes or []:
|
||||
self.layers[size] = (
|
||||
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
|
||||
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
|
||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
||||
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
|
||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
||||
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
|
||||
)
|
||||
self.eval_mode()
|
||||
|
||||
def weights(self):
|
||||
res = []
|
||||
for k, layers in self.layers.items():
|
||||
for layer in layers:
|
||||
res += layer.parameters()
|
||||
return res
|
||||
|
||||
def train_mode(self):
|
||||
for k, layers in self.layers.items():
|
||||
for layer in layers:
|
||||
layer.train()
|
||||
res += layer.trainables()
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
return res
|
||||
def eval_mode(self):
|
||||
for k, layers in self.layers.items():
|
||||
for layer in layers:
|
||||
layer.eval()
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def save(self, filename):
|
||||
state_dict = {}
|
||||
optimizer_saved_dict = {}
|
||||
|
||||
for k, v in self.layers.items():
|
||||
state_dict[k] = (v[0].state_dict(), v[1].state_dict())
|
||||
@ -116,11 +187,23 @@ class Hypernetwork:
|
||||
state_dict['step'] = self.step
|
||||
state_dict['name'] = self.name
|
||||
state_dict['layer_structure'] = self.layer_structure
|
||||
state_dict['activation_func'] = self.activation_func
|
||||
state_dict['is_layer_norm'] = self.add_layer_norm
|
||||
state_dict['weight_initialization'] = self.weight_init
|
||||
state_dict['use_dropout'] = self.use_dropout
|
||||
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
||||
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
||||
state_dict['activate_output'] = self.activate_output
|
||||
state_dict['last_layer_dropout'] = self.last_layer_dropout
|
||||
|
||||
if self.optimizer_name is not None:
|
||||
optimizer_saved_dict['optimizer_name'] = self.optimizer_name
|
||||
|
||||
torch.save(state_dict, filename)
|
||||
if shared.opts.save_optimizer_state and self.optimizer_state_dict:
|
||||
optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
|
||||
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
|
||||
torch.save(optimizer_saved_dict, filename + '.optim')
|
||||
|
||||
def load(self, filename):
|
||||
self.filename = filename
|
||||
@ -130,13 +213,38 @@ class Hypernetwork:
|
||||
state_dict = torch.load(filename, map_location='cpu')
|
||||
|
||||
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
||||
print(self.layer_structure)
|
||||
self.activation_func = state_dict.get('activation_func', None)
|
||||
print(f"Activation function is {self.activation_func}")
|
||||
self.weight_init = state_dict.get('weight_initialization', 'Normal')
|
||||
print(f"Weight initialization is {self.weight_init}")
|
||||
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
||||
print(f"Layer norm is set to {self.add_layer_norm}")
|
||||
self.use_dropout = state_dict.get('use_dropout', False)
|
||||
print(f"Dropout usage is set to {self.use_dropout}" )
|
||||
self.activate_output = state_dict.get('activate_output', True)
|
||||
print(f"Activate last layer is set to {self.activate_output}")
|
||||
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
|
||||
|
||||
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
|
||||
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
|
||||
print(f"Optimizer name is {self.optimizer_name}")
|
||||
if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
|
||||
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
||||
else:
|
||||
self.optimizer_state_dict = None
|
||||
if self.optimizer_state_dict:
|
||||
print("Loaded existing optimizer from checkpoint")
|
||||
else:
|
||||
print("No saved optimizer exists in checkpoint")
|
||||
|
||||
for size, sd in state_dict.items():
|
||||
if type(size) == int:
|
||||
self.layers[size] = (
|
||||
HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm),
|
||||
HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm),
|
||||
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
|
||||
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
|
||||
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
|
||||
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
|
||||
)
|
||||
|
||||
self.name = state_dict.get('name', self.name)
|
||||
@ -147,15 +255,18 @@ class Hypernetwork:
|
||||
|
||||
def list_hypernetworks(path):
|
||||
res = {}
|
||||
for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
|
||||
for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)):
|
||||
name = os.path.splitext(os.path.basename(filename))[0]
|
||||
res[name] = filename
|
||||
# Prevent a hypothetical "None.pt" from being listed.
|
||||
if name != "None":
|
||||
res[name + f"({sd_models.model_hash(filename)})"] = filename
|
||||
return res
|
||||
|
||||
|
||||
def load_hypernetwork(filename):
|
||||
path = shared.hypernetworks.get(filename, None)
|
||||
if path is not None:
|
||||
# Prevent any file named "None.pt" from being loaded.
|
||||
if path is not None and filename != "None":
|
||||
print(f"Loading hypernetwork {filename}")
|
||||
try:
|
||||
shared.loaded_hypernetwork = Hypernetwork()
|
||||
@ -166,7 +277,7 @@ def load_hypernetwork(filename):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
else:
|
||||
if shared.loaded_hypernetwork is not None:
|
||||
print(f"Unloading hypernetwork")
|
||||
print("Unloading hypernetwork")
|
||||
|
||||
shared.loaded_hypernetwork = None
|
||||
|
||||
@ -240,16 +351,77 @@ def stack_conds(conds):
|
||||
return torch.stack(conds)
|
||||
|
||||
|
||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
assert hypernetwork_name, 'hypernetwork not selected'
|
||||
def statistics(data):
|
||||
if len(data) < 2:
|
||||
std = 0
|
||||
else:
|
||||
std = stdev(data)
|
||||
total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
|
||||
recent_data = data[-32:]
|
||||
if len(recent_data) < 2:
|
||||
std = 0
|
||||
else:
|
||||
std = stdev(recent_data)
|
||||
recent_information = f"recent 32 loss:{mean(recent_data):.3f}" + u"\u00B1" + f"({std / (len(recent_data) ** 0.5):.3f})"
|
||||
return total_information, recent_information
|
||||
|
||||
|
||||
def report_statistics(loss_info:dict):
|
||||
keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
|
||||
for key in keys:
|
||||
try:
|
||||
print("Loss statistics for file " + key)
|
||||
info, recent = statistics(list(loss_info[key]))
|
||||
print(info)
|
||||
print(recent)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
|
||||
# Remove illegal characters from name.
|
||||
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
||||
|
||||
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
||||
if not overwrite_old:
|
||||
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||
|
||||
if type(layer_structure) == str:
|
||||
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
|
||||
|
||||
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
|
||||
name=name,
|
||||
enable_sizes=[int(x) for x in enable_sizes],
|
||||
layer_structure=layer_structure,
|
||||
activation_func=activation_func,
|
||||
weight_init=weight_init,
|
||||
add_layer_norm=add_layer_norm,
|
||||
use_dropout=use_dropout,
|
||||
)
|
||||
hypernet.save(fn)
|
||||
|
||||
shared.reload_hypernetworks()
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||
from modules import images
|
||||
|
||||
save_hypernetwork_every = save_hypernetwork_every or 0
|
||||
create_image_every = create_image_every or 0
|
||||
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
|
||||
|
||||
path = shared.hypernetworks.get(hypernetwork_name, None)
|
||||
shared.loaded_hypernetwork = Hypernetwork()
|
||||
shared.loaded_hypernetwork.load(path)
|
||||
|
||||
shared.state.job = "train-hypernetwork"
|
||||
shared.state.textinfo = "Initializing hypernetwork training..."
|
||||
shared.state.job_count = steps
|
||||
|
||||
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
|
||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||
|
||||
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
|
||||
@ -267,72 +439,147 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||
else:
|
||||
images_dir = None
|
||||
|
||||
hypernetwork = shared.loaded_hypernetwork
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
|
||||
initial_step = hypernetwork.step or 0
|
||||
if initial_step >= steps:
|
||||
shared.state.textinfo = "Model has already been trained beyond specified max steps"
|
||||
return hypernetwork, filename
|
||||
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||
|
||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
with torch.autocast("cuda"):
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
|
||||
|
||||
pin_memory = shared.opts.pin_memory
|
||||
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
|
||||
|
||||
latent_sampling_method = ds.latent_sampling_method
|
||||
|
||||
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
|
||||
|
||||
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
||||
|
||||
if unload:
|
||||
shared.parallel_processing_allowed = False
|
||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
hypernetwork = shared.loaded_hypernetwork
|
||||
weights = hypernetwork.weights()
|
||||
for weight in weights:
|
||||
weight.requires_grad = True
|
||||
hypernetwork.train_mode()
|
||||
|
||||
losses = torch.zeros((32,))
|
||||
# Here we use optimizer from saved HN, or we can specify as UI option.
|
||||
if hypernetwork.optimizer_name in optimizer_dict:
|
||||
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
|
||||
optimizer_name = hypernetwork.optimizer_name
|
||||
else:
|
||||
print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
|
||||
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
|
||||
optimizer_name = 'AdamW'
|
||||
|
||||
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
|
||||
try:
|
||||
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
|
||||
except RuntimeError as e:
|
||||
print("Cannot resume from saved optimizer!")
|
||||
print(e)
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
batch_size = ds.batch_size
|
||||
gradient_step = ds.gradient_step
|
||||
# n steps = batch_size * gradient_step * n image processed
|
||||
steps_per_epoch = len(ds) // batch_size // gradient_step
|
||||
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
|
||||
loss_step = 0
|
||||
_loss_step = 0 #internal
|
||||
# size = len(ds.indexes)
|
||||
# loss_dict = defaultdict(lambda : deque(maxlen = 1024))
|
||||
# losses = torch.zeros((size,))
|
||||
# previous_mean_losses = [0]
|
||||
# previous_mean_loss = 0
|
||||
# print("Mean loss of {} elements".format(size))
|
||||
|
||||
steps_without_grad = 0
|
||||
|
||||
last_saved_file = "<none>"
|
||||
last_saved_image = "<none>"
|
||||
forced_filename = "<none>"
|
||||
|
||||
ititial_step = hypernetwork.step or 0
|
||||
if ititial_step > steps:
|
||||
return hypernetwork, filename
|
||||
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
||||
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
|
||||
|
||||
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
||||
for i, entries in pbar:
|
||||
hypernetwork.step = i + ititial_step
|
||||
|
||||
pbar = tqdm.tqdm(total=steps - initial_step)
|
||||
try:
|
||||
for i in range((steps-initial_step) * gradient_step):
|
||||
if scheduler.finished:
|
||||
break
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
for j, batch in enumerate(dl):
|
||||
# works as a drop_last=True for gradient accumulation
|
||||
if j == max_steps_per_epoch:
|
||||
break
|
||||
scheduler.apply(optimizer, hypernetwork.step)
|
||||
if scheduler.finished:
|
||||
break
|
||||
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
|
||||
with torch.autocast("cuda"):
|
||||
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
|
||||
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
|
||||
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
|
||||
loss = shared.sd_model(x, c)[0]
|
||||
with devices.autocast():
|
||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||
if tag_drop_out != 0 or shuffle_tags:
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
|
||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
else:
|
||||
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
|
||||
loss = shared.sd_model(x, c)[0] / gradient_step
|
||||
del x
|
||||
del c
|
||||
|
||||
losses[hypernetwork.step % losses.shape[0]] = loss.item()
|
||||
_loss_step += loss.item()
|
||||
scaler.scale(loss).backward()
|
||||
# go back until we reach gradient accumulation steps
|
||||
if (j + 1) % gradient_step != 0:
|
||||
continue
|
||||
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}")
|
||||
# scaler.unscale_(optimizer)
|
||||
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
|
||||
# torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0)
|
||||
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
hypernetwork.step += 1
|
||||
pbar.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
loss_step = _loss_step
|
||||
_loss_step = 0
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
mean_loss = losses.mean()
|
||||
if torch.isnan(mean_loss):
|
||||
raise RuntimeError("Loss diverged.")
|
||||
pbar.set_description(f"loss: {mean_loss:.7f}")
|
||||
steps_done = hypernetwork.step + 1
|
||||
|
||||
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
|
||||
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
|
||||
hypernetwork.save(last_saved_file)
|
||||
epoch_num = hypernetwork.step // steps_per_epoch
|
||||
epoch_step = hypernetwork.step % steps_per_epoch
|
||||
|
||||
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
|
||||
"loss": f"{mean_loss:.7f}",
|
||||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
|
||||
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
|
||||
# Before saving, change name to match current checkpoint.
|
||||
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
|
||||
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
|
||||
hypernetwork.optimizer_name = optimizer_name
|
||||
if shared.opts.save_optimizer_state:
|
||||
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
||||
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
|
||||
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
||||
|
||||
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
|
||||
"loss": f"{loss_step:.7f}",
|
||||
"learn_rate": scheduler.learn_rate
|
||||
})
|
||||
|
||||
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
|
||||
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
|
||||
|
||||
optimizer.zero_grad()
|
||||
if images_dir is not None and steps_done % create_image_every == 0:
|
||||
forced_filename = f'{hypernetwork_name}-{steps_done}'
|
||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||
hypernetwork.eval_mode()
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
|
||||
@ -346,47 +593,75 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||
p.prompt = preview_prompt
|
||||
p.negative_prompt = preview_negative_prompt
|
||||
p.steps = preview_steps
|
||||
p.sampler_index = preview_sampler_index
|
||||
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
||||
p.cfg_scale = preview_cfg_scale
|
||||
p.seed = preview_seed
|
||||
p.width = preview_width
|
||||
p.height = preview_height
|
||||
else:
|
||||
p.prompt = entries[0].cond_text
|
||||
p.prompt = batch.cond_text[0]
|
||||
p.steps = 20
|
||||
p.width = training_width
|
||||
p.height = training_height
|
||||
|
||||
preview_text = p.prompt
|
||||
|
||||
processed = processing.process_images(p)
|
||||
image = processed.images[0] if len(processed.images)>0 else None
|
||||
image = processed.images[0] if len(processed.images) > 0 else None
|
||||
|
||||
if unload:
|
||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
hypernetwork.train_mode()
|
||||
if image is not None:
|
||||
shared.state.current_image = image
|
||||
image.save(last_saved_image)
|
||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
|
||||
shared.state.job_no = hypernetwork.step
|
||||
|
||||
shared.state.textinfo = f"""
|
||||
<p>
|
||||
Loss: {mean_loss:.7f}<br/>
|
||||
Step: {hypernetwork.step}<br/>
|
||||
Last prompt: {html.escape(entries[0].cond_text)}<br/>
|
||||
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
||||
Loss: {loss_step:.7f}<br/>
|
||||
Step: {steps_done}<br/>
|
||||
Last prompt: {html.escape(batch.cond_text[0])}<br/>
|
||||
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
</p>
|
||||
"""
|
||||
except Exception:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
finally:
|
||||
pbar.leave = False
|
||||
pbar.close()
|
||||
hypernetwork.eval_mode()
|
||||
#report_statistics(loss_dict)
|
||||
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||
hypernetwork.optimizer_name = optimizer_name
|
||||
if shared.opts.save_optimizer_state:
|
||||
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
||||
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
|
||||
|
||||
hypernetwork.sd_checkpoint = checkpoint.hash
|
||||
hypernetwork.sd_checkpoint_name = checkpoint.model_name
|
||||
hypernetwork.save(filename)
|
||||
del optimizer
|
||||
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
shared.parallel_processing_allowed = old_parallel_processing_allowed
|
||||
|
||||
return hypernetwork, filename
|
||||
|
||||
|
||||
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
|
||||
old_hypernetwork_name = hypernetwork.name
|
||||
old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
|
||||
old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
|
||||
try:
|
||||
hypernetwork.sd_checkpoint = checkpoint.hash
|
||||
hypernetwork.sd_checkpoint_name = checkpoint.model_name
|
||||
hypernetwork.name = hypernetwork_name
|
||||
hypernetwork.save(filename)
|
||||
except:
|
||||
hypernetwork.sd_checkpoint = old_sd_checkpoint
|
||||
hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
|
||||
hypernetwork.name = old_hypernetwork_name
|
||||
raise
|
||||
|
@ -3,31 +3,16 @@ import os
|
||||
import re
|
||||
|
||||
import gradio as gr
|
||||
import modules.hypernetworks.hypernetwork
|
||||
from modules import devices, sd_hijack, shared
|
||||
|
||||
import modules.textual_inversion.textual_inversion
|
||||
import modules.textual_inversion.preprocess
|
||||
from modules import sd_hijack, shared, devices
|
||||
from modules.hypernetworks import hypernetwork
|
||||
not_available = ["hardswish", "multiheadattention"]
|
||||
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
|
||||
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
|
||||
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout)
|
||||
|
||||
def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False):
|
||||
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
||||
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||
|
||||
if type(layer_structure) == str:
|
||||
layer_structure = tuple(map(int, re.sub(r'\D', '', layer_structure)))
|
||||
|
||||
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
|
||||
name=name,
|
||||
enable_sizes=[int(x) for x in enable_sizes],
|
||||
layer_structure=layer_structure,
|
||||
add_layer_norm=add_layer_norm,
|
||||
)
|
||||
hypernet.save(fn)
|
||||
|
||||
shared.reload_hypernetworks()
|
||||
|
||||
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", ""
|
||||
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
|
||||
|
||||
|
||||
def train_hypernetwork(*args):
|
||||
|
@ -1,4 +1,8 @@
|
||||
import datetime
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import pytz
|
||||
import io
|
||||
import math
|
||||
import os
|
||||
@ -11,8 +15,9 @@ import piexif.helper
|
||||
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
||||
from fonts.ttf import Roboto
|
||||
import string
|
||||
import json
|
||||
|
||||
from modules import sd_samplers, shared
|
||||
from modules import sd_samplers, shared, script_callbacks
|
||||
from modules.shared import opts, cmd_opts
|
||||
|
||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||
@ -34,11 +39,14 @@ def image_grid(imgs, batch_size=1, rows=None):
|
||||
|
||||
cols = math.ceil(len(imgs) / rows)
|
||||
|
||||
w, h = imgs[0].size
|
||||
grid = Image.new('RGB', size=(cols * w, rows * h), color='black')
|
||||
params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
|
||||
script_callbacks.image_grid_callback(params)
|
||||
|
||||
for i, img in enumerate(imgs):
|
||||
grid.paste(img, box=(i % cols * w, i // cols * h))
|
||||
w, h = imgs[0].size
|
||||
grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color='black')
|
||||
|
||||
for i, img in enumerate(params.imgs):
|
||||
grid.paste(img, box=(i % params.cols * w, i // params.cols * h))
|
||||
|
||||
return grid
|
||||
|
||||
@ -131,8 +139,19 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
|
||||
lines.append(word)
|
||||
return lines
|
||||
|
||||
def draw_texts(drawing, draw_x, draw_y, lines):
|
||||
def get_font(fontsize):
|
||||
try:
|
||||
return ImageFont.truetype(opts.font or Roboto, fontsize)
|
||||
except Exception:
|
||||
return ImageFont.truetype(Roboto, fontsize)
|
||||
|
||||
def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
|
||||
for i, line in enumerate(lines):
|
||||
fnt = initial_fnt
|
||||
fontsize = initial_fontsize
|
||||
while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
|
||||
fontsize -= 1
|
||||
fnt = get_font(fontsize)
|
||||
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
|
||||
|
||||
if not line.is_active:
|
||||
@ -143,10 +162,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
|
||||
fontsize = (width + height) // 25
|
||||
line_spacing = fontsize // 2
|
||||
|
||||
try:
|
||||
fnt = ImageFont.truetype(opts.font or Roboto, fontsize)
|
||||
except Exception:
|
||||
fnt = ImageFont.truetype(Roboto, fontsize)
|
||||
fnt = get_font(fontsize)
|
||||
|
||||
color_active = (0, 0, 0)
|
||||
color_inactive = (153, 153, 153)
|
||||
@ -173,6 +189,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
|
||||
for line in texts:
|
||||
bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)
|
||||
line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
|
||||
line.allowed_width = allowed_width
|
||||
|
||||
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
|
||||
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in
|
||||
@ -189,13 +206,13 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
|
||||
x = pad_left + width * col + width / 2
|
||||
y = pad_top / 2 - hor_text_heights[col] / 2
|
||||
|
||||
draw_texts(d, x, y, hor_texts[col])
|
||||
draw_texts(d, x, y, hor_texts[col], fnt, fontsize)
|
||||
|
||||
for row in range(rows):
|
||||
x = pad_left / 2
|
||||
y = pad_top + height * row + height / 2 - ver_text_heights[row] / 2
|
||||
|
||||
draw_texts(d, x, y, ver_texts[row])
|
||||
draw_texts(d, x, y, ver_texts[row], fnt, fontsize)
|
||||
|
||||
return result
|
||||
|
||||
@ -213,16 +230,32 @@ def draw_prompt_matrix(im, width, height, all_prompts):
|
||||
return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
|
||||
|
||||
|
||||
def resize_image(resize_mode, im, width, height):
|
||||
def resize_image(resize_mode, im, width, height, upscaler_name=None):
|
||||
"""
|
||||
Resizes an image with the specified resize_mode, width, and height.
|
||||
|
||||
Args:
|
||||
resize_mode: The mode to use when resizing the image.
|
||||
0: Resize the image to the specified width and height.
|
||||
1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
|
||||
2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
|
||||
im: The image to resize.
|
||||
width: The width to resize the image to.
|
||||
height: The height to resize the image to.
|
||||
upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
|
||||
"""
|
||||
|
||||
upscaler_name = upscaler_name or opts.upscaler_for_img2img
|
||||
|
||||
def resize(im, w, h):
|
||||
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L':
|
||||
if upscaler_name is None or upscaler_name == "None" or im.mode == 'L':
|
||||
return im.resize((w, h), resample=LANCZOS)
|
||||
|
||||
scale = max(w / im.width, h / im.height)
|
||||
|
||||
if scale > 1.0:
|
||||
upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img]
|
||||
assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}"
|
||||
upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
|
||||
assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}"
|
||||
|
||||
upscaler = upscalers[0]
|
||||
im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
|
||||
@ -273,10 +306,15 @@ invalid_filename_chars = '<>:"/\\|?*\n'
|
||||
invalid_filename_prefix = ' '
|
||||
invalid_filename_postfix = ' .'
|
||||
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
||||
re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
|
||||
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
|
||||
max_filename_part_length = 128
|
||||
|
||||
|
||||
def sanitize_filename_part(text, replace_spaces=True):
|
||||
if text is None:
|
||||
return None
|
||||
|
||||
if replace_spaces:
|
||||
text = text.replace(' ', '_')
|
||||
|
||||
@ -286,49 +324,105 @@ def sanitize_filename_part(text, replace_spaces=True):
|
||||
return text
|
||||
|
||||
|
||||
def apply_filename_pattern(x, p, seed, prompt):
|
||||
max_prompt_words = opts.directories_max_prompt_words
|
||||
class FilenameGenerator:
|
||||
replacements = {
|
||||
'seed': lambda self: self.seed if self.seed is not None else '',
|
||||
'steps': lambda self: self.p and self.p.steps,
|
||||
'cfg': lambda self: self.p and self.p.cfg_scale,
|
||||
'width': lambda self: self.image.width,
|
||||
'height': lambda self: self.image.height,
|
||||
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
|
||||
'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
|
||||
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
|
||||
'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.model_name, replace_spaces=False),
|
||||
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
|
||||
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
|
||||
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
|
||||
'prompt': lambda self: sanitize_filename_part(self.prompt),
|
||||
'prompt_no_styles': lambda self: self.prompt_no_style(),
|
||||
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
|
||||
'prompt_words': lambda self: self.prompt_words(),
|
||||
}
|
||||
default_time_format = '%Y%m%d%H%M%S'
|
||||
|
||||
if seed is not None:
|
||||
x = x.replace("[seed]", str(seed))
|
||||
def __init__(self, p, seed, prompt, image):
|
||||
self.p = p
|
||||
self.seed = seed
|
||||
self.prompt = prompt
|
||||
self.image = image
|
||||
|
||||
if p is not None:
|
||||
x = x.replace("[steps]", str(p.steps))
|
||||
x = x.replace("[cfg]", str(p.cfg_scale))
|
||||
x = x.replace("[width]", str(p.width))
|
||||
x = x.replace("[height]", str(p.height))
|
||||
x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]) or "None", replace_spaces=False))
|
||||
x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
|
||||
def prompt_no_style(self):
|
||||
if self.p is None or self.prompt is None:
|
||||
return None
|
||||
|
||||
x = x.replace("[model_hash]", getattr(p, "sd_model_hash", shared.sd_model.sd_model_hash))
|
||||
x = x.replace("[date]", datetime.date.today().isoformat())
|
||||
x = x.replace("[datetime]", datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
|
||||
x = x.replace("[job_timestamp]", getattr(p, "job_timestamp", shared.state.job_timestamp))
|
||||
|
||||
# Apply [prompt] at last. Because it may contain any replacement word.^M
|
||||
if prompt is not None:
|
||||
x = x.replace("[prompt]", sanitize_filename_part(prompt))
|
||||
if "[prompt_no_styles]" in x:
|
||||
prompt_no_style = prompt
|
||||
for style in shared.prompt_styles.get_style_prompts(p.styles):
|
||||
prompt_no_style = self.prompt
|
||||
for style in shared.prompt_styles.get_style_prompts(self.p.styles):
|
||||
if len(style) > 0:
|
||||
style_parts = [y for y in style.split("{prompt}")]
|
||||
for part in style_parts:
|
||||
for part in style.split("{prompt}"):
|
||||
prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
|
||||
prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
|
||||
x = x.replace("[prompt_no_styles]", sanitize_filename_part(prompt_no_style, replace_spaces=False))
|
||||
|
||||
x = x.replace("[prompt_spaces]", sanitize_filename_part(prompt, replace_spaces=False))
|
||||
if "[prompt_words]" in x:
|
||||
words = [x for x in re_nonletters.split(prompt or "") if len(x) > 0]
|
||||
prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
|
||||
|
||||
return sanitize_filename_part(prompt_no_style, replace_spaces=False)
|
||||
|
||||
def prompt_words(self):
|
||||
words = [x for x in re_nonletters.split(self.prompt or "") if len(x) > 0]
|
||||
if len(words) == 0:
|
||||
words = ["empty"]
|
||||
x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
|
||||
return sanitize_filename_part(" ".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False)
|
||||
|
||||
if cmd_opts.hide_ui_dir_config:
|
||||
x = re.sub(r'^[\\/]+|\.{2,}[\\/]+|[\\/]+\.{2,}', '', x)
|
||||
def datetime(self, *args):
|
||||
time_datetime = datetime.datetime.now()
|
||||
|
||||
return x
|
||||
time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
|
||||
try:
|
||||
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
|
||||
except pytz.exceptions.UnknownTimeZoneError as _:
|
||||
time_zone = None
|
||||
|
||||
time_zone_time = time_datetime.astimezone(time_zone)
|
||||
try:
|
||||
formatted_time = time_zone_time.strftime(time_format)
|
||||
except (ValueError, TypeError) as _:
|
||||
formatted_time = time_zone_time.strftime(self.default_time_format)
|
||||
|
||||
return sanitize_filename_part(formatted_time, replace_spaces=False)
|
||||
|
||||
def apply(self, x):
|
||||
res = ''
|
||||
|
||||
for m in re_pattern.finditer(x):
|
||||
text, pattern = m.groups()
|
||||
res += text
|
||||
|
||||
if pattern is None:
|
||||
continue
|
||||
|
||||
pattern_args = []
|
||||
while True:
|
||||
m = re_pattern_arg.match(pattern)
|
||||
if m is None:
|
||||
break
|
||||
|
||||
pattern, arg = m.groups()
|
||||
pattern_args.insert(0, arg)
|
||||
|
||||
fun = self.replacements.get(pattern.lower())
|
||||
if fun is not None:
|
||||
try:
|
||||
replacement = fun(self, *pattern_args)
|
||||
except Exception:
|
||||
replacement = None
|
||||
print(f"Error adding [{pattern}] to filename", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
if replacement is not None:
|
||||
res += str(replacement)
|
||||
continue
|
||||
|
||||
res += f'[{pattern}]'
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def get_next_sequence_number(path, basename):
|
||||
@ -354,7 +448,7 @@ def get_next_sequence_number(path, basename):
|
||||
|
||||
|
||||
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
|
||||
'''Save an image.
|
||||
"""Save an image.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image`):
|
||||
@ -385,66 +479,94 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
The full path of the saved imaged.
|
||||
txt_fullfn (`str` or None):
|
||||
If a text file is saved for this image, this will be its full path. Otherwise None.
|
||||
'''
|
||||
if short_filename or prompt is None or seed is None:
|
||||
"""
|
||||
namegen = FilenameGenerator(p, seed, prompt, image)
|
||||
|
||||
if save_to_dirs is None:
|
||||
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
|
||||
|
||||
if save_to_dirs:
|
||||
dirname = namegen.apply(opts.directories_filename_pattern or "[prompt_words]").lstrip(' ').rstrip('\\ /')
|
||||
path = os.path.join(path, dirname)
|
||||
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
if forced_filename is None:
|
||||
if short_filename or seed is None:
|
||||
file_decoration = ""
|
||||
elif opts.save_to_dirs:
|
||||
file_decoration = opts.samples_filename_pattern or "[seed]"
|
||||
else:
|
||||
file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
|
||||
|
||||
if file_decoration != "":
|
||||
file_decoration = "-" + file_decoration.lower()
|
||||
add_number = opts.save_images_add_number or file_decoration == ''
|
||||
|
||||
file_decoration = apply_filename_pattern(file_decoration, p, seed, prompt) + suffix
|
||||
if file_decoration != "" and add_number:
|
||||
file_decoration = "-" + file_decoration
|
||||
|
||||
if extension == 'png' and opts.enable_pnginfo and info is not None:
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
file_decoration = namegen.apply(file_decoration) + suffix
|
||||
|
||||
if existing_info is not None:
|
||||
for k, v in existing_info.items():
|
||||
pnginfo.add_text(k, str(v))
|
||||
|
||||
pnginfo.add_text(pnginfo_section_name, info)
|
||||
else:
|
||||
pnginfo = None
|
||||
|
||||
if save_to_dirs is None:
|
||||
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
|
||||
|
||||
if save_to_dirs:
|
||||
dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt).strip('\\ /')
|
||||
path = os.path.join(path, dirname)
|
||||
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
if forced_filename is None:
|
||||
if add_number:
|
||||
basecount = get_next_sequence_number(path, basename)
|
||||
fullfn = "a.png"
|
||||
fullfn_without_extension = "a"
|
||||
fullfn = None
|
||||
for i in range(500):
|
||||
fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
|
||||
fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
|
||||
fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
|
||||
if not os.path.exists(fullfn):
|
||||
break
|
||||
else:
|
||||
fullfn = os.path.join(path, f"{file_decoration}.{extension}")
|
||||
else:
|
||||
fullfn = os.path.join(path, f"{forced_filename}.{extension}")
|
||||
fullfn_without_extension = os.path.join(path, forced_filename)
|
||||
|
||||
def exif_bytes():
|
||||
return piexif.dump({
|
||||
pnginfo = existing_info or {}
|
||||
if info is not None:
|
||||
pnginfo[pnginfo_section_name] = info
|
||||
|
||||
params = script_callbacks.ImageSaveParams(image, p, fullfn, pnginfo)
|
||||
script_callbacks.before_image_saved_callback(params)
|
||||
|
||||
image = params.image
|
||||
fullfn = params.filename
|
||||
info = params.pnginfo.get(pnginfo_section_name, None)
|
||||
|
||||
def _atomically_save_image(image_to_save, filename_without_extension, extension):
|
||||
# save image with .tmp extension to avoid race condition when another process detects new image in the directory
|
||||
temp_file_path = filename_without_extension + ".tmp"
|
||||
image_format = Image.registered_extensions()[extension]
|
||||
|
||||
if extension.lower() == '.png':
|
||||
pnginfo_data = PngImagePlugin.PngInfo()
|
||||
if opts.enable_pnginfo:
|
||||
for k, v in params.pnginfo.items():
|
||||
pnginfo_data.add_text(k, str(v))
|
||||
|
||||
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
|
||||
|
||||
elif extension.lower() in (".jpg", ".jpeg", ".webp"):
|
||||
if image_to_save.mode == 'RGBA':
|
||||
image_to_save = image_to_save.convert("RGB")
|
||||
|
||||
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
|
||||
|
||||
if opts.enable_pnginfo and info is not None:
|
||||
exif_bytes = piexif.dump({
|
||||
"Exif": {
|
||||
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode")
|
||||
},
|
||||
})
|
||||
|
||||
if extension.lower() in ("jpg", "jpeg", "webp"):
|
||||
image.save(fullfn, quality=opts.jpeg_quality)
|
||||
if opts.enable_pnginfo and info is not None:
|
||||
piexif.insert(exif_bytes(), fullfn)
|
||||
piexif.insert(exif_bytes, temp_file_path)
|
||||
else:
|
||||
image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo)
|
||||
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
|
||||
|
||||
# atomically rename the file with correct extension
|
||||
os.replace(temp_file_path, filename_without_extension + extension)
|
||||
|
||||
fullfn_without_extension, extension = os.path.splitext(params.filename)
|
||||
_atomically_save_image(image, fullfn_without_extension, extension)
|
||||
|
||||
image.already_saved_as = fullfn
|
||||
|
||||
target_side_length = 4000
|
||||
oversize = image.width > target_side_length or image.height > target_side_length
|
||||
@ -456,9 +578,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
elif oversize:
|
||||
image = image.resize((image.width * target_side_length // image.height, target_side_length), LANCZOS)
|
||||
|
||||
image.save(fullfn_without_extension + ".jpg", quality=opts.jpeg_quality)
|
||||
if opts.enable_pnginfo and info is not None:
|
||||
piexif.insert(exif_bytes(), fullfn_without_extension + ".jpg")
|
||||
_atomically_save_image(image, fullfn_without_extension, ".jpg")
|
||||
|
||||
if opts.save_txt and info is not None:
|
||||
txt_fullfn = f"{fullfn_without_extension}.txt"
|
||||
@ -467,13 +587,50 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
else:
|
||||
txt_fullfn = None
|
||||
|
||||
script_callbacks.image_saved_callback(params)
|
||||
|
||||
return fullfn, txt_fullfn
|
||||
|
||||
|
||||
def read_info_from_image(image):
|
||||
items = image.info or {}
|
||||
|
||||
geninfo = items.pop('parameters', None)
|
||||
|
||||
if "exif" in items:
|
||||
exif = piexif.load(items["exif"])
|
||||
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
|
||||
try:
|
||||
exif_comment = piexif.helper.UserComment.load(exif_comment)
|
||||
except ValueError:
|
||||
exif_comment = exif_comment.decode('utf8', errors="ignore")
|
||||
|
||||
items['exif comment'] = exif_comment
|
||||
geninfo = exif_comment
|
||||
|
||||
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
|
||||
'loop', 'background', 'timestamp', 'duration']:
|
||||
items.pop(field, None)
|
||||
|
||||
if items.get("Software", None) == "NovelAI":
|
||||
try:
|
||||
json_info = json.loads(items["Comment"])
|
||||
sampler = sd_samplers.samplers_map.get(json_info["sampler"], "Euler a")
|
||||
|
||||
geninfo = f"""{items["Description"]}
|
||||
Negative prompt: {json_info["uc"]}
|
||||
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
|
||||
except Exception:
|
||||
print("Error parsing NovelAI image generation parameters:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
return geninfo, items
|
||||
|
||||
|
||||
def image_data(data):
|
||||
try:
|
||||
image = Image.open(io.BytesIO(data))
|
||||
textinfo = image.text["parameters"]
|
||||
textinfo, _ = read_info_from_image(image)
|
||||
return textinfo, None
|
||||
except Exception:
|
||||
pass
|
||||
@ -487,3 +644,14 @@ def image_data(data):
|
||||
pass
|
||||
|
||||
return '', None
|
||||
|
||||
|
||||
def flatten(img, bgcolor):
|
||||
"""replaces transparency with bgcolor (example: "#ffffff"), returning an RGB mode image with no transparency"""
|
||||
|
||||
if img.mode == "RGBA":
|
||||
background = Image.new('RGBA', img.size, bgcolor)
|
||||
background.paste(img, mask=img)
|
||||
img = background
|
||||
|
||||
return img.convert('RGB')
|
||||
|
@ -1,183 +0,0 @@
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
def traverse_all_files(output_dir, image_list, curr_dir=None):
|
||||
curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir)
|
||||
try:
|
||||
f_list = os.listdir(curr_path)
|
||||
except:
|
||||
if curr_dir[-10:].rfind(".") > 0 and curr_dir[-4:] != ".txt":
|
||||
image_list.append(curr_dir)
|
||||
return image_list
|
||||
for file in f_list:
|
||||
file = file if curr_dir is None else os.path.join(curr_dir, file)
|
||||
file_path = os.path.join(curr_path, file)
|
||||
if file[-4:] == ".txt":
|
||||
pass
|
||||
elif os.path.isfile(file_path) and file[-10:].rfind(".") > 0:
|
||||
image_list.append(file)
|
||||
else:
|
||||
image_list = traverse_all_files(output_dir, image_list, file)
|
||||
return image_list
|
||||
|
||||
|
||||
def get_recent_images(dir_name, page_index, step, image_index, tabname):
|
||||
page_index = int(page_index)
|
||||
image_list = []
|
||||
if not os.path.exists(dir_name):
|
||||
pass
|
||||
elif os.path.isdir(dir_name):
|
||||
image_list = traverse_all_files(dir_name, image_list)
|
||||
image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file)))
|
||||
else:
|
||||
print(f'ERROR: "{dir_name}" is not a directory. Check the path in the settings.', file=sys.stderr)
|
||||
num = 48 if tabname != "extras" else 12
|
||||
max_page_index = len(image_list) // num + 1
|
||||
page_index = max_page_index if page_index == -1 else page_index + step
|
||||
page_index = 1 if page_index < 1 else page_index
|
||||
page_index = max_page_index if page_index > max_page_index else page_index
|
||||
idx_frm = (page_index - 1) * num
|
||||
image_list = image_list[idx_frm:idx_frm + num]
|
||||
image_index = int(image_index)
|
||||
if image_index < 0 or image_index > len(image_list) - 1:
|
||||
current_file = None
|
||||
hidden = None
|
||||
else:
|
||||
current_file = image_list[int(image_index)]
|
||||
hidden = os.path.join(dir_name, current_file)
|
||||
return [os.path.join(dir_name, file) for file in image_list], page_index, image_list, current_file, hidden, ""
|
||||
|
||||
|
||||
def first_page_click(dir_name, page_index, image_index, tabname):
|
||||
return get_recent_images(dir_name, 1, 0, image_index, tabname)
|
||||
|
||||
|
||||
def end_page_click(dir_name, page_index, image_index, tabname):
|
||||
return get_recent_images(dir_name, -1, 0, image_index, tabname)
|
||||
|
||||
|
||||
def prev_page_click(dir_name, page_index, image_index, tabname):
|
||||
return get_recent_images(dir_name, page_index, -1, image_index, tabname)
|
||||
|
||||
|
||||
def next_page_click(dir_name, page_index, image_index, tabname):
|
||||
return get_recent_images(dir_name, page_index, 1, image_index, tabname)
|
||||
|
||||
|
||||
def page_index_change(dir_name, page_index, image_index, tabname):
|
||||
return get_recent_images(dir_name, page_index, 0, image_index, tabname)
|
||||
|
||||
|
||||
def show_image_info(num, image_path, filenames):
|
||||
# print(f"select image {num}")
|
||||
file = filenames[int(num)]
|
||||
return file, num, os.path.join(image_path, file)
|
||||
|
||||
|
||||
def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, image_index):
|
||||
if name == "":
|
||||
return filenames, delete_num
|
||||
else:
|
||||
delete_num = int(delete_num)
|
||||
index = list(filenames).index(name)
|
||||
i = 0
|
||||
new_file_list = []
|
||||
for name in filenames:
|
||||
if i >= index and i < index + delete_num:
|
||||
path = os.path.join(dir_name, name)
|
||||
if os.path.exists(path):
|
||||
print(f"Delete file {path}")
|
||||
os.remove(path)
|
||||
txt_file = os.path.splitext(path)[0] + ".txt"
|
||||
if os.path.exists(txt_file):
|
||||
os.remove(txt_file)
|
||||
else:
|
||||
print(f"Not exists file {path}")
|
||||
else:
|
||||
new_file_list.append(name)
|
||||
i += 1
|
||||
return new_file_list, 1
|
||||
|
||||
|
||||
def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
|
||||
if opts.outdir_samples != "":
|
||||
dir_name = opts.outdir_samples
|
||||
elif tabname == "txt2img":
|
||||
dir_name = opts.outdir_txt2img_samples
|
||||
elif tabname == "img2img":
|
||||
dir_name = opts.outdir_img2img_samples
|
||||
elif tabname == "extras":
|
||||
dir_name = opts.outdir_extras_samples
|
||||
else:
|
||||
return
|
||||
with gr.Row():
|
||||
renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page")
|
||||
first_page = gr.Button('First Page')
|
||||
prev_page = gr.Button('Prev Page')
|
||||
page_index = gr.Number(value=1, label="Page Index")
|
||||
next_page = gr.Button('Next Page')
|
||||
end_page = gr.Button('End Page')
|
||||
with gr.Row(elem_id=tabname + "_images_history"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6)
|
||||
with gr.Row():
|
||||
delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next")
|
||||
delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button")
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
|
||||
pnginfo_send_to_img2img = gr.Button('Send to img2img')
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
img_file_info = gr.Textbox(label="Generate Info", interactive=False)
|
||||
img_file_name = gr.Textbox(label="File Name", interactive=False)
|
||||
with gr.Row():
|
||||
# hiden items
|
||||
|
||||
img_path = gr.Textbox(dir_name.rstrip("/"), visible=False)
|
||||
tabname_box = gr.Textbox(tabname, visible=False)
|
||||
image_index = gr.Textbox(value=-1, visible=False)
|
||||
set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index", visible=False)
|
||||
filenames = gr.State()
|
||||
hidden = gr.Image(type="pil", visible=False)
|
||||
info1 = gr.Textbox(visible=False)
|
||||
info2 = gr.Textbox(visible=False)
|
||||
|
||||
# turn pages
|
||||
gallery_inputs = [img_path, page_index, image_index, tabname_box]
|
||||
gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hidden, img_file_name]
|
||||
|
||||
first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
prev_page.click(prev_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
end_page.click(end_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
page_index.submit(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
renew_page.click(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
# page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index])
|
||||
|
||||
# other funcitons
|
||||
set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hidden])
|
||||
img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None)
|
||||
delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num])
|
||||
hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
|
||||
|
||||
# pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
|
||||
switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img')
|
||||
switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img')
|
||||
|
||||
|
||||
def create_history_tabs(gr, opts, run_pnginfo, switch_dict):
|
||||
with gr.Blocks(analytics_enabled=False) as images_history:
|
||||
with gr.Tabs() as tabs:
|
||||
with gr.Tab("txt2img history"):
|
||||
with gr.Blocks(analytics_enabled=False) as images_history_txt2img:
|
||||
show_images_history(gr, opts, "txt2img", run_pnginfo, switch_dict)
|
||||
with gr.Tab("img2img history"):
|
||||
with gr.Blocks(analytics_enabled=False) as images_history_img2img:
|
||||
show_images_history(gr, opts, "img2img", run_pnginfo, switch_dict)
|
||||
with gr.Tab("extras history"):
|
||||
with gr.Blocks(analytics_enabled=False) as images_history_img2img:
|
||||
show_images_history(gr, opts, "extras", run_pnginfo, switch_dict)
|
||||
return images_history
|
@ -4,9 +4,9 @@ import sys
|
||||
import traceback
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps, ImageChops
|
||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
|
||||
|
||||
from modules import devices
|
||||
from modules import devices, sd_samplers
|
||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.shared import opts, state
|
||||
import modules.shared as shared
|
||||
@ -19,7 +19,7 @@ import modules.scripts
|
||||
def process_batch(p, input_dir, output_dir, args):
|
||||
processing.fix_seed(p)
|
||||
|
||||
images = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
|
||||
images = shared.listfiles(input_dir)
|
||||
|
||||
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
||||
|
||||
@ -39,6 +39,8 @@ def process_batch(p, input_dir, output_dir, args):
|
||||
break
|
||||
|
||||
img = Image.open(image)
|
||||
# Use the EXIF orientation of photos taken by smartphones.
|
||||
img = ImageOps.exif_transpose(img)
|
||||
p.init_images = [img] * p.batch_size
|
||||
|
||||
proc = modules.scripts.scripts_img2img.run(p, *args)
|
||||
@ -53,27 +55,48 @@ def process_batch(p, input_dir, output_dir, args):
|
||||
filename = f"{left}-{n}{right}"
|
||||
|
||||
if not save_normally:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
processed_image.save(os.path.join(output_dir, filename))
|
||||
|
||||
|
||||
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
|
||||
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_with_mask_orig, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
|
||||
is_inpaint = mode == 1
|
||||
is_batch = mode == 2
|
||||
|
||||
if is_inpaint:
|
||||
# Drawn mask
|
||||
if mask_mode == 0:
|
||||
image = init_img_with_mask['image']
|
||||
mask = init_img_with_mask['mask']
|
||||
is_mask_sketch = isinstance(init_img_with_mask, dict)
|
||||
is_mask_paint = not is_mask_sketch
|
||||
if is_mask_sketch:
|
||||
# Sketch: mask iff. not transparent
|
||||
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
|
||||
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
|
||||
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
|
||||
image = image.convert('RGB')
|
||||
else:
|
||||
# Color-sketch: mask iff. painted over
|
||||
image = init_img_with_mask
|
||||
orig = init_img_with_mask_orig or init_img_with_mask
|
||||
pred = np.any(np.array(image) != np.array(orig), axis=-1)
|
||||
mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
|
||||
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
|
||||
blur = ImageFilter.GaussianBlur(mask_blur)
|
||||
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
|
||||
|
||||
image = image.convert("RGB")
|
||||
# Uploaded mask
|
||||
else:
|
||||
image = init_img_inpaint
|
||||
mask = init_mask_inpaint
|
||||
# No mask
|
||||
else:
|
||||
image = init_img
|
||||
mask = None
|
||||
|
||||
# Use the EXIF orientation of photos taken by smartphones.
|
||||
if image is not None:
|
||||
image = ImageOps.exif_transpose(image)
|
||||
|
||||
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||
|
||||
p = StableDiffusionProcessingImg2Img(
|
||||
@ -89,7 +112,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
|
||||
seed_resize_from_h=seed_resize_from_h,
|
||||
seed_resize_from_w=seed_resize_from_w,
|
||||
seed_enable_extras=seed_enable_extras,
|
||||
sampler_index=sampler_index,
|
||||
sampler_name=sd_samplers.samplers_for_img2img[sampler_index].name,
|
||||
batch_size=batch_size,
|
||||
n_iter=n_iter,
|
||||
steps=steps,
|
||||
@ -109,6 +132,9 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
|
||||
inpainting_mask_invert=inpainting_mask_invert,
|
||||
)
|
||||
|
||||
p.scripts = modules.scripts.scripts_txt2img
|
||||
p.script_args = args
|
||||
|
||||
if shared.cmd_opts.enable_console_prompts:
|
||||
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
||||
|
||||
@ -125,6 +151,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
|
||||
if processed is None:
|
||||
processed = process_images(p)
|
||||
|
||||
p.close()
|
||||
|
||||
shared.total_tqdm.clear()
|
||||
|
||||
generation_info_js = processed.js()
|
||||
@ -134,4 +162,4 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
|
||||
if opts.do_not_show_images:
|
||||
processed.images = []
|
||||
|
||||
return processed.images, generation_info_js, plaintext_to_html(processed.info)
|
||||
return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
|
||||
|
5
modules/import_hook.py
Normal file
5
modules/import_hook.py
Normal file
@ -0,0 +1,5 @@
|
||||
import sys
|
||||
|
||||
# this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it
|
||||
if "--xformers" not in "".join(sys.argv):
|
||||
sys.modules["xformers"] = None
|
@ -1,4 +1,3 @@
|
||||
import contextlib
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
@ -11,10 +10,9 @@ from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import devices, paths, lowvram
|
||||
from modules import devices, paths, lowvram, modelloader
|
||||
|
||||
blip_image_eval_size = 384
|
||||
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
|
||||
clip_model_name = 'ViT-L/14'
|
||||
|
||||
Category = namedtuple("Category", ["name", "topn", "items"])
|
||||
@ -28,9 +26,11 @@ class InterrogateModels:
|
||||
clip_preprocess = None
|
||||
categories = None
|
||||
dtype = None
|
||||
running_on_cpu = None
|
||||
|
||||
def __init__(self, content_dir):
|
||||
self.categories = []
|
||||
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
|
||||
|
||||
if os.path.exists(content_dir):
|
||||
for filename in os.listdir(content_dir):
|
||||
@ -45,7 +45,14 @@ class InterrogateModels:
|
||||
def load_blip_model(self):
|
||||
import models.blip
|
||||
|
||||
blip_model = models.blip.blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
|
||||
files = modelloader.load_models(
|
||||
model_path=os.path.join(paths.models_path, "BLIP"),
|
||||
model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
|
||||
ext_filter=[".pth"],
|
||||
download_name='model_base_caption_capfilt_large.pth',
|
||||
)
|
||||
|
||||
blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
|
||||
blip_model.eval()
|
||||
|
||||
return blip_model
|
||||
@ -53,7 +60,11 @@ class InterrogateModels:
|
||||
def load_clip_model(self):
|
||||
import clip
|
||||
|
||||
model, preprocess = clip.load(clip_model_name)
|
||||
if self.running_on_cpu:
|
||||
model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
|
||||
else:
|
||||
model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)
|
||||
|
||||
model.eval()
|
||||
model = model.to(devices.device_interrogate)
|
||||
|
||||
@ -62,14 +73,14 @@ class InterrogateModels:
|
||||
def load(self):
|
||||
if self.blip_model is None:
|
||||
self.blip_model = self.load_blip_model()
|
||||
if not shared.cmd_opts.no_half:
|
||||
if not shared.cmd_opts.no_half and not self.running_on_cpu:
|
||||
self.blip_model = self.blip_model.half()
|
||||
|
||||
self.blip_model = self.blip_model.to(devices.device_interrogate)
|
||||
|
||||
if self.clip_model is None:
|
||||
self.clip_model, self.clip_preprocess = self.load_clip_model()
|
||||
if not shared.cmd_opts.no_half:
|
||||
if not shared.cmd_opts.no_half and not self.running_on_cpu:
|
||||
self.clip_model = self.clip_model.half()
|
||||
|
||||
self.clip_model = self.clip_model.to(devices.device_interrogate)
|
||||
@ -124,8 +135,9 @@ class InterrogateModels:
|
||||
return caption[0]
|
||||
|
||||
def interrogate(self, pil_image):
|
||||
res = None
|
||||
|
||||
res = ""
|
||||
shared.state.begin()
|
||||
shared.state.job = 'interrogate'
|
||||
try:
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
@ -142,8 +154,7 @@ class InterrogateModels:
|
||||
|
||||
clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
|
||||
|
||||
precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
|
||||
with torch.no_grad(), precision_scope("cuda"):
|
||||
with torch.no_grad(), devices.autocast():
|
||||
image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
|
||||
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
@ -162,10 +173,11 @@ class InterrogateModels:
|
||||
res += ", " + match
|
||||
|
||||
except Exception:
|
||||
print(f"Error interrogating", file=sys.stderr)
|
||||
print("Error interrogating", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
res += "<error>"
|
||||
|
||||
self.unload()
|
||||
shared.state.end()
|
||||
|
||||
return res
|
||||
|
@ -3,6 +3,7 @@ import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
|
||||
localizations = {}
|
||||
|
||||
|
||||
@ -16,6 +17,11 @@ def list_localizations(dirname):
|
||||
|
||||
localizations[fn] = os.path.join(dirname, file)
|
||||
|
||||
from modules import scripts
|
||||
for file in scripts.list_scripts("localizations", ".json"):
|
||||
fn, ext = os.path.splitext(file.filename)
|
||||
localizations[fn] = file.path
|
||||
|
||||
|
||||
def localization_js(current_localization_name):
|
||||
fn = localizations.get(current_localization_name, None)
|
||||
|
@ -1,9 +1,8 @@
|
||||
import torch
|
||||
from modules.devices import get_optimal_device
|
||||
from modules import devices
|
||||
|
||||
module_in_gpu = None
|
||||
cpu = torch.device("cpu")
|
||||
device = gpu = get_optimal_device()
|
||||
|
||||
|
||||
def send_everything_to_cpu():
|
||||
@ -33,34 +32,49 @@ def setup_for_low_vram(sd_model, use_medvram):
|
||||
if module_in_gpu is not None:
|
||||
module_in_gpu.to(cpu)
|
||||
|
||||
module.to(gpu)
|
||||
module.to(devices.device)
|
||||
module_in_gpu = module
|
||||
|
||||
# see below for register_forward_pre_hook;
|
||||
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
|
||||
# useless here, and we just replace those methods
|
||||
def first_stage_model_encode_wrap(self, encoder, x):
|
||||
send_me_to_gpu(self, None)
|
||||
return encoder(x)
|
||||
|
||||
def first_stage_model_decode_wrap(self, decoder, z):
|
||||
send_me_to_gpu(self, None)
|
||||
return decoder(z)
|
||||
first_stage_model = sd_model.first_stage_model
|
||||
first_stage_model_encode = sd_model.first_stage_model.encode
|
||||
first_stage_model_decode = sd_model.first_stage_model.decode
|
||||
|
||||
# remove three big modules, cond, first_stage, and unet from the model and then
|
||||
def first_stage_model_encode_wrap(x):
|
||||
send_me_to_gpu(first_stage_model, None)
|
||||
return first_stage_model_encode(x)
|
||||
|
||||
def first_stage_model_decode_wrap(z):
|
||||
send_me_to_gpu(first_stage_model, None)
|
||||
return first_stage_model_decode(z)
|
||||
|
||||
# for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field
|
||||
if hasattr(sd_model.cond_stage_model, 'model'):
|
||||
sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
|
||||
|
||||
# remove four big modules, cond, first_stage, depth (if applicable), and unet from the model and then
|
||||
# send the model to GPU. Then put modules back. the modules will be in CPU.
|
||||
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
|
||||
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
|
||||
sd_model.to(device)
|
||||
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
|
||||
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), sd_model.model
|
||||
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = None, None, None, None
|
||||
sd_model.to(devices.device)
|
||||
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = stored
|
||||
|
||||
# register hooks for those the first two models
|
||||
# register hooks for those the first three models
|
||||
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
|
||||
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
|
||||
sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
|
||||
sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
|
||||
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
|
||||
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
|
||||
if sd_model.depth_model:
|
||||
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
|
||||
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
||||
|
||||
if hasattr(sd_model.cond_stage_model, 'model'):
|
||||
sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer
|
||||
del sd_model.cond_stage_model.transformer
|
||||
|
||||
if use_medvram:
|
||||
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
|
||||
else:
|
||||
@ -70,7 +84,7 @@ def setup_for_low_vram(sd_model, use_medvram):
|
||||
# so that only one of them is in GPU at a time
|
||||
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
|
||||
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
|
||||
sd_model.model.to(device)
|
||||
sd_model.model.to(devices.device)
|
||||
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
|
||||
|
||||
# install hooks for bits of third model
|
||||
|
@ -49,7 +49,7 @@ def expand_crop_region(crop_region, processing_width, processing_height, image_w
|
||||
ratio_processing = processing_width / processing_height
|
||||
|
||||
if ratio_crop_region > ratio_processing:
|
||||
desired_height = (x2 - x1) * ratio_processing
|
||||
desired_height = (x2 - x1) / ratio_processing
|
||||
desired_height_diff = int(desired_height - (y2-y1))
|
||||
y1 -= desired_height_diff//2
|
||||
y2 += desired_height_diff - desired_height_diff//2
|
||||
|
@ -71,10 +71,13 @@ class MemUsageMonitor(threading.Thread):
|
||||
def read(self):
|
||||
if not self.disabled:
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
self.data["free"] = free
|
||||
self.data["total"] = total
|
||||
|
||||
torch_stats = torch.cuda.memory_stats(self.device)
|
||||
self.data["active"] = torch_stats["active.all.current"]
|
||||
self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
|
||||
self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
|
||||
self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
|
||||
self.data["system_peak"] = total - self.data["min_free"]
|
||||
|
||||
|
@ -82,9 +82,13 @@ def cleanup_models():
|
||||
src_path = models_path
|
||||
dest_path = os.path.join(models_path, "Stable-diffusion")
|
||||
move_files(src_path, dest_path, ".ckpt")
|
||||
move_files(src_path, dest_path, ".safetensors")
|
||||
src_path = os.path.join(root_path, "ESRGAN")
|
||||
dest_path = os.path.join(models_path, "ESRGAN")
|
||||
move_files(src_path, dest_path)
|
||||
src_path = os.path.join(models_path, "BSRGAN")
|
||||
dest_path = os.path.join(models_path, "ESRGAN")
|
||||
move_files(src_path, dest_path, ".pth")
|
||||
src_path = os.path.join(root_path, "gfpgan")
|
||||
dest_path = os.path.join(models_path, "GFPGAN")
|
||||
move_files(src_path, dest_path)
|
||||
@ -119,11 +123,27 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
|
||||
pass
|
||||
|
||||
|
||||
builtin_upscaler_classes = []
|
||||
forbidden_upscaler_classes = set()
|
||||
|
||||
|
||||
def list_builtin_upscalers():
|
||||
load_upscalers()
|
||||
|
||||
builtin_upscaler_classes.clear()
|
||||
builtin_upscaler_classes.extend(Upscaler.__subclasses__())
|
||||
|
||||
|
||||
def forbid_loaded_nonbuiltin_upscalers():
|
||||
for cls in Upscaler.__subclasses__():
|
||||
if cls not in builtin_upscaler_classes:
|
||||
forbidden_upscaler_classes.add(cls)
|
||||
|
||||
|
||||
def load_upscalers():
|
||||
sd = shared.script_path
|
||||
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
|
||||
# so we'll try to import any _model.py files before looking in __subclasses__
|
||||
modules_dir = os.path.join(sd, "modules")
|
||||
modules_dir = os.path.join(shared.script_path, "modules")
|
||||
for file in os.listdir(modules_dir):
|
||||
if "_model.py" in file:
|
||||
model_name = file.replace("_model.py", "")
|
||||
@ -132,22 +152,16 @@ def load_upscalers():
|
||||
importlib.import_module(full_model)
|
||||
except:
|
||||
pass
|
||||
|
||||
datas = []
|
||||
c_o = vars(shared.cmd_opts)
|
||||
commandline_options = vars(shared.cmd_opts)
|
||||
for cls in Upscaler.__subclasses__():
|
||||
if cls in forbidden_upscaler_classes:
|
||||
continue
|
||||
|
||||
name = cls.__name__
|
||||
module_name = cls.__module__
|
||||
module = importlib.import_module(module_name)
|
||||
class_ = getattr(module, name)
|
||||
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
|
||||
opt_string = None
|
||||
try:
|
||||
if cmd_name in c_o:
|
||||
opt_string = c_o[cmd_name]
|
||||
except:
|
||||
pass
|
||||
scaler = class_(opt_string)
|
||||
for child in scaler.scalers:
|
||||
datas.append(child)
|
||||
scaler = cls(commandline_options.get(cmd_name, None))
|
||||
datas += scaler.scalers
|
||||
|
||||
shared.sd_upscalers = datas
|
||||
|
@ -1,14 +1,23 @@
|
||||
from pyngrok import ngrok, conf, exception
|
||||
|
||||
|
||||
def connect(token, port, region):
|
||||
if token == None:
|
||||
account = None
|
||||
if token is None:
|
||||
token = 'None'
|
||||
else:
|
||||
if ':' in token:
|
||||
# token = authtoken:username:password
|
||||
account = token.split(':')[1] + ':' + token.split(':')[-1]
|
||||
token = token.split(':')[0]
|
||||
|
||||
config = conf.PyngrokConfig(
|
||||
auth_token=token, region=region
|
||||
)
|
||||
try:
|
||||
public_url = ngrok.connect(port, pyngrok_config=config).public_url
|
||||
if account is None:
|
||||
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
|
||||
else:
|
||||
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True, auth=account).public_url
|
||||
except exception.PyngrokNgrokError:
|
||||
print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
|
||||
f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
|
||||
|
@ -9,7 +9,7 @@ sys.path.insert(0, script_path)
|
||||
|
||||
# search for directory of stable diffusion in following places
|
||||
sd_path = None
|
||||
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)]
|
||||
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)]
|
||||
for possible_sd_path in possible_sd_paths:
|
||||
if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
|
||||
sd_path = os.path.abspath(possible_sd_path)
|
||||
|
@ -2,6 +2,7 @@ import json
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
@ -12,15 +13,21 @@ from skimage import exposure
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import modules.sd_hijack
|
||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram
|
||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks
|
||||
from modules.sd_hijack import model_hijack
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
import modules.shared as shared
|
||||
import modules.face_restoration
|
||||
import modules.images as images
|
||||
import modules.styles
|
||||
import modules.sd_models as sd_models
|
||||
import modules.sd_vae as sd_vae
|
||||
import logging
|
||||
from ldm.data.util import AddMiDaS
|
||||
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
|
||||
|
||||
from einops import repeat, rearrange
|
||||
from blendmodes.blend import blendLayers, BlendType
|
||||
|
||||
# some of those options should not be changed at all because they would break the model, so I removed them from options.
|
||||
opt_C = 4
|
||||
@ -33,34 +40,68 @@ def setup_color_correction(image):
|
||||
return correction_target
|
||||
|
||||
|
||||
def apply_color_correction(correction, image):
|
||||
def apply_color_correction(correction, original_image):
|
||||
logging.info("Applying color correction.")
|
||||
image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
|
||||
cv2.cvtColor(
|
||||
np.asarray(image),
|
||||
np.asarray(original_image),
|
||||
cv2.COLOR_RGB2LAB
|
||||
),
|
||||
correction,
|
||||
channel_axis=2
|
||||
), cv2.COLOR_LAB2RGB).astype("uint8"))
|
||||
|
||||
image = blendLayers(image, original_image, BlendType.LUMINOSITY)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def get_correct_sampler(p):
|
||||
if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img):
|
||||
return sd_samplers.samplers
|
||||
elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
|
||||
return sd_samplers.samplers_for_img2img
|
||||
elif isinstance(p, modules.api.processing.StableDiffusionProcessingAPI):
|
||||
return sd_samplers.samplers
|
||||
def apply_overlay(image, paste_loc, index, overlays):
|
||||
if overlays is None or index >= len(overlays):
|
||||
return image
|
||||
|
||||
overlay = overlays[index]
|
||||
|
||||
if paste_loc is not None:
|
||||
x, y, w, h = paste_loc
|
||||
base_image = Image.new('RGBA', (overlay.width, overlay.height))
|
||||
image = images.resize_image(1, image, w, h)
|
||||
base_image.paste(image, (x, y))
|
||||
image = base_image
|
||||
|
||||
image = image.convert('RGBA')
|
||||
image.alpha_composite(overlay)
|
||||
image = image.convert('RGB')
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def txt2img_image_conditioning(sd_model, x, width, height):
|
||||
if sd_model.model.conditioning_key not in {'hybrid', 'concat'}:
|
||||
# Dummy zero conditioning if we're not using inpainting model.
|
||||
# Still takes up a bit of memory, but no encoder call.
|
||||
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
||||
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
|
||||
|
||||
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
||||
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
||||
image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
|
||||
|
||||
# Add the fake full 1s mask to the first dimension.
|
||||
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
||||
image_conditioning = image_conditioning.to(x.dtype)
|
||||
|
||||
return image_conditioning
|
||||
|
||||
|
||||
class StableDiffusionProcessing():
|
||||
"""
|
||||
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
||||
|
||||
"""
|
||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str="", styles: List[str]=None, seed: int=-1, subseed: int=-1, subseed_strength: float=0, seed_resize_from_h: int=-1, seed_resize_from_w: int=-1, seed_enable_extras: bool=True, sampler_index: int=0, batch_size: int=1, n_iter: int=1, steps:int =50, cfg_scale:float=7.0, width:int=512, height:int=512, restore_faces:bool=False, tiling:bool=False, do_not_save_samples:bool=False, do_not_save_grid:bool=False, extra_generation_params: Dict[Any,Any]=None, overlay_images: Any=None, negative_prompt: str=None, eta: float =None, do_not_reload_embeddings: bool=False, denoising_strength: float = 0, ddim_discretize: str = "uniform", s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0):
|
||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None):
|
||||
if sampler_index is not None:
|
||||
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
|
||||
|
||||
self.sd_model = sd_model
|
||||
self.outpath_samples: str = outpath_samples
|
||||
self.outpath_grids: str = outpath_grids
|
||||
@ -73,7 +114,7 @@ class StableDiffusionProcessing():
|
||||
self.subseed_strength: float = subseed_strength
|
||||
self.seed_resize_from_h: int = seed_resize_from_h
|
||||
self.seed_resize_from_w: int = seed_resize_from_w
|
||||
self.sampler_index: int = sampler_index
|
||||
self.sampler_name: str = sampler_name
|
||||
self.batch_size: int = batch_size
|
||||
self.n_iter: int = n_iter
|
||||
self.steps: int = steps
|
||||
@ -90,13 +131,16 @@ class StableDiffusionProcessing():
|
||||
self.do_not_reload_embeddings = do_not_reload_embeddings
|
||||
self.paste_to = None
|
||||
self.color_corrections = None
|
||||
self.denoising_strength: float = 0
|
||||
self.denoising_strength: float = denoising_strength
|
||||
self.sampler_noise_scheduler_override = None
|
||||
self.ddim_discretize = opts.ddim_discretize
|
||||
self.ddim_discretize = ddim_discretize or opts.ddim_discretize
|
||||
self.s_churn = s_churn or opts.s_churn
|
||||
self.s_tmin = s_tmin or opts.s_tmin
|
||||
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
|
||||
self.s_noise = s_noise or opts.s_noise
|
||||
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
|
||||
self.override_settings_restore_afterwards = override_settings_restore_afterwards
|
||||
self.is_using_inpainting_conditioning = False
|
||||
|
||||
if not seed_enable_extras:
|
||||
self.subseed = -1
|
||||
@ -104,16 +148,100 @@ class StableDiffusionProcessing():
|
||||
self.seed_resize_from_h = 0
|
||||
self.seed_resize_from_w = 0
|
||||
|
||||
self.scripts = None
|
||||
self.script_args = None
|
||||
self.all_prompts = None
|
||||
self.all_negative_prompts = None
|
||||
self.all_seeds = None
|
||||
self.all_subseeds = None
|
||||
self.iteration = 0
|
||||
|
||||
def txt2img_image_conditioning(self, x, width=None, height=None):
|
||||
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
|
||||
|
||||
return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
|
||||
|
||||
def depth2img_image_conditioning(self, source_image):
|
||||
# Use the AddMiDaS helper to Format our source image to suit the MiDaS model
|
||||
transformer = AddMiDaS(model_type="dpt_hybrid")
|
||||
transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
|
||||
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
|
||||
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
|
||||
|
||||
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
|
||||
conditioning = torch.nn.functional.interpolate(
|
||||
self.sd_model.depth_model(midas_in),
|
||||
size=conditioning_image.shape[2:],
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
(depth_min, depth_max) = torch.aminmax(conditioning)
|
||||
conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
|
||||
return conditioning
|
||||
|
||||
def inpainting_image_conditioning(self, source_image, latent_image, image_mask = None):
|
||||
self.is_using_inpainting_conditioning = True
|
||||
|
||||
# Handle the different mask inputs
|
||||
if image_mask is not None:
|
||||
if torch.is_tensor(image_mask):
|
||||
conditioning_mask = image_mask
|
||||
else:
|
||||
conditioning_mask = np.array(image_mask.convert("L"))
|
||||
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
|
||||
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
|
||||
|
||||
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
|
||||
conditioning_mask = torch.round(conditioning_mask)
|
||||
else:
|
||||
conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
|
||||
|
||||
# Create another latent image, this time with a masked version of the original input.
|
||||
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
|
||||
conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype)
|
||||
conditioning_image = torch.lerp(
|
||||
source_image,
|
||||
source_image * (1.0 - conditioning_mask),
|
||||
getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
|
||||
)
|
||||
|
||||
# Encode the new masked image using first stage of network.
|
||||
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
|
||||
|
||||
# Create the concatenated conditioning tensor to be fed to `c_concat`
|
||||
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
|
||||
conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
|
||||
image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
|
||||
image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)
|
||||
|
||||
return image_conditioning
|
||||
|
||||
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
|
||||
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
|
||||
# identify itself with a field common to all models. The conditioning_key is also hybrid.
|
||||
if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
|
||||
return self.depth2img_image_conditioning(source_image)
|
||||
|
||||
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
||||
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
|
||||
|
||||
# Dummy zero conditioning if we're not using inpainting or depth model.
|
||||
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
|
||||
|
||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||
pass
|
||||
|
||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||
raise NotImplementedError()
|
||||
|
||||
def close(self):
|
||||
self.sd_model = None
|
||||
self.sampler = None
|
||||
|
||||
|
||||
class Processed:
|
||||
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
|
||||
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
|
||||
self.images = images_list
|
||||
self.prompt = p.prompt
|
||||
self.negative_prompt = p.negative_prompt
|
||||
@ -121,10 +249,10 @@ class Processed:
|
||||
self.subseed = subseed
|
||||
self.subseed_strength = p.subseed_strength
|
||||
self.info = info
|
||||
self.comments = comments
|
||||
self.width = p.width
|
||||
self.height = p.height
|
||||
self.sampler_index = p.sampler_index
|
||||
self.sampler = sd_samplers.samplers[p.sampler_index].name
|
||||
self.sampler_name = p.sampler_name
|
||||
self.cfg_scale = p.cfg_scale
|
||||
self.steps = p.steps
|
||||
self.batch_size = p.batch_size
|
||||
@ -151,17 +279,20 @@ class Processed:
|
||||
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
|
||||
self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
|
||||
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
|
||||
self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
|
||||
|
||||
self.all_prompts = all_prompts or [self.prompt]
|
||||
self.all_seeds = all_seeds or [self.seed]
|
||||
self.all_subseeds = all_subseeds or [self.subseed]
|
||||
self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
|
||||
self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
|
||||
self.all_seeds = all_seeds or p.all_seeds or [self.seed]
|
||||
self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
|
||||
self.infotexts = infotexts or [info]
|
||||
|
||||
def js(self):
|
||||
obj = {
|
||||
"prompt": self.prompt,
|
||||
"prompt": self.all_prompts[0],
|
||||
"all_prompts": self.all_prompts,
|
||||
"negative_prompt": self.negative_prompt,
|
||||
"negative_prompt": self.all_negative_prompts[0],
|
||||
"all_negative_prompts": self.all_negative_prompts,
|
||||
"seed": self.seed,
|
||||
"all_seeds": self.all_seeds,
|
||||
"subseed": self.subseed,
|
||||
@ -169,8 +300,7 @@ class Processed:
|
||||
"subseed_strength": self.subseed_strength,
|
||||
"width": self.width,
|
||||
"height": self.height,
|
||||
"sampler_index": self.sampler_index,
|
||||
"sampler": self.sampler,
|
||||
"sampler_name": self.sampler_name,
|
||||
"cfg_scale": self.cfg_scale,
|
||||
"steps": self.steps,
|
||||
"batch_size": self.batch_size,
|
||||
@ -186,6 +316,7 @@ class Processed:
|
||||
"styles": self.styles,
|
||||
"job_timestamp": self.job_timestamp,
|
||||
"clip_skip": self.clip_skip,
|
||||
"is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
|
||||
}
|
||||
|
||||
return json.dumps(obj)
|
||||
@ -210,13 +341,14 @@ def slerp(val, low, high):
|
||||
|
||||
|
||||
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
|
||||
eta_noise_seed_delta = opts.eta_noise_seed_delta or 0
|
||||
xs = []
|
||||
|
||||
# if we have multiple seeds, this means we are working with batch size>1; this then
|
||||
# enables the generation of additional tensors with noise that the sampler will use during its processing.
|
||||
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
|
||||
# produce the same images as with two batches [100], [101].
|
||||
if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0):
|
||||
if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or eta_noise_seed_delta > 0):
|
||||
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
|
||||
else:
|
||||
sampler_noises = None
|
||||
@ -256,8 +388,8 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
||||
if sampler_noises is not None:
|
||||
cnt = p.sampler.number_of_needed_noises(p)
|
||||
|
||||
if opts.eta_noise_seed_delta > 0:
|
||||
torch.manual_seed(seed + opts.eta_noise_seed_delta)
|
||||
if eta_noise_seed_delta > 0:
|
||||
torch.manual_seed(seed + eta_noise_seed_delta)
|
||||
|
||||
for j in range(cnt):
|
||||
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
|
||||
@ -297,20 +429,23 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
||||
|
||||
generation_params = {
|
||||
"Steps": p.steps,
|
||||
"Sampler": get_correct_sampler(p)[p.sampler_index].name,
|
||||
"Sampler": p.sampler_name,
|
||||
"CFG scale": p.cfg_scale,
|
||||
"Seed": all_seeds[index],
|
||||
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
||||
"Size": f"{p.width}x{p.height}",
|
||||
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
||||
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
||||
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.filename.split('\\')[-1].split('.')[0]),
|
||||
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
|
||||
"Hypernet hash": (None if shared.loaded_hypernetwork is None else sd_models.model_hash(shared.loaded_hypernetwork.filename)),
|
||||
"Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
|
||||
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
||||
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
||||
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
||||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||
"Denoising strength": getattr(p, 'denoising_strength', None),
|
||||
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
||||
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
|
||||
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
||||
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
|
||||
@ -318,14 +453,38 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
||||
|
||||
generation_params.update(p.extra_generation_params)
|
||||
|
||||
generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
|
||||
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
||||
|
||||
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
|
||||
negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else ""
|
||||
|
||||
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
|
||||
|
||||
|
||||
def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
|
||||
|
||||
try:
|
||||
for k, v in p.override_settings.items():
|
||||
setattr(opts, k, v)
|
||||
if k == 'sd_hypernetwork': shared.reload_hypernetworks() # make onchange call for changing hypernet
|
||||
if k == 'sd_model_checkpoint': sd_models.reload_model_weights() # make onchange call for changing SD model
|
||||
if k == 'sd_vae': sd_vae.reload_vae_weights() # make onchange call for changing VAE
|
||||
|
||||
res = process_images_inner(p)
|
||||
|
||||
finally:
|
||||
# restore opts to original state
|
||||
if p.override_settings_restore_afterwards:
|
||||
for k, v in stored_opts.items():
|
||||
setattr(opts, k, v)
|
||||
if k == 'sd_hypernetwork': shared.reload_hypernetworks()
|
||||
if k == 'sd_model_checkpoint': sd_models.reload_model_weights()
|
||||
if k == 'sd_vae': sd_vae.reload_vae_weights()
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
||||
|
||||
if type(p.prompt) == list:
|
||||
@ -333,10 +492,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
else:
|
||||
assert p.prompt is not None
|
||||
|
||||
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
file.write(processed.infotext(p, 0))
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
seed = get_fixed_seed(p.seed)
|
||||
@ -347,57 +502,71 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
|
||||
comments = {}
|
||||
|
||||
shared.prompt_styles.apply_styles(p)
|
||||
|
||||
if type(p.prompt) == list:
|
||||
all_prompts = p.prompt
|
||||
p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt]
|
||||
else:
|
||||
all_prompts = p.batch_size * p.n_iter * [p.prompt]
|
||||
p.all_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)]
|
||||
|
||||
if type(p.negative_prompt) == list:
|
||||
p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.negative_prompt]
|
||||
else:
|
||||
p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
|
||||
|
||||
if type(seed) == list:
|
||||
all_seeds = seed
|
||||
p.all_seeds = seed
|
||||
else:
|
||||
all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
|
||||
p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
|
||||
|
||||
if type(subseed) == list:
|
||||
all_subseeds = subseed
|
||||
p.all_subseeds = subseed
|
||||
else:
|
||||
all_subseeds = [int(subseed) + x for x in range(len(all_prompts))]
|
||||
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
|
||||
|
||||
def infotext(iteration=0, position_in_batch=0):
|
||||
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
|
||||
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
|
||||
|
||||
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
file.write(processed.infotext(p, 0))
|
||||
|
||||
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
||||
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.process(p)
|
||||
|
||||
infotexts = []
|
||||
output_images = []
|
||||
|
||||
with torch.no_grad(), p.sd_model.ema_scope():
|
||||
with devices.autocast():
|
||||
p.init(all_prompts, all_seeds, all_subseeds)
|
||||
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
||||
|
||||
if state.job_count == -1:
|
||||
state.job_count = p.n_iter
|
||||
|
||||
for n in range(p.n_iter):
|
||||
p.iteration = n
|
||||
|
||||
if state.skipped:
|
||||
state.skipped = False
|
||||
|
||||
if state.interrupted:
|
||||
break
|
||||
|
||||
prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
|
||||
if (len(prompts) == 0):
|
||||
if len(prompts) == 0:
|
||||
break
|
||||
|
||||
#uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
|
||||
#c = p.sd_model.get_learned_conditioning(prompts)
|
||||
if p.scripts is not None:
|
||||
p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
|
||||
|
||||
with devices.autocast():
|
||||
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
|
||||
uc = prompt_parser.get_learned_conditioning(shared.sd_model, negative_prompts, p.steps)
|
||||
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
|
||||
|
||||
if len(model_hijack.comments) > 0:
|
||||
@ -408,10 +577,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||
|
||||
with devices.autocast():
|
||||
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
|
||||
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
|
||||
|
||||
samples_ddim = samples_ddim.to(devices.dtype_vae)
|
||||
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
|
||||
x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
|
||||
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
del samples_ddim
|
||||
@ -421,9 +590,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
if opts.filter_nsfw:
|
||||
import modules.safety as safety
|
||||
x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
|
||||
if p.scripts is not None:
|
||||
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
|
||||
|
||||
for i, x_sample in enumerate(x_samples_ddim):
|
||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||
@ -442,22 +610,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
|
||||
if p.color_corrections is not None and i < len(p.color_corrections):
|
||||
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
|
||||
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
|
||||
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
||||
images.save_image(image_without_cc, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
|
||||
image = apply_color_correction(p.color_corrections[i], image)
|
||||
|
||||
if p.overlay_images is not None and i < len(p.overlay_images):
|
||||
overlay = p.overlay_images[i]
|
||||
|
||||
if p.paste_to is not None:
|
||||
x, y, w, h = p.paste_to
|
||||
base_image = Image.new('RGBA', (overlay.width, overlay.height))
|
||||
image = images.resize_image(1, image, w, h)
|
||||
base_image.paste(image, (x, y))
|
||||
image = base_image
|
||||
|
||||
image = image.convert('RGBA')
|
||||
image.alpha_composite(overlay)
|
||||
image = image.convert('RGB')
|
||||
image = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
||||
|
||||
if opts.samples_save and not p.do_not_save_samples:
|
||||
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
|
||||
@ -490,23 +647,33 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
index_of_first_image = 1
|
||||
|
||||
if opts.grid_save:
|
||||
images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
||||
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
||||
|
||||
devices.torch_gc()
|
||||
return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
|
||||
|
||||
res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.postprocess(p, res)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
sampler = None
|
||||
|
||||
def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs):
|
||||
def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.enable_hr = enable_hr
|
||||
self.denoising_strength = denoising_strength
|
||||
self.firstphase_width = firstphase_width
|
||||
self.firstphase_height = firstphase_height
|
||||
self.truncate_x = 0
|
||||
self.truncate_y = 0
|
||||
self.hr_scale = hr_scale
|
||||
self.hr_upscaler = hr_upscaler
|
||||
|
||||
if firstphase_width != 0 or firstphase_height != 0:
|
||||
print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr)
|
||||
self.hr_scale = self.width / firstphase_width
|
||||
self.width = firstphase_width
|
||||
self.height = firstphase_height
|
||||
|
||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||
if self.enable_hr:
|
||||
@ -515,48 +682,50 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
else:
|
||||
state.job_count = state.job_count * 2
|
||||
|
||||
if self.firstphase_width == 0 or self.firstphase_height == 0:
|
||||
desired_pixel_count = 512 * 512
|
||||
actual_pixel_count = self.width * self.height
|
||||
scale = math.sqrt(desired_pixel_count / actual_pixel_count)
|
||||
self.firstphase_width = math.ceil(scale * self.width / 64) * 64
|
||||
self.firstphase_height = math.ceil(scale * self.height / 64) * 64
|
||||
firstphase_width_truncated = int(scale * self.width)
|
||||
firstphase_height_truncated = int(scale * self.height)
|
||||
self.extra_generation_params["Hires upscale"] = self.hr_scale
|
||||
if self.hr_upscaler is not None:
|
||||
self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
|
||||
|
||||
else:
|
||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||
|
||||
width_ratio = self.width / self.firstphase_width
|
||||
height_ratio = self.height / self.firstphase_height
|
||||
latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
|
||||
if self.enable_hr and latent_scale_mode is None:
|
||||
assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}"
|
||||
|
||||
if width_ratio > height_ratio:
|
||||
firstphase_width_truncated = self.firstphase_width
|
||||
firstphase_height_truncated = self.firstphase_width * self.height / self.width
|
||||
else:
|
||||
firstphase_width_truncated = self.firstphase_height * self.width / self.height
|
||||
firstphase_height_truncated = self.firstphase_height
|
||||
|
||||
self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
|
||||
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
|
||||
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
|
||||
|
||||
|
||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
||||
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
||||
|
||||
if not self.enable_hr:
|
||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
|
||||
return samples
|
||||
|
||||
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
|
||||
target_width = int(self.width * self.hr_scale)
|
||||
target_height = int(self.height * self.hr_scale)
|
||||
|
||||
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
|
||||
def save_intermediate(image, index):
|
||||
"""saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
|
||||
|
||||
if opts.use_scale_latent_for_hires_fix:
|
||||
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
|
||||
if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
|
||||
return
|
||||
|
||||
if not isinstance(image, Image.Image):
|
||||
image = sd_samplers.sample_to_image(image, index, approximation=0)
|
||||
|
||||
info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
|
||||
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, suffix="-before-highres-fix")
|
||||
|
||||
if latent_scale_mode is not None:
|
||||
for i in range(samples.shape[0]):
|
||||
save_intermediate(samples, i)
|
||||
|
||||
samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
|
||||
|
||||
# Avoid making the inpainting conditioning unless necessary as
|
||||
# this does need some extra compute to decode / encode the image again.
|
||||
if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
|
||||
image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
|
||||
else:
|
||||
image_conditioning = self.txt2img_image_conditioning(samples)
|
||||
else:
|
||||
decoded_samples = decode_first_stage(self.sd_model, samples)
|
||||
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
@ -566,7 +735,10 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
image = Image.fromarray(x_sample)
|
||||
image = images.resize_image(0, image, self.width, self.height)
|
||||
|
||||
save_intermediate(image, i)
|
||||
|
||||
image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = np.moveaxis(image, 2, 0)
|
||||
batch_images.append(image)
|
||||
@ -577,17 +749,19 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
|
||||
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
|
||||
|
||||
image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
|
||||
|
||||
shared.state.nextjob()
|
||||
|
||||
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||
|
||||
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
|
||||
|
||||
# GC now before running the next img2img to prevent running out of memory
|
||||
x = None
|
||||
devices.torch_gc()
|
||||
|
||||
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
|
||||
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning)
|
||||
|
||||
return samples
|
||||
|
||||
@ -595,7 +769,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
sampler = None
|
||||
|
||||
def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0, **kwargs):
|
||||
def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.init_images = init_images
|
||||
@ -603,7 +777,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
self.denoising_strength: float = denoising_strength
|
||||
self.init_latent = None
|
||||
self.image_mask = mask
|
||||
#self.image_unblurred_mask = None
|
||||
self.latent_mask = None
|
||||
self.mask_for_overlay = None
|
||||
self.mask_blur = mask_blur
|
||||
@ -611,65 +784,68 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
self.inpaint_full_res = inpaint_full_res
|
||||
self.inpaint_full_res_padding = inpaint_full_res_padding
|
||||
self.inpainting_mask_invert = inpainting_mask_invert
|
||||
self.initial_noise_multiplier = opts.initial_noise_multiplier if initial_noise_multiplier is None else initial_noise_multiplier
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
self.image_conditioning = None
|
||||
|
||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
|
||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||
crop_region = None
|
||||
|
||||
if self.image_mask is not None:
|
||||
self.image_mask = self.image_mask.convert('L')
|
||||
image_mask = self.image_mask
|
||||
|
||||
if image_mask is not None:
|
||||
image_mask = image_mask.convert('L')
|
||||
|
||||
if self.inpainting_mask_invert:
|
||||
self.image_mask = ImageOps.invert(self.image_mask)
|
||||
|
||||
#self.image_unblurred_mask = self.image_mask
|
||||
image_mask = ImageOps.invert(image_mask)
|
||||
|
||||
if self.mask_blur > 0:
|
||||
self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
|
||||
image_mask = image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
|
||||
|
||||
if self.inpaint_full_res:
|
||||
self.mask_for_overlay = self.image_mask
|
||||
mask = self.image_mask.convert('L')
|
||||
self.mask_for_overlay = image_mask
|
||||
mask = image_mask.convert('L')
|
||||
crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
|
||||
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
|
||||
x1, y1, x2, y2 = crop_region
|
||||
|
||||
mask = mask.crop(crop_region)
|
||||
self.image_mask = images.resize_image(2, mask, self.width, self.height)
|
||||
image_mask = images.resize_image(2, mask, self.width, self.height)
|
||||
self.paste_to = (x1, y1, x2-x1, y2-y1)
|
||||
else:
|
||||
self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
|
||||
np_mask = np.array(self.image_mask)
|
||||
image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
|
||||
np_mask = np.array(image_mask)
|
||||
np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
|
||||
self.mask_for_overlay = Image.fromarray(np_mask)
|
||||
|
||||
self.overlay_images = []
|
||||
|
||||
latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask
|
||||
latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
|
||||
|
||||
add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
|
||||
if add_color_corrections:
|
||||
self.color_corrections = []
|
||||
imgs = []
|
||||
for img in self.init_images:
|
||||
image = img.convert("RGB")
|
||||
image = images.flatten(img, opts.img2img_background_color)
|
||||
|
||||
if crop_region is None:
|
||||
if crop_region is None and self.resize_mode != 3:
|
||||
image = images.resize_image(self.resize_mode, image, self.width, self.height)
|
||||
|
||||
if self.image_mask is not None:
|
||||
if image_mask is not None:
|
||||
image_masked = Image.new('RGBa', (image.width, image.height))
|
||||
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
|
||||
|
||||
self.overlay_images.append(image_masked.convert('RGBA'))
|
||||
|
||||
# crop_region is not None if we are doing inpaint full res
|
||||
if crop_region is not None:
|
||||
image = image.crop(crop_region)
|
||||
image = images.resize_image(2, image, self.width, self.height)
|
||||
|
||||
if self.image_mask is not None:
|
||||
if image_mask is not None:
|
||||
if self.inpainting_fill != 1:
|
||||
image = masking.fill(image, latent_mask)
|
||||
|
||||
@ -685,6 +861,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
|
||||
if self.overlay_images is not None:
|
||||
self.overlay_images = self.overlay_images * self.batch_size
|
||||
|
||||
if self.color_corrections is not None and len(self.color_corrections) == 1:
|
||||
self.color_corrections = self.color_corrections * self.batch_size
|
||||
|
||||
elif len(imgs) <= self.batch_size:
|
||||
self.batch_size = len(imgs)
|
||||
batch_images = np.array(imgs)
|
||||
@ -697,7 +877,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
|
||||
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
|
||||
|
||||
if self.image_mask is not None:
|
||||
if self.resize_mode == 3:
|
||||
self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
|
||||
|
||||
if image_mask is not None:
|
||||
init_mask = latent_mask
|
||||
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
|
||||
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
|
||||
@ -714,10 +897,16 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
elif self.inpainting_fill == 3:
|
||||
self.init_latent = self.init_latent * self.mask
|
||||
|
||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
||||
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask)
|
||||
|
||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||
|
||||
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
|
||||
if self.initial_noise_multiplier != 1.0:
|
||||
self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
|
||||
x *= self.initial_noise_multiplier
|
||||
|
||||
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
|
||||
|
||||
if self.mask is not None:
|
||||
samples = samples * self.nmask + self.init_latent * self.mask
|
||||
|
113
modules/safe.py
113
modules/safe.py
@ -23,23 +23,30 @@ def encode(*args):
|
||||
|
||||
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
extra_handler = None
|
||||
|
||||
def persistent_load(self, saved_id):
|
||||
assert saved_id[0] == 'storage'
|
||||
return TypedStorage()
|
||||
|
||||
def find_class(self, module, name):
|
||||
if self.extra_handler is not None:
|
||||
res = self.extra_handler(module, name)
|
||||
if res is not None:
|
||||
return res
|
||||
|
||||
if module == 'collections' and name == 'OrderedDict':
|
||||
return getattr(collections, name)
|
||||
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
|
||||
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
|
||||
return getattr(torch._utils, name)
|
||||
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage']:
|
||||
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']:
|
||||
return getattr(torch, name)
|
||||
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
|
||||
return getattr(torch.nn.modules.container, name)
|
||||
if module == 'numpy.core.multiarray' and name == 'scalar':
|
||||
return numpy.core.multiarray.scalar
|
||||
if module == 'numpy' and name == 'dtype':
|
||||
return numpy.dtype
|
||||
if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
|
||||
return getattr(numpy.core.multiarray, name)
|
||||
if module == 'numpy' and name in ['dtype', 'ndarray']:
|
||||
return getattr(numpy, name)
|
||||
if module == '_codecs' and name == 'encode':
|
||||
return encode
|
||||
if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
|
||||
@ -52,32 +59,37 @@ class RestrictedUnpickler(pickle.Unpickler):
|
||||
return set
|
||||
|
||||
# Forbid everything else.
|
||||
raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
|
||||
raise Exception(f"global '{module}/{name}' is forbidden")
|
||||
|
||||
|
||||
allowed_zip_names = ["archive/data.pkl", "archive/version"]
|
||||
allowed_zip_names_re = re.compile(r"^archive/data/\d+$")
|
||||
|
||||
# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
|
||||
allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
|
||||
data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
|
||||
|
||||
def check_zip_filenames(filename, names):
|
||||
for name in names:
|
||||
if name in allowed_zip_names:
|
||||
continue
|
||||
if allowed_zip_names_re.match(name):
|
||||
continue
|
||||
|
||||
raise Exception(f"bad file inside {filename}: {name}")
|
||||
|
||||
|
||||
def check_pt(filename):
|
||||
def check_pt(filename, extra_handler):
|
||||
try:
|
||||
|
||||
# new pytorch format is a zip file
|
||||
with zipfile.ZipFile(filename) as z:
|
||||
check_zip_filenames(filename, z.namelist())
|
||||
|
||||
with z.open('archive/data.pkl') as file:
|
||||
# find filename of data.pkl in zip file: '<directory name>/data.pkl'
|
||||
data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
|
||||
if len(data_pkl_filenames) == 0:
|
||||
raise Exception(f"data.pkl not found in {filename}")
|
||||
if len(data_pkl_filenames) > 1:
|
||||
raise Exception(f"Multiple data.pkl found in {filename}")
|
||||
with z.open(data_pkl_filenames[0]) as file:
|
||||
unpickler = RestrictedUnpickler(file)
|
||||
unpickler.extra_handler = extra_handler
|
||||
unpickler.load()
|
||||
|
||||
except zipfile.BadZipfile:
|
||||
@ -85,33 +97,96 @@ def check_pt(filename):
|
||||
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
|
||||
with open(filename, "rb") as file:
|
||||
unpickler = RestrictedUnpickler(file)
|
||||
unpickler.extra_handler = extra_handler
|
||||
for i in range(5):
|
||||
unpickler.load()
|
||||
|
||||
|
||||
def load(filename, *args, **kwargs):
|
||||
return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
|
||||
|
||||
|
||||
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
||||
"""
|
||||
this function is intended to be used by extensions that want to load models with
|
||||
some extra classes in them that the usual unpickler would find suspicious.
|
||||
|
||||
Use the extra_handler argument to specify a function that takes module and field name as text,
|
||||
and returns that field's value:
|
||||
|
||||
```python
|
||||
def extra(module, name):
|
||||
if module == 'collections' and name == 'OrderedDict':
|
||||
return collections.OrderedDict
|
||||
|
||||
return None
|
||||
|
||||
safe.load_with_extra('model.pt', extra_handler=extra)
|
||||
```
|
||||
|
||||
The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
|
||||
definitely unsafe.
|
||||
"""
|
||||
|
||||
from modules import shared
|
||||
|
||||
try:
|
||||
if not shared.cmd_opts.disable_safe_unpickle:
|
||||
check_pt(filename)
|
||||
check_pt(filename, extra_handler)
|
||||
|
||||
except pickle.UnpicklingError:
|
||||
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print(f"-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
|
||||
print(f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
|
||||
print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
|
||||
print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
|
||||
return None
|
||||
|
||||
except Exception:
|
||||
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
|
||||
print(f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
|
||||
print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
|
||||
print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
|
||||
return None
|
||||
|
||||
return unsafe_torch_load(filename, *args, **kwargs)
|
||||
|
||||
|
||||
class Extra:
|
||||
"""
|
||||
A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
|
||||
(because it's not your code making the torch.load call). The intended use is like this:
|
||||
|
||||
```
|
||||
import torch
|
||||
from modules import safe
|
||||
|
||||
def handler(module, name):
|
||||
if module == 'torch' and name in ['float64', 'float16']:
|
||||
return getattr(torch, name)
|
||||
|
||||
return None
|
||||
|
||||
with safe.Extra(handler):
|
||||
x = torch.load('model.pt')
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, handler):
|
||||
self.handler = handler
|
||||
|
||||
def __enter__(self):
|
||||
global global_extra_handler
|
||||
|
||||
assert global_extra_handler is None, 'already inside an Extra() block'
|
||||
global_extra_handler = self.handler
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
global global_extra_handler
|
||||
|
||||
global_extra_handler = None
|
||||
|
||||
|
||||
unsafe_torch_load = torch.load
|
||||
torch.load = load
|
||||
global_extra_handler = None
|
||||
|
||||
|
@ -1,42 +0,0 @@
|
||||
import torch
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from transformers import AutoFeatureExtractor
|
||||
from PIL import Image
|
||||
|
||||
import modules.shared as shared
|
||||
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
safety_feature_extractor = None
|
||||
safety_checker = None
|
||||
|
||||
def numpy_to_pil(images):
|
||||
"""
|
||||
Convert a numpy image or a batch of images to a PIL image.
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
images = images[None, ...]
|
||||
images = (images * 255).round().astype("uint8")
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
|
||||
return pil_images
|
||||
|
||||
# check and replace nsfw content
|
||||
def check_safety(x_image):
|
||||
global safety_feature_extractor, safety_checker
|
||||
|
||||
if safety_feature_extractor is None:
|
||||
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
||||
|
||||
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
|
||||
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
|
||||
|
||||
return x_checked_image, has_nsfw_concept
|
||||
|
||||
|
||||
def censor_batch(x):
|
||||
x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy()
|
||||
x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy)
|
||||
x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
|
||||
|
||||
return x
|
281
modules/script_callbacks.py
Normal file
281
modules/script_callbacks.py
Normal file
@ -0,0 +1,281 @@
|
||||
import sys
|
||||
import traceback
|
||||
from collections import namedtuple
|
||||
import inspect
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import FastAPI
|
||||
from gradio import Blocks
|
||||
|
||||
|
||||
def report_exception(c, job):
|
||||
print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
|
||||
class ImageSaveParams:
|
||||
def __init__(self, image, p, filename, pnginfo):
|
||||
self.image = image
|
||||
"""the PIL image itself"""
|
||||
|
||||
self.p = p
|
||||
"""p object with processing parameters; either StableDiffusionProcessing or an object with same fields"""
|
||||
|
||||
self.filename = filename
|
||||
"""name of file that the image would be saved to"""
|
||||
|
||||
self.pnginfo = pnginfo
|
||||
"""dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
|
||||
|
||||
|
||||
class CFGDenoiserParams:
|
||||
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps):
|
||||
self.x = x
|
||||
"""Latent image representation in the process of being denoised"""
|
||||
|
||||
self.image_cond = image_cond
|
||||
"""Conditioning image"""
|
||||
|
||||
self.sigma = sigma
|
||||
"""Current sigma noise step value"""
|
||||
|
||||
self.sampling_step = sampling_step
|
||||
"""Current Sampling step number"""
|
||||
|
||||
self.total_sampling_steps = total_sampling_steps
|
||||
"""Total number of sampling steps planned"""
|
||||
|
||||
|
||||
class UiTrainTabParams:
|
||||
def __init__(self, txt2img_preview_params):
|
||||
self.txt2img_preview_params = txt2img_preview_params
|
||||
|
||||
|
||||
class ImageGridLoopParams:
|
||||
def __init__(self, imgs, cols, rows):
|
||||
self.imgs = imgs
|
||||
self.cols = cols
|
||||
self.rows = rows
|
||||
|
||||
|
||||
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
|
||||
callback_map = dict(
|
||||
callbacks_app_started=[],
|
||||
callbacks_model_loaded=[],
|
||||
callbacks_ui_tabs=[],
|
||||
callbacks_ui_train_tabs=[],
|
||||
callbacks_ui_settings=[],
|
||||
callbacks_before_image_saved=[],
|
||||
callbacks_image_saved=[],
|
||||
callbacks_cfg_denoiser=[],
|
||||
callbacks_before_component=[],
|
||||
callbacks_after_component=[],
|
||||
callbacks_image_grid=[],
|
||||
)
|
||||
|
||||
|
||||
def clear_callbacks():
|
||||
for callback_list in callback_map.values():
|
||||
callback_list.clear()
|
||||
|
||||
|
||||
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
|
||||
for c in callback_map['callbacks_app_started']:
|
||||
try:
|
||||
c.callback(demo, app)
|
||||
except Exception:
|
||||
report_exception(c, 'app_started_callback')
|
||||
|
||||
|
||||
def model_loaded_callback(sd_model):
|
||||
for c in callback_map['callbacks_model_loaded']:
|
||||
try:
|
||||
c.callback(sd_model)
|
||||
except Exception:
|
||||
report_exception(c, 'model_loaded_callback')
|
||||
|
||||
|
||||
def ui_tabs_callback():
|
||||
res = []
|
||||
|
||||
for c in callback_map['callbacks_ui_tabs']:
|
||||
try:
|
||||
res += c.callback() or []
|
||||
except Exception:
|
||||
report_exception(c, 'ui_tabs_callback')
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def ui_train_tabs_callback(params: UiTrainTabParams):
|
||||
for c in callback_map['callbacks_ui_train_tabs']:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'callbacks_ui_train_tabs')
|
||||
|
||||
|
||||
def ui_settings_callback():
|
||||
for c in callback_map['callbacks_ui_settings']:
|
||||
try:
|
||||
c.callback()
|
||||
except Exception:
|
||||
report_exception(c, 'ui_settings_callback')
|
||||
|
||||
|
||||
def before_image_saved_callback(params: ImageSaveParams):
|
||||
for c in callback_map['callbacks_before_image_saved']:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'before_image_saved_callback')
|
||||
|
||||
|
||||
def image_saved_callback(params: ImageSaveParams):
|
||||
for c in callback_map['callbacks_image_saved']:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'image_saved_callback')
|
||||
|
||||
|
||||
def cfg_denoiser_callback(params: CFGDenoiserParams):
|
||||
for c in callback_map['callbacks_cfg_denoiser']:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'cfg_denoiser_callback')
|
||||
|
||||
|
||||
def before_component_callback(component, **kwargs):
|
||||
for c in callback_map['callbacks_before_component']:
|
||||
try:
|
||||
c.callback(component, **kwargs)
|
||||
except Exception:
|
||||
report_exception(c, 'before_component_callback')
|
||||
|
||||
|
||||
def after_component_callback(component, **kwargs):
|
||||
for c in callback_map['callbacks_after_component']:
|
||||
try:
|
||||
c.callback(component, **kwargs)
|
||||
except Exception:
|
||||
report_exception(c, 'after_component_callback')
|
||||
|
||||
|
||||
def image_grid_callback(params: ImageGridLoopParams):
|
||||
for c in callback_map['callbacks_image_grid']:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'image_grid')
|
||||
|
||||
|
||||
def add_callback(callbacks, fun):
|
||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
||||
|
||||
callbacks.append(ScriptCallback(filename, fun))
|
||||
|
||||
|
||||
def remove_current_script_callbacks():
|
||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
||||
if filename == 'unknown file':
|
||||
return
|
||||
for callback_list in callback_map.values():
|
||||
for callback_to_remove in [cb for cb in callback_list if cb.script == filename]:
|
||||
callback_list.remove(callback_to_remove)
|
||||
|
||||
|
||||
def remove_callbacks_for_function(callback_func):
|
||||
for callback_list in callback_map.values():
|
||||
for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:
|
||||
callback_list.remove(callback_to_remove)
|
||||
|
||||
|
||||
def on_app_started(callback):
|
||||
"""register a function to be called when the webui started, the gradio `Block` component and
|
||||
fastapi `FastAPI` object are passed as the arguments"""
|
||||
add_callback(callback_map['callbacks_app_started'], callback)
|
||||
|
||||
|
||||
def on_model_loaded(callback):
|
||||
"""register a function to be called when the stable diffusion model is created; the model is
|
||||
passed as an argument"""
|
||||
add_callback(callback_map['callbacks_model_loaded'], callback)
|
||||
|
||||
|
||||
def on_ui_tabs(callback):
|
||||
"""register a function to be called when the UI is creating new tabs.
|
||||
The function must either return a None, which means no new tabs to be added, or a list, where
|
||||
each element is a tuple:
|
||||
(gradio_component, title, elem_id)
|
||||
|
||||
gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)
|
||||
title is tab text displayed to user in the UI
|
||||
elem_id is HTML id for the tab
|
||||
"""
|
||||
add_callback(callback_map['callbacks_ui_tabs'], callback)
|
||||
|
||||
|
||||
def on_ui_train_tabs(callback):
|
||||
"""register a function to be called when the UI is creating new tabs for the train tab.
|
||||
Create your new tabs with gr.Tab.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_ui_train_tabs'], callback)
|
||||
|
||||
|
||||
def on_ui_settings(callback):
|
||||
"""register a function to be called before UI settings are populated; add your settings
|
||||
by using shared.opts.add_option(shared.OptionInfo(...)) """
|
||||
add_callback(callback_map['callbacks_ui_settings'], callback)
|
||||
|
||||
|
||||
def on_before_image_saved(callback):
|
||||
"""register a function to be called before an image is saved to a file.
|
||||
The callback is called with one argument:
|
||||
- params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_before_image_saved'], callback)
|
||||
|
||||
|
||||
def on_image_saved(callback):
|
||||
"""register a function to be called after an image is saved to a file.
|
||||
The callback is called with one argument:
|
||||
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_image_saved'], callback)
|
||||
|
||||
|
||||
def on_cfg_denoiser(callback):
|
||||
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
|
||||
The callback is called with one argument:
|
||||
- params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_cfg_denoiser'], callback)
|
||||
|
||||
|
||||
def on_before_component(callback):
|
||||
"""register a function to be called before a component is created.
|
||||
The callback is called with arguments:
|
||||
- component - gradio component that is about to be created.
|
||||
- **kwargs - args to gradio.components.IOComponent.__init__ function
|
||||
|
||||
Use elem_id/label fields of kwargs to figure out which component it is.
|
||||
This can be useful to inject your own components somewhere in the middle of vanilla UI.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_before_component'], callback)
|
||||
|
||||
|
||||
def on_after_component(callback):
|
||||
"""register a function to be called after a component is created. See on_before_component for more."""
|
||||
add_callback(callback_map['callbacks_after_component'], callback)
|
||||
|
||||
|
||||
def on_image_grid(callback):
|
||||
"""register a function to be called before making an image grid.
|
||||
The callback is called with one argument:
|
||||
- params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_image_grid'], callback)
|
34
modules/script_loading.py
Normal file
34
modules/script_loading.py
Normal file
@ -0,0 +1,34 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from types import ModuleType
|
||||
|
||||
|
||||
def load_module(path):
|
||||
with open(path, "r", encoding="utf8") as file:
|
||||
text = file.read()
|
||||
|
||||
compiled = compile(text, path, 'exec')
|
||||
module = ModuleType(os.path.basename(path))
|
||||
exec(compiled, module.__dict__)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def preload_extensions(extensions_dir, parser):
|
||||
if not os.path.isdir(extensions_dir):
|
||||
return
|
||||
|
||||
for dirname in sorted(os.listdir(extensions_dir)):
|
||||
preload_script = os.path.join(extensions_dir, dirname, "preload.py")
|
||||
if not os.path.isfile(preload_script):
|
||||
continue
|
||||
|
||||
try:
|
||||
module = load_module(preload_script)
|
||||
if hasattr(module, 'preload'):
|
||||
module.preload(parser)
|
||||
|
||||
except Exception:
|
||||
print(f"Error running preload() for {preload_script}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
@ -1,86 +1,211 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from collections import namedtuple
|
||||
|
||||
import modules.ui as ui
|
||||
import gradio as gr
|
||||
|
||||
from modules.processing import StableDiffusionProcessing
|
||||
from modules import shared
|
||||
from modules import shared, paths, script_callbacks, extensions, script_loading
|
||||
|
||||
AlwaysVisible = object()
|
||||
|
||||
|
||||
class Script:
|
||||
filename = None
|
||||
args_from = None
|
||||
args_to = None
|
||||
alwayson = False
|
||||
|
||||
is_txt2img = False
|
||||
is_img2img = False
|
||||
|
||||
"""A gr.Group component that has all script's UI inside it"""
|
||||
group = None
|
||||
|
||||
infotext_fields = None
|
||||
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
|
||||
parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
|
||||
"""
|
||||
|
||||
# The title of the script. This is what will be displayed in the dropdown menu.
|
||||
def title(self):
|
||||
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
# How the script is displayed in the UI. See https://gradio.app/docs/#components
|
||||
# for the different UI components you can use and how to create them.
|
||||
# Most UI components can return a value, such as a boolean for a checkbox.
|
||||
# The returned values are passed to the run method as parameters.
|
||||
def ui(self, is_img2img):
|
||||
"""this function should create gradio UI elements. See https://gradio.app/docs/#components
|
||||
The return value should be an array of all components that are used in processing.
|
||||
Values of those returned components will be passed to run() and process() functions.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
# Determines when the script should be shown in the dropdown menu via the
|
||||
# returned value. As an example:
|
||||
# is_img2img is True if the current tab is img2img, and False if it is txt2img.
|
||||
# Thus, return is_img2img to only show the script on the img2img tab.
|
||||
def show(self, is_img2img):
|
||||
"""
|
||||
is_img2img is True if this function is called for the img2img interface, and Fasle otherwise
|
||||
|
||||
This function should return:
|
||||
- False if the script should not be shown in UI at all
|
||||
- True if the script should be shown in UI if it's selected in the scripts dropdown
|
||||
- script.AlwaysVisible if the script should be shown in UI at all times
|
||||
"""
|
||||
|
||||
return True
|
||||
|
||||
# This is where the additional processing is implemented. The parameters include
|
||||
# self, the model object "p" (a StableDiffusionProcessing class, see
|
||||
# processing.py), and the parameters returned by the ui method.
|
||||
# Custom functions can be defined here, and additional libraries can be imported
|
||||
# to be used in processing. The return value should be a Processed object, which is
|
||||
# what is returned by the process_images method.
|
||||
def run(self, *args):
|
||||
def run(self, p, *args):
|
||||
"""
|
||||
This function is called if the script has been selected in the script dropdown.
|
||||
It must do all processing and return the Processed object with results, same as
|
||||
one returned by processing.process_images.
|
||||
|
||||
Usually the processing is done by calling the processing.process_images function.
|
||||
|
||||
args contains all values returned by components from ui()
|
||||
"""
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
# The description method is currently unused.
|
||||
# To add a description that appears when hovering over the title, amend the "titles"
|
||||
# dict in script.js to include the script title (returned by title) as a key, and
|
||||
# your description as the value.
|
||||
def process(self, p, *args):
|
||||
"""
|
||||
This function is called before processing begins for AlwaysVisible scripts.
|
||||
You can modify the processing object (p) here, inject hooks, etc.
|
||||
args contains all values returned by components from ui()
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def process_batch(self, p, *args, **kwargs):
|
||||
"""
|
||||
Same as process(), but called for every batch.
|
||||
|
||||
**kwargs will have those items:
|
||||
- batch_number - index of current batch, from 0 to number of batches-1
|
||||
- prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
|
||||
- seeds - list of seeds for current batch
|
||||
- subseeds - list of subseeds for current batch
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def postprocess_batch(self, p, *args, **kwargs):
|
||||
"""
|
||||
Same as process_batch(), but called for every batch after it has been generated.
|
||||
|
||||
**kwargs will have same items as process_batch, and also:
|
||||
- batch_number - index of current batch, from 0 to number of batches-1
|
||||
- images - torch tensor with all generated images, with values ranging from 0 to 1;
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def postprocess(self, p, processed, *args):
|
||||
"""
|
||||
This function is called after processing ends for AlwaysVisible scripts.
|
||||
args contains all values returned by components from ui()
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def before_component(self, component, **kwargs):
|
||||
"""
|
||||
Called before a component is created.
|
||||
Use elem_id/label fields of kwargs to figure out which component it is.
|
||||
This can be useful to inject your own components somewhere in the middle of vanilla UI.
|
||||
You can return created components in the ui() function to add them to the list of arguments for your processing functions
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def after_component(self, component, **kwargs):
|
||||
"""
|
||||
Called after a component is created. Same as above.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def describe(self):
|
||||
"""unused"""
|
||||
return ""
|
||||
|
||||
|
||||
current_basedir = paths.script_path
|
||||
|
||||
|
||||
def basedir():
|
||||
"""returns the base directory for the current script. For scripts in the main scripts directory,
|
||||
this is the main directory (where webui.py resides), and for scripts in extensions directory
|
||||
(ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic)
|
||||
"""
|
||||
return current_basedir
|
||||
|
||||
|
||||
scripts_data = []
|
||||
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
|
||||
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir"])
|
||||
|
||||
|
||||
def load_scripts(basedir):
|
||||
if not os.path.exists(basedir):
|
||||
return
|
||||
def list_scripts(scriptdirname, extension):
|
||||
scripts_list = []
|
||||
|
||||
basedir = os.path.join(paths.script_path, scriptdirname)
|
||||
if os.path.exists(basedir):
|
||||
for filename in sorted(os.listdir(basedir)):
|
||||
path = os.path.join(basedir, filename)
|
||||
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
|
||||
|
||||
if os.path.splitext(path)[1].lower() != '.py':
|
||||
for ext in extensions.active():
|
||||
scripts_list += ext.list_files(scriptdirname, extension)
|
||||
|
||||
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
||||
|
||||
return scripts_list
|
||||
|
||||
|
||||
def list_files_with_name(filename):
|
||||
res = []
|
||||
|
||||
dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
|
||||
|
||||
for dirpath in dirs:
|
||||
if not os.path.isdir(dirpath):
|
||||
continue
|
||||
|
||||
if not os.path.isfile(path):
|
||||
continue
|
||||
path = os.path.join(dirpath, filename)
|
||||
if os.path.isfile(path):
|
||||
res.append(path)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def load_scripts():
|
||||
global current_basedir
|
||||
scripts_data.clear()
|
||||
script_callbacks.clear_callbacks()
|
||||
|
||||
scripts_list = list_scripts("scripts", ".py")
|
||||
|
||||
syspath = sys.path
|
||||
|
||||
for scriptfile in sorted(scripts_list):
|
||||
try:
|
||||
with open(path, "r", encoding="utf8") as file:
|
||||
text = file.read()
|
||||
if scriptfile.basedir != paths.script_path:
|
||||
sys.path = [scriptfile.basedir] + sys.path
|
||||
current_basedir = scriptfile.basedir
|
||||
|
||||
from types import ModuleType
|
||||
compiled = compile(text, path, 'exec')
|
||||
module = ModuleType(filename)
|
||||
exec(compiled, module.__dict__)
|
||||
module = script_loading.load_module(scriptfile.path)
|
||||
|
||||
for key, script_class in module.__dict__.items():
|
||||
if type(script_class) == type and issubclass(script_class, Script):
|
||||
scripts_data.append((script_class, path))
|
||||
scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir))
|
||||
|
||||
except Exception:
|
||||
print(f"Error loading script: {filename}", file=sys.stderr)
|
||||
print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
finally:
|
||||
sys.path = syspath
|
||||
current_basedir = paths.script_path
|
||||
|
||||
|
||||
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
||||
try:
|
||||
@ -96,64 +221,94 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
||||
class ScriptRunner:
|
||||
def __init__(self):
|
||||
self.scripts = []
|
||||
self.selectable_scripts = []
|
||||
self.alwayson_scripts = []
|
||||
self.titles = []
|
||||
self.infotext_fields = []
|
||||
|
||||
def setup_ui(self, is_img2img):
|
||||
for script_class, path in scripts_data:
|
||||
def initialize_scripts(self, is_img2img):
|
||||
self.scripts.clear()
|
||||
self.alwayson_scripts.clear()
|
||||
self.selectable_scripts.clear()
|
||||
|
||||
for script_class, path, basedir in scripts_data:
|
||||
script = script_class()
|
||||
script.filename = path
|
||||
script.is_txt2img = not is_img2img
|
||||
script.is_img2img = is_img2img
|
||||
|
||||
if not script.show(is_img2img):
|
||||
continue
|
||||
visibility = script.show(script.is_img2img)
|
||||
|
||||
if visibility == AlwaysVisible:
|
||||
self.scripts.append(script)
|
||||
self.alwayson_scripts.append(script)
|
||||
script.alwayson = True
|
||||
|
||||
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts]
|
||||
elif visibility:
|
||||
self.scripts.append(script)
|
||||
self.selectable_scripts.append(script)
|
||||
|
||||
dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index")
|
||||
dropdown.save_to_config = True
|
||||
inputs = [dropdown]
|
||||
def setup_ui(self):
|
||||
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
|
||||
|
||||
for script in self.scripts:
|
||||
inputs = [None]
|
||||
inputs_alwayson = [True]
|
||||
|
||||
def create_script_ui(script, inputs, inputs_alwayson):
|
||||
script.args_from = len(inputs)
|
||||
script.args_to = len(inputs)
|
||||
|
||||
controls = wrap_call(script.ui, script.filename, "ui", is_img2img)
|
||||
controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
|
||||
|
||||
if controls is None:
|
||||
continue
|
||||
return
|
||||
|
||||
for control in controls:
|
||||
control.custom_script_source = os.path.basename(script.filename)
|
||||
control.visible = False
|
||||
|
||||
if script.infotext_fields is not None:
|
||||
self.infotext_fields += script.infotext_fields
|
||||
|
||||
inputs += controls
|
||||
inputs_alwayson += [script.alwayson for _ in controls]
|
||||
script.args_to = len(inputs)
|
||||
|
||||
def select_script(script_index):
|
||||
if 0 < script_index <= len(self.scripts):
|
||||
script = self.scripts[script_index-1]
|
||||
args_from = script.args_from
|
||||
args_to = script.args_to
|
||||
else:
|
||||
args_from = 0
|
||||
args_to = 0
|
||||
for script in self.alwayson_scripts:
|
||||
with gr.Group() as group:
|
||||
create_script_ui(script, inputs, inputs_alwayson)
|
||||
|
||||
return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))]
|
||||
script.group = group
|
||||
|
||||
dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
|
||||
dropdown.save_to_config = True
|
||||
inputs[0] = dropdown
|
||||
|
||||
for script in self.selectable_scripts:
|
||||
with gr.Group(visible=False) as group:
|
||||
create_script_ui(script, inputs, inputs_alwayson)
|
||||
|
||||
script.group = group
|
||||
|
||||
def select_script(script_index):
|
||||
selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
|
||||
|
||||
return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
|
||||
|
||||
def init_field(title):
|
||||
"""called when an initial value is set from ui-config.json to show script's UI components"""
|
||||
|
||||
if title == 'None':
|
||||
return
|
||||
|
||||
script_index = self.titles.index(title)
|
||||
script = self.scripts[script_index]
|
||||
for i in range(script.args_from, script.args_to):
|
||||
inputs[i].visible = True
|
||||
self.selectable_scripts[script_index].group.visible = True
|
||||
|
||||
dropdown.init_field = init_field
|
||||
|
||||
dropdown.change(
|
||||
fn=select_script,
|
||||
inputs=[dropdown],
|
||||
outputs=inputs
|
||||
outputs=[script.group for script in self.selectable_scripts]
|
||||
)
|
||||
|
||||
return inputs
|
||||
@ -164,7 +319,7 @@ class ScriptRunner:
|
||||
if script_index == 0:
|
||||
return None
|
||||
|
||||
script = self.scripts[script_index-1]
|
||||
script = self.selectable_scripts[script_index-1]
|
||||
|
||||
if script is None:
|
||||
return None
|
||||
@ -176,19 +331,68 @@ class ScriptRunner:
|
||||
|
||||
return processed
|
||||
|
||||
def reload_sources(self):
|
||||
def process(self, p):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.process(p, *script_args)
|
||||
except Exception:
|
||||
print(f"Error running process: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
def process_batch(self, p, **kwargs):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.process_batch(p, *script_args, **kwargs)
|
||||
except Exception:
|
||||
print(f"Error running process_batch: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
def postprocess(self, p, processed):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.postprocess(p, processed, *script_args)
|
||||
except Exception:
|
||||
print(f"Error running postprocess: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
def postprocess_batch(self, p, images, **kwargs):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.postprocess_batch(p, *script_args, images=images, **kwargs)
|
||||
except Exception:
|
||||
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
def before_component(self, component, **kwargs):
|
||||
for script in self.scripts:
|
||||
try:
|
||||
script.before_component(component, **kwargs)
|
||||
except Exception:
|
||||
print(f"Error running before_component: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
def after_component(self, component, **kwargs):
|
||||
for script in self.scripts:
|
||||
try:
|
||||
script.after_component(component, **kwargs)
|
||||
except Exception:
|
||||
print(f"Error running after_component: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
def reload_sources(self, cache):
|
||||
for si, script in list(enumerate(self.scripts)):
|
||||
with open(script.filename, "r", encoding="utf8") as file:
|
||||
args_from = script.args_from
|
||||
args_to = script.args_to
|
||||
filename = script.filename
|
||||
text = file.read()
|
||||
|
||||
from types import ModuleType
|
||||
|
||||
compiled = compile(text, filename, 'exec')
|
||||
module = ModuleType(script.filename)
|
||||
exec(compiled, module.__dict__)
|
||||
module = cache.get(filename, None)
|
||||
if module is None:
|
||||
module = script_loading.load_module(script.filename)
|
||||
cache[filename] = module
|
||||
|
||||
for key, script_class in module.__dict__.items():
|
||||
if type(script_class) == type and issubclass(script_class, Script):
|
||||
@ -197,19 +401,42 @@ class ScriptRunner:
|
||||
self.scripts[si].args_from = args_from
|
||||
self.scripts[si].args_to = args_to
|
||||
|
||||
|
||||
scripts_txt2img = ScriptRunner()
|
||||
scripts_img2img = ScriptRunner()
|
||||
scripts_current: ScriptRunner = None
|
||||
|
||||
|
||||
def reload_script_body_only():
|
||||
scripts_txt2img.reload_sources()
|
||||
scripts_img2img.reload_sources()
|
||||
cache = {}
|
||||
scripts_txt2img.reload_sources(cache)
|
||||
scripts_img2img.reload_sources(cache)
|
||||
|
||||
|
||||
def reload_scripts(basedir):
|
||||
def reload_scripts():
|
||||
global scripts_txt2img, scripts_img2img
|
||||
|
||||
scripts_data.clear()
|
||||
load_scripts(basedir)
|
||||
load_scripts()
|
||||
|
||||
scripts_txt2img = ScriptRunner()
|
||||
scripts_img2img = ScriptRunner()
|
||||
|
||||
|
||||
def IOComponent_init(self, *args, **kwargs):
|
||||
if scripts_current is not None:
|
||||
scripts_current.before_component(self, **kwargs)
|
||||
|
||||
script_callbacks.before_component_callback(self, **kwargs)
|
||||
|
||||
res = original_IOComponent_init(self, *args, **kwargs)
|
||||
|
||||
script_callbacks.after_component_callback(self, **kwargs)
|
||||
|
||||
if scripts_current is not None:
|
||||
scripts_current.after_component(self, **kwargs)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
original_IOComponent_init = gr.components.IOComponent.__init__
|
||||
gr.components.IOComponent.__init__ = IOComponent_init
|
||||
|
@ -1,60 +1,81 @@
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import einsum
|
||||
from torch.nn.functional import silu
|
||||
|
||||
import modules.textual_inversion.textual_inversion
|
||||
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
|
||||
from modules.shared import opts, device, cmd_opts
|
||||
from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
|
||||
from modules.hypernetworks import hypernetwork
|
||||
from modules.shared import cmd_opts
|
||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
||||
|
||||
from modules.sd_hijack_optimizations import invokeAI_mps_available
|
||||
|
||||
import ldm.modules.attention
|
||||
import ldm.modules.diffusionmodules.model
|
||||
import ldm.modules.diffusionmodules.openaimodel
|
||||
import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
import ldm.modules.encoders.modules
|
||||
|
||||
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
||||
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
||||
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||
|
||||
# new memory efficient cross attention blocks do not support hypernets and we already
|
||||
# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
|
||||
ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
|
||||
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
|
||||
|
||||
# silence new console spam from SD2
|
||||
ldm.modules.attention.print = lambda *args: None
|
||||
ldm.modules.diffusionmodules.model.print = lambda *args: None
|
||||
|
||||
|
||||
def apply_optimizations():
|
||||
undo_optimizations()
|
||||
|
||||
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
||||
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
||||
|
||||
optimization_method = None
|
||||
|
||||
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
|
||||
print("Applying xformers cross attention optimization.")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
|
||||
optimization_method = 'xformers'
|
||||
elif cmd_opts.opt_split_attention_v1:
|
||||
print("Applying v1 cross attention optimization.")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
||||
optimization_method = 'V1'
|
||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
|
||||
if not invokeAI_mps_available and shared.device.type == 'mps':
|
||||
print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
|
||||
print("Applying v1 cross attention optimization.")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
||||
optimization_method = 'V1'
|
||||
else:
|
||||
print("Applying cross attention optimization (InvokeAI).")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
|
||||
optimization_method = 'InvokeAI'
|
||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
||||
print("Applying cross attention optimization (Doggettx).")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
|
||||
optimization_method = 'Doggettx'
|
||||
|
||||
return optimization_method
|
||||
|
||||
|
||||
def undo_optimizations():
|
||||
from modules.hypernetworks import hypernetwork
|
||||
|
||||
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||
|
||||
|
||||
def get_target_prompt_token_count(token_count):
|
||||
return math.ceil(max(token_count, 1) / 75) * 75
|
||||
def fix_checkpoint():
|
||||
ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward
|
||||
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
|
||||
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
|
||||
|
||||
|
||||
class StableDiffusionModelHijack:
|
||||
@ -63,18 +84,31 @@ class StableDiffusionModelHijack:
|
||||
layers = None
|
||||
circular_enabled = False
|
||||
clip = None
|
||||
optimization_method = None
|
||||
|
||||
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
|
||||
|
||||
def hijack(self, m):
|
||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||
|
||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||
model_embeddings = m.cond_stage_model.roberta.embeddings
|
||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
|
||||
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
|
||||
|
||||
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
|
||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
||||
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||
|
||||
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
|
||||
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
|
||||
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||
|
||||
self.optimization_method = apply_optimizations()
|
||||
|
||||
self.clip = m.cond_stage_model
|
||||
|
||||
apply_optimizations()
|
||||
fix_checkpoint()
|
||||
|
||||
def flatten(el):
|
||||
flattened = [flatten(children) for children in el.children()]
|
||||
@ -86,12 +120,23 @@ class StableDiffusionModelHijack:
|
||||
self.layers = flatten(m)
|
||||
|
||||
def undo_hijack(self, m):
|
||||
if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords:
|
||||
|
||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||
|
||||
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||
|
||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
|
||||
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
|
||||
elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
|
||||
m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
|
||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||
|
||||
self.apply_circular(False)
|
||||
self.layers = None
|
||||
self.clip = None
|
||||
|
||||
def apply_circular(self, enable):
|
||||
if self.circular_enabled == enable:
|
||||
@ -107,263 +152,8 @@ class StableDiffusionModelHijack:
|
||||
|
||||
def tokenize(self, text):
|
||||
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
|
||||
return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count)
|
||||
|
||||
|
||||
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__()
|
||||
self.wrapped = wrapped
|
||||
self.hijack: StableDiffusionModelHijack = hijack
|
||||
self.tokenizer = wrapped.tokenizer
|
||||
self.token_mults = {}
|
||||
|
||||
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
|
||||
|
||||
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
||||
for text, ident in tokens_with_parens:
|
||||
mult = 1.0
|
||||
for c in text:
|
||||
if c == '[':
|
||||
mult /= 1.1
|
||||
if c == ']':
|
||||
mult *= 1.1
|
||||
if c == '(':
|
||||
mult *= 1.1
|
||||
if c == ')':
|
||||
mult /= 1.1
|
||||
|
||||
if mult != 1.0:
|
||||
self.token_mults[ident] = mult
|
||||
|
||||
def tokenize_line(self, line, used_custom_terms, hijack_comments):
|
||||
id_end = self.wrapped.tokenizer.eos_token_id
|
||||
|
||||
if opts.enable_emphasis:
|
||||
parsed = prompt_parser.parse_prompt_attention(line)
|
||||
else:
|
||||
parsed = [[line, 1.0]]
|
||||
|
||||
tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"]
|
||||
|
||||
fixes = []
|
||||
remade_tokens = []
|
||||
multipliers = []
|
||||
last_comma = -1
|
||||
|
||||
for tokens, (text, weight) in zip(tokenized, parsed):
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
token = tokens[i]
|
||||
|
||||
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||
|
||||
if token == self.comma_token:
|
||||
last_comma = len(remade_tokens)
|
||||
elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
|
||||
last_comma += 1
|
||||
reloc_tokens = remade_tokens[last_comma:]
|
||||
reloc_mults = multipliers[last_comma:]
|
||||
|
||||
remade_tokens = remade_tokens[:last_comma]
|
||||
length = len(remade_tokens)
|
||||
|
||||
rem = int(math.ceil(length / 75)) * 75 - length
|
||||
remade_tokens += [id_end] * rem + reloc_tokens
|
||||
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
|
||||
|
||||
if embedding is None:
|
||||
remade_tokens.append(token)
|
||||
multipliers.append(weight)
|
||||
i += 1
|
||||
else:
|
||||
emb_len = int(embedding.vec.shape[0])
|
||||
iteration = len(remade_tokens) // 75
|
||||
if (len(remade_tokens) + emb_len) // 75 != iteration:
|
||||
rem = (75 * (iteration + 1) - len(remade_tokens))
|
||||
remade_tokens += [id_end] * rem
|
||||
multipliers += [1.0] * rem
|
||||
iteration += 1
|
||||
fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
|
||||
remade_tokens += [0] * emb_len
|
||||
multipliers += [weight] * emb_len
|
||||
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||
i += embedding_length_in_tokens
|
||||
|
||||
token_count = len(remade_tokens)
|
||||
prompt_target_length = get_target_prompt_token_count(token_count)
|
||||
tokens_to_add = prompt_target_length - len(remade_tokens)
|
||||
|
||||
remade_tokens = remade_tokens + [id_end] * tokens_to_add
|
||||
multipliers = multipliers + [1.0] * tokens_to_add
|
||||
|
||||
return remade_tokens, fixes, multipliers, token_count
|
||||
|
||||
def process_text(self, texts):
|
||||
used_custom_terms = []
|
||||
remade_batch_tokens = []
|
||||
hijack_comments = []
|
||||
hijack_fixes = []
|
||||
token_count = 0
|
||||
|
||||
cache = {}
|
||||
batch_multipliers = []
|
||||
for line in texts:
|
||||
if line in cache:
|
||||
remade_tokens, fixes, multipliers = cache[line]
|
||||
else:
|
||||
remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
|
||||
token_count = max(current_token_count, token_count)
|
||||
|
||||
cache[line] = (remade_tokens, fixes, multipliers)
|
||||
|
||||
remade_batch_tokens.append(remade_tokens)
|
||||
hijack_fixes.append(fixes)
|
||||
batch_multipliers.append(multipliers)
|
||||
|
||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||
|
||||
|
||||
def process_text_old(self, text):
|
||||
id_start = self.wrapped.tokenizer.bos_token_id
|
||||
id_end = self.wrapped.tokenizer.eos_token_id
|
||||
maxlen = self.wrapped.max_length # you get to stay at 77
|
||||
used_custom_terms = []
|
||||
remade_batch_tokens = []
|
||||
overflowing_words = []
|
||||
hijack_comments = []
|
||||
hijack_fixes = []
|
||||
token_count = 0
|
||||
|
||||
cache = {}
|
||||
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
|
||||
batch_multipliers = []
|
||||
for tokens in batch_tokens:
|
||||
tuple_tokens = tuple(tokens)
|
||||
|
||||
if tuple_tokens in cache:
|
||||
remade_tokens, fixes, multipliers = cache[tuple_tokens]
|
||||
else:
|
||||
fixes = []
|
||||
remade_tokens = []
|
||||
multipliers = []
|
||||
mult = 1.0
|
||||
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
token = tokens[i]
|
||||
|
||||
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||
|
||||
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
|
||||
if mult_change is not None:
|
||||
mult *= mult_change
|
||||
i += 1
|
||||
elif embedding is None:
|
||||
remade_tokens.append(token)
|
||||
multipliers.append(mult)
|
||||
i += 1
|
||||
else:
|
||||
emb_len = int(embedding.vec.shape[0])
|
||||
fixes.append((len(remade_tokens), embedding))
|
||||
remade_tokens += [0] * emb_len
|
||||
multipliers += [mult] * emb_len
|
||||
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||
i += embedding_length_in_tokens
|
||||
|
||||
if len(remade_tokens) > maxlen - 2:
|
||||
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
||||
ovf = remade_tokens[maxlen - 2:]
|
||||
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
||||
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
||||
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||
|
||||
token_count = len(remade_tokens)
|
||||
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
||||
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
|
||||
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
||||
|
||||
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
||||
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
||||
|
||||
remade_batch_tokens.append(remade_tokens)
|
||||
hijack_fixes.append(fixes)
|
||||
batch_multipliers.append(multipliers)
|
||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||
|
||||
def forward(self, text):
|
||||
use_old = opts.use_old_emphasis_implementation
|
||||
if use_old:
|
||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
|
||||
else:
|
||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
|
||||
|
||||
self.hijack.comments += hijack_comments
|
||||
|
||||
if len(used_custom_terms) > 0:
|
||||
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
||||
|
||||
if use_old:
|
||||
self.hijack.fixes = hijack_fixes
|
||||
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
||||
|
||||
z = None
|
||||
i = 0
|
||||
while max(map(len, remade_batch_tokens)) != 0:
|
||||
rem_tokens = [x[75:] for x in remade_batch_tokens]
|
||||
rem_multipliers = [x[75:] for x in batch_multipliers]
|
||||
|
||||
self.hijack.fixes = []
|
||||
for unfiltered in hijack_fixes:
|
||||
fixes = []
|
||||
for fix in unfiltered:
|
||||
if fix[0] == i:
|
||||
fixes.append(fix[1])
|
||||
self.hijack.fixes.append(fixes)
|
||||
|
||||
tokens = []
|
||||
multipliers = []
|
||||
for j in range(len(remade_batch_tokens)):
|
||||
if len(remade_batch_tokens[j]) > 0:
|
||||
tokens.append(remade_batch_tokens[j][:75])
|
||||
multipliers.append(batch_multipliers[j][:75])
|
||||
else:
|
||||
tokens.append([self.wrapped.tokenizer.eos_token_id] * 75)
|
||||
multipliers.append([1.0] * 75)
|
||||
|
||||
z1 = self.process_tokens(tokens, multipliers)
|
||||
z = z1 if z is None else torch.cat((z, z1), axis=-2)
|
||||
|
||||
remade_batch_tokens = rem_tokens
|
||||
batch_multipliers = rem_multipliers
|
||||
i += 1
|
||||
|
||||
return z
|
||||
|
||||
|
||||
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
||||
if not opts.use_old_emphasis_implementation:
|
||||
remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
|
||||
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
|
||||
|
||||
tokens = torch.asarray(remade_batch_tokens).to(device)
|
||||
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
|
||||
|
||||
if opts.CLIP_stop_at_last_layers > 1:
|
||||
z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
|
||||
z = self.wrapped.transformer.text_model.final_layer_norm(z)
|
||||
else:
|
||||
z = outputs.last_hidden_state
|
||||
|
||||
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||
batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
|
||||
batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device)
|
||||
original_mean = z.mean()
|
||||
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
||||
new_mean = z.mean()
|
||||
z *= original_mean / new_mean
|
||||
|
||||
return z
|
||||
return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
|
||||
|
||||
|
||||
class EmbeddingsWithFixes(torch.nn.Module):
|
||||
@ -385,8 +175,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
|
||||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||
for offset, embedding in fixes:
|
||||
emb = embedding.vec
|
||||
emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
|
||||
tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
|
||||
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
||||
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
|
||||
|
||||
vecs.append(tensor)
|
||||
|
||||
@ -403,3 +193,19 @@ def add_circular_option_to_conv_2d():
|
||||
|
||||
|
||||
model_hijack = StableDiffusionModelHijack()
|
||||
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
"""
|
||||
Fix register buffer bug for Mac OS.
|
||||
"""
|
||||
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != devices.device:
|
||||
attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
|
||||
|
||||
setattr(self, name, attr)
|
||||
|
||||
|
||||
ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
|
||||
ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
|
||||
|
10
modules/sd_hijack_checkpoint.py
Normal file
10
modules/sd_hijack_checkpoint.py
Normal file
@ -0,0 +1,10 @@
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
def BasicTransformerBlock_forward(self, x, context=None):
|
||||
return checkpoint(self._forward, x, context)
|
||||
|
||||
def AttentionBlock_forward(self, x):
|
||||
return checkpoint(self._forward, x)
|
||||
|
||||
def ResBlock_forward(self, x, emb):
|
||||
return checkpoint(self._forward, x, emb)
|
303
modules/sd_hijack_clip.py
Normal file
303
modules/sd_hijack_clip.py
Normal file
@ -0,0 +1,303 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from modules import prompt_parser, devices
|
||||
from modules.shared import opts
|
||||
|
||||
def get_target_prompt_token_count(token_count):
|
||||
return math.ceil(max(token_count, 1) / 75) * 75
|
||||
|
||||
|
||||
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__()
|
||||
self.wrapped = wrapped
|
||||
self.hijack = hijack
|
||||
|
||||
def tokenize(self, texts):
|
||||
raise NotImplementedError
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
raise NotImplementedError
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
raise NotImplementedError
|
||||
|
||||
def tokenize_line(self, line, used_custom_terms, hijack_comments):
|
||||
if opts.enable_emphasis:
|
||||
parsed = prompt_parser.parse_prompt_attention(line)
|
||||
else:
|
||||
parsed = [[line, 1.0]]
|
||||
|
||||
tokenized = self.tokenize([text for text, _ in parsed])
|
||||
|
||||
fixes = []
|
||||
remade_tokens = []
|
||||
multipliers = []
|
||||
last_comma = -1
|
||||
|
||||
for tokens, (text, weight) in zip(tokenized, parsed):
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
token = tokens[i]
|
||||
|
||||
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||
|
||||
if token == self.comma_token:
|
||||
last_comma = len(remade_tokens)
|
||||
elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
|
||||
last_comma += 1
|
||||
reloc_tokens = remade_tokens[last_comma:]
|
||||
reloc_mults = multipliers[last_comma:]
|
||||
|
||||
remade_tokens = remade_tokens[:last_comma]
|
||||
length = len(remade_tokens)
|
||||
|
||||
rem = int(math.ceil(length / 75)) * 75 - length
|
||||
remade_tokens += [self.id_end] * rem + reloc_tokens
|
||||
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
|
||||
|
||||
if embedding is None:
|
||||
remade_tokens.append(token)
|
||||
multipliers.append(weight)
|
||||
i += 1
|
||||
else:
|
||||
emb_len = int(embedding.vec.shape[0])
|
||||
iteration = len(remade_tokens) // 75
|
||||
if (len(remade_tokens) + emb_len) // 75 != iteration:
|
||||
rem = (75 * (iteration + 1) - len(remade_tokens))
|
||||
remade_tokens += [self.id_end] * rem
|
||||
multipliers += [1.0] * rem
|
||||
iteration += 1
|
||||
fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
|
||||
remade_tokens += [0] * emb_len
|
||||
multipliers += [weight] * emb_len
|
||||
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||
i += embedding_length_in_tokens
|
||||
|
||||
token_count = len(remade_tokens)
|
||||
prompt_target_length = get_target_prompt_token_count(token_count)
|
||||
tokens_to_add = prompt_target_length - len(remade_tokens)
|
||||
|
||||
remade_tokens = remade_tokens + [self.id_end] * tokens_to_add
|
||||
multipliers = multipliers + [1.0] * tokens_to_add
|
||||
|
||||
return remade_tokens, fixes, multipliers, token_count
|
||||
|
||||
def process_text(self, texts):
|
||||
used_custom_terms = []
|
||||
remade_batch_tokens = []
|
||||
hijack_comments = []
|
||||
hijack_fixes = []
|
||||
token_count = 0
|
||||
|
||||
cache = {}
|
||||
batch_multipliers = []
|
||||
for line in texts:
|
||||
if line in cache:
|
||||
remade_tokens, fixes, multipliers = cache[line]
|
||||
else:
|
||||
remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
|
||||
token_count = max(current_token_count, token_count)
|
||||
|
||||
cache[line] = (remade_tokens, fixes, multipliers)
|
||||
|
||||
remade_batch_tokens.append(remade_tokens)
|
||||
hijack_fixes.append(fixes)
|
||||
batch_multipliers.append(multipliers)
|
||||
|
||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||
|
||||
def process_text_old(self, texts):
|
||||
id_start = self.id_start
|
||||
id_end = self.id_end
|
||||
maxlen = self.wrapped.max_length # you get to stay at 77
|
||||
used_custom_terms = []
|
||||
remade_batch_tokens = []
|
||||
hijack_comments = []
|
||||
hijack_fixes = []
|
||||
token_count = 0
|
||||
|
||||
cache = {}
|
||||
batch_tokens = self.tokenize(texts)
|
||||
batch_multipliers = []
|
||||
for tokens in batch_tokens:
|
||||
tuple_tokens = tuple(tokens)
|
||||
|
||||
if tuple_tokens in cache:
|
||||
remade_tokens, fixes, multipliers = cache[tuple_tokens]
|
||||
else:
|
||||
fixes = []
|
||||
remade_tokens = []
|
||||
multipliers = []
|
||||
mult = 1.0
|
||||
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
token = tokens[i]
|
||||
|
||||
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||
|
||||
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
|
||||
if mult_change is not None:
|
||||
mult *= mult_change
|
||||
i += 1
|
||||
elif embedding is None:
|
||||
remade_tokens.append(token)
|
||||
multipliers.append(mult)
|
||||
i += 1
|
||||
else:
|
||||
emb_len = int(embedding.vec.shape[0])
|
||||
fixes.append((len(remade_tokens), embedding))
|
||||
remade_tokens += [0] * emb_len
|
||||
multipliers += [mult] * emb_len
|
||||
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||
i += embedding_length_in_tokens
|
||||
|
||||
if len(remade_tokens) > maxlen - 2:
|
||||
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
||||
ovf = remade_tokens[maxlen - 2:]
|
||||
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
||||
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
||||
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||
|
||||
token_count = len(remade_tokens)
|
||||
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
||||
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
|
||||
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
||||
|
||||
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
||||
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
||||
|
||||
remade_batch_tokens.append(remade_tokens)
|
||||
hijack_fixes.append(fixes)
|
||||
batch_multipliers.append(multipliers)
|
||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||
|
||||
def forward(self, text):
|
||||
use_old = opts.use_old_emphasis_implementation
|
||||
if use_old:
|
||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
|
||||
else:
|
||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
|
||||
|
||||
self.hijack.comments += hijack_comments
|
||||
|
||||
if len(used_custom_terms) > 0:
|
||||
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
||||
|
||||
if use_old:
|
||||
self.hijack.fixes = hijack_fixes
|
||||
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
||||
|
||||
z = None
|
||||
i = 0
|
||||
while max(map(len, remade_batch_tokens)) != 0:
|
||||
rem_tokens = [x[75:] for x in remade_batch_tokens]
|
||||
rem_multipliers = [x[75:] for x in batch_multipliers]
|
||||
|
||||
self.hijack.fixes = []
|
||||
for unfiltered in hijack_fixes:
|
||||
fixes = []
|
||||
for fix in unfiltered:
|
||||
if fix[0] == i:
|
||||
fixes.append(fix[1])
|
||||
self.hijack.fixes.append(fixes)
|
||||
|
||||
tokens = []
|
||||
multipliers = []
|
||||
for j in range(len(remade_batch_tokens)):
|
||||
if len(remade_batch_tokens[j]) > 0:
|
||||
tokens.append(remade_batch_tokens[j][:75])
|
||||
multipliers.append(batch_multipliers[j][:75])
|
||||
else:
|
||||
tokens.append([self.id_end] * 75)
|
||||
multipliers.append([1.0] * 75)
|
||||
|
||||
z1 = self.process_tokens(tokens, multipliers)
|
||||
z = z1 if z is None else torch.cat((z, z1), axis=-2)
|
||||
|
||||
remade_batch_tokens = rem_tokens
|
||||
batch_multipliers = rem_multipliers
|
||||
i += 1
|
||||
|
||||
return z
|
||||
|
||||
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
||||
if not opts.use_old_emphasis_implementation:
|
||||
remade_batch_tokens = [[self.id_start] + x[:75] + [self.id_end] for x in remade_batch_tokens]
|
||||
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
|
||||
|
||||
tokens = torch.asarray(remade_batch_tokens).to(devices.device)
|
||||
|
||||
if self.id_end != self.id_pad:
|
||||
for batch_pos in range(len(remade_batch_tokens)):
|
||||
index = remade_batch_tokens[batch_pos].index(self.id_end)
|
||||
tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad
|
||||
|
||||
z = self.encode_with_transformers(tokens)
|
||||
|
||||
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||
batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
|
||||
batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(devices.device)
|
||||
original_mean = z.mean()
|
||||
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
||||
new_mean = z.mean()
|
||||
z *= original_mean / new_mean
|
||||
|
||||
return z
|
||||
|
||||
|
||||
class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__(wrapped, hijack)
|
||||
self.tokenizer = wrapped.tokenizer
|
||||
|
||||
vocab = self.tokenizer.get_vocab()
|
||||
|
||||
self.comma_token = vocab.get(',</w>', None)
|
||||
|
||||
self.token_mults = {}
|
||||
tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
||||
for text, ident in tokens_with_parens:
|
||||
mult = 1.0
|
||||
for c in text:
|
||||
if c == '[':
|
||||
mult /= 1.1
|
||||
if c == ']':
|
||||
mult *= 1.1
|
||||
if c == '(':
|
||||
mult *= 1.1
|
||||
if c == ')':
|
||||
mult /= 1.1
|
||||
|
||||
if mult != 1.0:
|
||||
self.token_mults[ident] = mult
|
||||
|
||||
self.id_start = self.wrapped.tokenizer.bos_token_id
|
||||
self.id_end = self.wrapped.tokenizer.eos_token_id
|
||||
self.id_pad = self.id_end
|
||||
|
||||
def tokenize(self, texts):
|
||||
tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
|
||||
|
||||
return tokenized
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
|
||||
|
||||
if opts.CLIP_stop_at_last_layers > 1:
|
||||
z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
|
||||
z = self.wrapped.transformer.text_model.final_layer_norm(z)
|
||||
else:
|
||||
z = outputs.last_hidden_state
|
||||
|
||||
return z
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
embedding_layer = self.wrapped.transformer.text_model.embeddings
|
||||
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||
embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
|
||||
|
||||
return embedded
|
111
modules/sd_hijack_inpainting.py
Normal file
111
modules/sd_hijack_inpainting.py
Normal file
@ -0,0 +1,111 @@
|
||||
import os
|
||||
import torch
|
||||
|
||||
from einops import repeat
|
||||
from omegaconf import ListConfig
|
||||
|
||||
import ldm.models.diffusion.ddpm
|
||||
import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
|
||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from ldm.models.diffusion.ddim import DDIMSampler, noise_like
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
def get_model_output(x, t):
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
|
||||
if isinstance(c, dict):
|
||||
assert isinstance(unconditional_conditioning, dict)
|
||||
c_in = dict()
|
||||
for k in c:
|
||||
if isinstance(c[k], list):
|
||||
c_in[k] = [
|
||||
torch.cat([unconditional_conditioning[k][i], c[k][i]])
|
||||
for i in range(len(c[k]))
|
||||
]
|
||||
else:
|
||||
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
||||
else:
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
|
||||
return e_t
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
def get_x_prev_and_pred_x0(e_t, index):
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
if dynamic_threshold is not None:
|
||||
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
e_t = get_model_output(x, t)
|
||||
if len(old_eps) == 0:
|
||||
# Pseudo Improved Euler (2nd order)
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||
e_t_next = get_model_output(x_prev, t_next)
|
||||
e_t_prime = (e_t + e_t_next) / 2
|
||||
elif len(old_eps) == 1:
|
||||
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||
elif len(old_eps) == 2:
|
||||
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||
elif len(old_eps) >= 3:
|
||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||
|
||||
return x_prev, pred_x0, e_t
|
||||
|
||||
|
||||
def should_hijack_inpainting(checkpoint_info):
|
||||
from modules import sd_models
|
||||
|
||||
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
||||
cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()
|
||||
|
||||
return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
|
||||
|
||||
|
||||
def do_inpainting_hijack():
|
||||
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
|
||||
|
||||
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
37
modules/sd_hijack_open_clip.py
Normal file
37
modules/sd_hijack_open_clip.py
Normal file
@ -0,0 +1,37 @@
|
||||
import open_clip.tokenizer
|
||||
import torch
|
||||
|
||||
from modules import sd_hijack_clip, devices
|
||||
from modules.shared import opts
|
||||
|
||||
tokenizer = open_clip.tokenizer._tokenizer
|
||||
|
||||
|
||||
class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__(wrapped, hijack)
|
||||
|
||||
self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
|
||||
self.id_start = tokenizer.encoder["<start_of_text>"]
|
||||
self.id_end = tokenizer.encoder["<end_of_text>"]
|
||||
self.id_pad = 0
|
||||
|
||||
def tokenize(self, texts):
|
||||
assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
|
||||
|
||||
tokenized = [tokenizer.encode(text) for text in texts]
|
||||
|
||||
return tokenized
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
# set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers
|
||||
z = self.wrapped.encode_with_transformer(tokens)
|
||||
|
||||
return z
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
ids = tokenizer.encode(init_text)
|
||||
ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
|
||||
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
|
||||
|
||||
return embedded
|
@ -127,7 +127,7 @@ def check_for_psutil():
|
||||
|
||||
invokeAI_mps_available = check_for_psutil()
|
||||
|
||||
# -- Taken from https://github.com/invoke-ai/InvokeAI --
|
||||
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
|
||||
if invokeAI_mps_available:
|
||||
import psutil
|
||||
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||
@ -152,14 +152,16 @@ def einsum_op_slice_1(q, k, v, slice_size):
|
||||
return r
|
||||
|
||||
def einsum_op_mps_v1(q, k, v):
|
||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||
if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
|
||||
return einsum_op_compvis(q, k, v)
|
||||
else:
|
||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
if slice_size % 4096 == 0:
|
||||
slice_size -= 1
|
||||
return einsum_op_slice_1(q, k, v, slice_size)
|
||||
|
||||
def einsum_op_mps_v2(q, k, v):
|
||||
if mem_total_gb > 8 and q.shape[1] <= 4096:
|
||||
if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
|
||||
return einsum_op_compvis(q, k, v)
|
||||
else:
|
||||
return einsum_op_slice_0(q, k, v, 1)
|
||||
@ -188,7 +190,7 @@ def einsum_op(q, k, v):
|
||||
return einsum_op_cuda(q, k, v)
|
||||
|
||||
if q.device.type == 'mps':
|
||||
if mem_total_gb >= 32:
|
||||
if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
|
||||
return einsum_op_mps_v1(q, k, v)
|
||||
return einsum_op_mps_v2(q, k, v)
|
||||
|
||||
|
30
modules/sd_hijack_unet.py
Normal file
30
modules/sd_hijack_unet.py
Normal file
@ -0,0 +1,30 @@
|
||||
import torch
|
||||
|
||||
|
||||
class TorchHijackForUnet:
|
||||
"""
|
||||
This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
|
||||
this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
|
||||
"""
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item == 'cat':
|
||||
return self.cat
|
||||
|
||||
if hasattr(torch, item):
|
||||
return getattr(torch, item)
|
||||
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
|
||||
|
||||
def cat(self, tensors, *args, **kwargs):
|
||||
if len(tensors) == 2:
|
||||
a, b = tensors
|
||||
if a.shape[-2:] != b.shape[-2:]:
|
||||
a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
|
||||
|
||||
tensors = (a, b)
|
||||
|
||||
return torch.cat(tensors, *args, **kwargs)
|
||||
|
||||
|
||||
th = TorchHijackForUnet()
|
34
modules/sd_hijack_xlmr.py
Normal file
34
modules/sd_hijack_xlmr.py
Normal file
@ -0,0 +1,34 @@
|
||||
import open_clip.tokenizer
|
||||
import torch
|
||||
|
||||
from modules import sd_hijack_clip, devices
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__(wrapped, hijack)
|
||||
|
||||
self.id_start = wrapped.config.bos_token_id
|
||||
self.id_end = wrapped.config.eos_token_id
|
||||
self.id_pad = wrapped.config.pad_token_id
|
||||
|
||||
self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have </w> bits for comma
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
# there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a
|
||||
# trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer
|
||||
# layer to work with - you have to use the last
|
||||
|
||||
attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64)
|
||||
features = self.wrapped(input_ids=tokens, attention_mask=attention_mask)
|
||||
z = features['projection_state']
|
||||
|
||||
return z
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
embedding_layer = self.wrapped.roberta.embeddings
|
||||
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
|
||||
|
||||
return embedded
|
@ -1,26 +1,33 @@
|
||||
import collections
|
||||
import os.path
|
||||
import sys
|
||||
import gc
|
||||
from collections import namedtuple
|
||||
import torch
|
||||
import re
|
||||
import safetensors.torch
|
||||
from omegaconf import OmegaConf
|
||||
from os import mkdir
|
||||
from urllib import request
|
||||
import ldm.modules.midas as midas
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
from modules import shared, modelloader, devices
|
||||
from modules import shared, modelloader, devices, script_callbacks, sd_vae
|
||||
from modules.paths import models_path
|
||||
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
|
||||
|
||||
model_dir = "Stable-diffusion"
|
||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||
|
||||
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config'])
|
||||
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
|
||||
checkpoints_list = {}
|
||||
checkpoints_loaded = collections.OrderedDict()
|
||||
|
||||
try:
|
||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||
|
||||
from transformers import logging
|
||||
from transformers import logging, CLIPModel
|
||||
|
||||
logging.set_verbosity_error()
|
||||
except Exception:
|
||||
@ -32,15 +39,26 @@ def setup_model():
|
||||
os.makedirs(model_path)
|
||||
|
||||
list_models()
|
||||
enable_midas_autodownload()
|
||||
|
||||
|
||||
def checkpoint_tiles():
|
||||
return sorted([x.title for x in checkpoints_list.values()])
|
||||
convert = lambda name: int(name) if name.isdigit() else name.lower()
|
||||
alphanumeric_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
|
||||
return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
|
||||
|
||||
|
||||
def find_checkpoint_config(info):
|
||||
config = os.path.splitext(info.filename)[0] + ".yaml"
|
||||
if os.path.exists(config):
|
||||
return config
|
||||
|
||||
return shared.cmd_opts.config
|
||||
|
||||
|
||||
def list_models():
|
||||
checkpoints_list.clear()
|
||||
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"])
|
||||
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"])
|
||||
|
||||
def modeltitle(path, shorthash):
|
||||
abspath = os.path.abspath(path)
|
||||
@ -63,7 +81,7 @@ def list_models():
|
||||
if os.path.exists(cmd_ckpt):
|
||||
h = model_hash(cmd_ckpt)
|
||||
title, short_model_name = modeltitle(cmd_ckpt, h)
|
||||
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config)
|
||||
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
|
||||
shared.opts.data['sd_model_checkpoint'] = title
|
||||
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
||||
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
||||
@ -71,12 +89,7 @@ def list_models():
|
||||
h = model_hash(filename)
|
||||
title, short_model_name = modeltitle(filename, h)
|
||||
|
||||
basename, _ = os.path.splitext(filename)
|
||||
config = basename + ".yaml"
|
||||
if not os.path.exists(config):
|
||||
config = shared.cmd_opts.config
|
||||
|
||||
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)
|
||||
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
|
||||
|
||||
|
||||
def get_closet_checkpoint_match(searchString):
|
||||
@ -101,18 +114,19 @@ def model_hash(filename):
|
||||
|
||||
def select_checkpoint():
|
||||
model_checkpoint = shared.opts.sd_model_checkpoint
|
||||
|
||||
checkpoint_info = checkpoints_list.get(model_checkpoint, None)
|
||||
if checkpoint_info is not None:
|
||||
return checkpoint_info
|
||||
|
||||
if len(checkpoints_list) == 0:
|
||||
print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
|
||||
print("No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
|
||||
if shared.cmd_opts.ckpt is not None:
|
||||
print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
|
||||
print(f" - directory {model_path}", file=sys.stderr)
|
||||
if shared.cmd_opts.ckpt_dir is not None:
|
||||
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
|
||||
print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
|
||||
print("Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
|
||||
exit(1)
|
||||
|
||||
checkpoint_info = next(iter(checkpoints_list.values()))
|
||||
@ -138,8 +152,8 @@ def transform_checkpoint_dict_key(k):
|
||||
|
||||
|
||||
def get_state_dict_from_checkpoint(pl_sd):
|
||||
if "state_dict" in pl_sd:
|
||||
pl_sd = pl_sd["state_dict"]
|
||||
pl_sd = pl_sd.pop("state_dict", pl_sd)
|
||||
pl_sd.pop("state_dict", None)
|
||||
|
||||
sd = {}
|
||||
for k, v in pl_sd.items():
|
||||
@ -154,64 +168,156 @@ def get_state_dict_from_checkpoint(pl_sd):
|
||||
return pl_sd
|
||||
|
||||
|
||||
def load_model_weights(model, checkpoint_info):
|
||||
checkpoint_file = checkpoint_info.filename
|
||||
sd_model_hash = checkpoint_info.hash
|
||||
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
|
||||
_, extension = os.path.splitext(checkpoint_file)
|
||||
if extension.lower() == ".safetensors":
|
||||
device = map_location or shared.weight_load_location
|
||||
if device is None:
|
||||
device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu"
|
||||
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
|
||||
else:
|
||||
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
|
||||
|
||||
if checkpoint_info not in checkpoints_loaded:
|
||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
|
||||
|
||||
pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
|
||||
if "global_step" in pl_sd:
|
||||
if print_global_state and "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
|
||||
sd = get_state_dict_from_checkpoint(pl_sd)
|
||||
missing, extra = model.load_state_dict(sd, strict=False)
|
||||
return sd
|
||||
|
||||
|
||||
def load_model_weights(model, checkpoint_info, vae_file="auto"):
|
||||
checkpoint_file = checkpoint_info.filename
|
||||
sd_model_hash = checkpoint_info.hash
|
||||
|
||||
cache_enabled = shared.opts.sd_checkpoint_cache > 0
|
||||
|
||||
if cache_enabled and checkpoint_info in checkpoints_loaded:
|
||||
# use checkpoint cache
|
||||
print(f"Loading weights [{sd_model_hash}] from cache")
|
||||
model.load_state_dict(checkpoints_loaded[checkpoint_info])
|
||||
else:
|
||||
# load from file
|
||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
|
||||
|
||||
sd = read_state_dict(checkpoint_file)
|
||||
model.load_state_dict(sd, strict=False)
|
||||
del sd
|
||||
|
||||
if cache_enabled:
|
||||
# cache newly loaded model
|
||||
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
|
||||
|
||||
if shared.cmd_opts.opt_channelslast:
|
||||
model.to(memory_format=torch.channels_last)
|
||||
|
||||
if not shared.cmd_opts.no_half:
|
||||
vae = model.first_stage_model
|
||||
|
||||
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
|
||||
if shared.cmd_opts.no_half_vae:
|
||||
model.first_stage_model = None
|
||||
|
||||
model.half()
|
||||
model.first_stage_model = vae
|
||||
|
||||
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
||||
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
|
||||
|
||||
vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
|
||||
|
||||
if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None:
|
||||
vae_file = shared.cmd_opts.vae_path
|
||||
|
||||
if os.path.exists(vae_file):
|
||||
print(f"Loading VAE weights from: {vae_file}")
|
||||
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
|
||||
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
||||
model.first_stage_model.load_state_dict(vae_dict)
|
||||
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
|
||||
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
|
||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
|
||||
# clean up cache if limit is reached
|
||||
if cache_enabled:
|
||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model
|
||||
checkpoints_loaded.popitem(last=False) # LRU
|
||||
else:
|
||||
print(f"Loading weights [{sd_model_hash}] from cache")
|
||||
checkpoints_loaded.move_to_end(checkpoint_info)
|
||||
model.load_state_dict(checkpoints_loaded[checkpoint_info])
|
||||
|
||||
model.sd_model_hash = sd_model_hash
|
||||
model.sd_model_checkpoint = checkpoint_file
|
||||
model.sd_checkpoint_info = checkpoint_info
|
||||
|
||||
model.logvar = model.logvar.to(devices.device) # fix for training
|
||||
|
||||
def load_model():
|
||||
sd_vae.delete_base_vae()
|
||||
sd_vae.clear_loaded_vae()
|
||||
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
|
||||
sd_vae.load_vae(model, vae_file)
|
||||
|
||||
|
||||
def enable_midas_autodownload():
|
||||
"""
|
||||
Gives the ldm.modules.midas.api.load_model function automatic downloading.
|
||||
|
||||
When the 512-depth-ema model, and other future models like it, is loaded,
|
||||
it calls midas.api.load_model to load the associated midas depth model.
|
||||
This function applies a wrapper to download the model to the correct
|
||||
location automatically.
|
||||
"""
|
||||
|
||||
midas_path = os.path.join(models_path, 'midas')
|
||||
|
||||
# stable-diffusion-stability-ai hard-codes the midas model path to
|
||||
# a location that differs from where other scripts using this model look.
|
||||
# HACK: Overriding the path here.
|
||||
for k, v in midas.api.ISL_PATHS.items():
|
||||
file_name = os.path.basename(v)
|
||||
midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name)
|
||||
|
||||
midas_urls = {
|
||||
"dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
|
||||
"dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
|
||||
"midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt",
|
||||
"midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt",
|
||||
}
|
||||
|
||||
midas.api.load_model_inner = midas.api.load_model
|
||||
|
||||
def load_model_wrapper(model_type):
|
||||
path = midas.api.ISL_PATHS[model_type]
|
||||
if not os.path.exists(path):
|
||||
if not os.path.exists(midas_path):
|
||||
mkdir(midas_path)
|
||||
|
||||
print(f"Downloading midas model weights for {model_type} to {path}")
|
||||
request.urlretrieve(midas_urls[model_type], path)
|
||||
print(f"{model_type} downloaded")
|
||||
|
||||
return midas.api.load_model_inner(model_type)
|
||||
|
||||
midas.api.load_model = load_model_wrapper
|
||||
|
||||
|
||||
def load_model(checkpoint_info=None):
|
||||
from modules import lowvram, sd_hijack
|
||||
checkpoint_info = select_checkpoint()
|
||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||
checkpoint_config = find_checkpoint_config(checkpoint_info)
|
||||
|
||||
if checkpoint_info.config != shared.cmd_opts.config:
|
||||
print(f"Loading config from: {checkpoint_info.config}")
|
||||
if checkpoint_config != shared.cmd_opts.config:
|
||||
print(f"Loading config from: {checkpoint_config}")
|
||||
|
||||
if shared.sd_model:
|
||||
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
||||
shared.sd_model = None
|
||||
gc.collect()
|
||||
devices.torch_gc()
|
||||
|
||||
sd_config = OmegaConf.load(checkpoint_config)
|
||||
|
||||
if should_hijack_inpainting(checkpoint_info):
|
||||
# Hardcoded config for now...
|
||||
sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
|
||||
sd_config.model.params.conditioning_key = "hybrid"
|
||||
sd_config.model.params.unet_config.params.in_channels = 9
|
||||
sd_config.model.params.finetune_keys = None
|
||||
|
||||
if not hasattr(sd_config.model.params, "use_ema"):
|
||||
sd_config.model.params.use_ema = False
|
||||
|
||||
do_inpainting_hijack()
|
||||
|
||||
if shared.cmd_opts.no_half:
|
||||
sd_config.model.params.unet_config.params.use_fp16 = False
|
||||
|
||||
sd_config = OmegaConf.load(checkpoint_info.config)
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
|
||||
load_model_weights(sd_model, checkpoint_info)
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
@ -222,21 +328,34 @@ def load_model():
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
|
||||
sd_model.eval()
|
||||
shared.sd_model = sd_model
|
||||
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
||||
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
|
||||
print("Model loaded.")
|
||||
|
||||
print(f"Model loaded.")
|
||||
return sd_model
|
||||
|
||||
|
||||
def reload_model_weights(sd_model, info=None):
|
||||
def reload_model_weights(sd_model=None, info=None):
|
||||
from modules import lowvram, devices, sd_hijack
|
||||
checkpoint_info = info or select_checkpoint()
|
||||
|
||||
if not sd_model:
|
||||
sd_model = shared.sd_model
|
||||
|
||||
current_checkpoint_info = sd_model.sd_checkpoint_info
|
||||
checkpoint_config = find_checkpoint_config(current_checkpoint_info)
|
||||
|
||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||
return
|
||||
|
||||
if sd_model.sd_checkpoint_info.config != checkpoint_info.config:
|
||||
if checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
|
||||
del sd_model
|
||||
checkpoints_loaded.clear()
|
||||
shared.sd_model = load_model()
|
||||
load_model(checkpoint_info)
|
||||
return shared.sd_model
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
@ -246,12 +365,19 @@ def reload_model_weights(sd_model, info=None):
|
||||
|
||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||
|
||||
try:
|
||||
load_model_weights(sd_model, checkpoint_info)
|
||||
|
||||
except Exception as e:
|
||||
print("Failed to load checkpoint, restoring previous")
|
||||
load_model_weights(sd_model, current_checkpoint_info)
|
||||
raise
|
||||
finally:
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
|
||||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
||||
sd_model.to(devices.device)
|
||||
|
||||
print(f"Weights loaded.")
|
||||
print("Weights loaded.")
|
||||
|
||||
return sd_model
|
||||
|
@ -1,32 +1,41 @@
|
||||
from collections import namedtuple
|
||||
from collections import namedtuple, deque
|
||||
import numpy as np
|
||||
from math import floor
|
||||
import torch
|
||||
import tqdm
|
||||
from PIL import Image
|
||||
import inspect
|
||||
import k_diffusion.sampling
|
||||
import torchsde._brownian.brownian_interval
|
||||
import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
from modules import prompt_parser, devices, processing
|
||||
from modules import prompt_parser, devices, processing, images, sd_vae_approx
|
||||
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
import modules.shared as shared
|
||||
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
||||
|
||||
|
||||
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||
|
||||
samplers_k_diffusion = [
|
||||
('Euler a', 'sample_euler_ancestral', ['k_euler_a'], {}),
|
||||
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
|
||||
('Euler', 'sample_euler', ['k_euler'], {}),
|
||||
('LMS', 'sample_lms', ['k_lms'], {}),
|
||||
('Heun', 'sample_heun', ['k_heun'], {}),
|
||||
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}),
|
||||
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}),
|
||||
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
|
||||
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
|
||||
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
|
||||
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
|
||||
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
|
||||
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
|
||||
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
|
||||
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
||||
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}),
|
||||
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}),
|
||||
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
|
||||
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
|
||||
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
|
||||
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
||||
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
|
||||
]
|
||||
|
||||
samplers_data_k_diffusion = [
|
||||
@ -40,13 +49,21 @@ all_samplers = [
|
||||
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
|
||||
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
|
||||
]
|
||||
all_samplers_map = {x.name: x for x in all_samplers}
|
||||
|
||||
samplers = []
|
||||
samplers_for_img2img = []
|
||||
samplers_map = {}
|
||||
|
||||
|
||||
def create_sampler_with_index(list_of_configs, index, model):
|
||||
config = list_of_configs[index]
|
||||
def create_sampler(name, model):
|
||||
if name is not None:
|
||||
config = all_samplers_map.get(name, None)
|
||||
else:
|
||||
config = all_samplers[0]
|
||||
|
||||
assert config is not None, f'bad sampler name: {name}'
|
||||
|
||||
sampler = config.constructor(model)
|
||||
sampler.config = config
|
||||
|
||||
@ -62,6 +79,12 @@ def set_samplers():
|
||||
samplers = [x for x in all_samplers if x.name not in hidden]
|
||||
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
|
||||
|
||||
samplers_map.clear()
|
||||
for sampler in all_samplers:
|
||||
samplers_map[sampler.name.lower()] = sampler.name
|
||||
for alias in sampler.aliases:
|
||||
samplers_map[alias.lower()] = sampler.name
|
||||
|
||||
|
||||
set_samplers()
|
||||
|
||||
@ -71,6 +94,7 @@ sampler_extra_params = {
|
||||
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
}
|
||||
|
||||
|
||||
def setup_img2img_steps(p, steps=None):
|
||||
if opts.img2img_fix_steps or steps is not None:
|
||||
steps = int((steps or p.steps) / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
|
||||
@ -82,14 +106,34 @@ def setup_img2img_steps(p, steps=None):
|
||||
return steps, t_enc
|
||||
|
||||
|
||||
def sample_to_image(samples):
|
||||
x_sample = processing.decode_first_stage(shared.sd_model, samples[0:1])[0]
|
||||
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
|
||||
|
||||
|
||||
def single_sample_to_image(sample, approximation=None):
|
||||
if approximation is None:
|
||||
approximation = approximation_indexes.get(opts.show_progress_type, 0)
|
||||
|
||||
if approximation == 2:
|
||||
x_sample = sd_vae_approx.cheap_approximation(sample)
|
||||
elif approximation == 1:
|
||||
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
|
||||
else:
|
||||
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
|
||||
|
||||
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
return Image.fromarray(x_sample)
|
||||
|
||||
|
||||
def sample_to_image(samples, index=0, approximation=None):
|
||||
return single_sample_to_image(samples[index], approximation)
|
||||
|
||||
|
||||
def samples_to_image_grid(samples, approximation=None):
|
||||
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
|
||||
|
||||
|
||||
def store_latent(decoded):
|
||||
state.current_latent = decoded
|
||||
|
||||
@ -105,7 +149,8 @@ class InterruptedException(BaseException):
|
||||
class VanillaStableDiffusionSampler:
|
||||
def __init__(self, constructor, sd_model):
|
||||
self.sampler = constructor(sd_model)
|
||||
self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else self.sampler.p_sample_plms
|
||||
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
|
||||
self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
self.init_latent = None
|
||||
@ -117,6 +162,8 @@ class VanillaStableDiffusionSampler:
|
||||
self.config = None
|
||||
self.last_latent = None
|
||||
|
||||
self.conditioning_key = sd_model.model.conditioning_key
|
||||
|
||||
def number_of_needed_noises(self, p):
|
||||
return 0
|
||||
|
||||
@ -136,6 +183,12 @@ class VanillaStableDiffusionSampler:
|
||||
if self.stop_at is not None and self.step > self.stop_at:
|
||||
raise InterruptedException
|
||||
|
||||
# Have to unwrap the inpainting conditioning here to perform pre-processing
|
||||
image_conditioning = None
|
||||
if isinstance(cond, dict):
|
||||
image_conditioning = cond["c_concat"][0]
|
||||
cond = cond["c_crossattn"][0]
|
||||
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
|
||||
|
||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
||||
@ -157,6 +210,12 @@ class VanillaStableDiffusionSampler:
|
||||
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
||||
x_dec = img_orig * self.mask + self.nmask * x_dec
|
||||
|
||||
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
|
||||
# Note that they need to be lists because it just concatenates them later.
|
||||
if image_conditioning is not None:
|
||||
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
|
||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||
|
||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
||||
|
||||
if self.mask is not None:
|
||||
@ -182,39 +241,52 @@ class VanillaStableDiffusionSampler:
|
||||
self.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
|
||||
steps, t_enc = setup_img2img_steps(p, steps)
|
||||
def adjust_steps_if_invalid(self, p, num_steps):
|
||||
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
|
||||
valid_step = 999 / (1000 // num_steps)
|
||||
if valid_step == floor(valid_step):
|
||||
return int(valid_step) + 1
|
||||
|
||||
return num_steps
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
steps, t_enc = setup_img2img_steps(p, steps)
|
||||
steps = self.adjust_steps_if_invalid(p, steps)
|
||||
self.initialize(p)
|
||||
|
||||
# existing code fails with certain step counts, like 9
|
||||
try:
|
||||
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
||||
except Exception:
|
||||
self.sampler.make_schedule(ddim_num_steps=steps+1, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
||||
|
||||
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
||||
|
||||
self.init_latent = x
|
||||
self.last_latent = x
|
||||
self.step = 0
|
||||
|
||||
samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
|
||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||
if image_conditioning is not None:
|
||||
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||
|
||||
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
|
||||
|
||||
return samples
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
self.initialize(p)
|
||||
|
||||
self.init_latent = None
|
||||
self.last_latent = x
|
||||
self.step = 0
|
||||
|
||||
steps = steps or p.steps
|
||||
steps = self.adjust_steps_if_invalid(p, steps or p.steps)
|
||||
|
||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
|
||||
if image_conditioning is not None:
|
||||
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
|
||||
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
|
||||
|
||||
# existing code fails with certain step counts, like 9
|
||||
try:
|
||||
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||
except Exception:
|
||||
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||
|
||||
return samples_ddim
|
||||
|
||||
@ -228,7 +300,17 @@ class CFGDenoiser(torch.nn.Module):
|
||||
self.init_latent = None
|
||||
self.step = 0
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||
denoised = torch.clone(denoised_uncond)
|
||||
|
||||
for i, conds in enumerate(conds_list):
|
||||
for cond_index, weight in conds:
|
||||
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
|
||||
|
||||
return denoised
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
|
||||
if state.interrupted or state.skipped:
|
||||
raise InterruptedException
|
||||
|
||||
@ -239,35 +321,37 @@ class CFGDenoiser(torch.nn.Module):
|
||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||
|
||||
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
|
||||
cfg_denoiser_callback(denoiser_params)
|
||||
x_in = denoiser_params.x
|
||||
image_cond_in = denoiser_params.image_cond
|
||||
sigma_in = denoiser_params.sigma
|
||||
|
||||
if tensor.shape[1] == uncond.shape[1]:
|
||||
cond_in = torch.cat([tensor, uncond])
|
||||
|
||||
if shared.batch_cond_uncond:
|
||||
x_out = self.inner_model(x_in, sigma_in, cond=cond_in)
|
||||
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = a + batch_size
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b])
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
||||
for batch_offset in range(0, tensor.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = min(a + batch_size, tensor.shape[0])
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=tensor[a:b])
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]})
|
||||
|
||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=uncond)
|
||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
|
||||
|
||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||
denoised = torch.clone(denoised_uncond)
|
||||
|
||||
for i, conds in enumerate(conds_list):
|
||||
for cond_index, weight in conds:
|
||||
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||
|
||||
if self.mask is not None:
|
||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||
@ -278,34 +362,63 @@ class CFGDenoiser(torch.nn.Module):
|
||||
|
||||
|
||||
class TorchHijack:
|
||||
def __init__(self, kdiff_sampler):
|
||||
self.kdiff_sampler = kdiff_sampler
|
||||
def __init__(self, sampler_noises):
|
||||
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
|
||||
# implementation.
|
||||
self.sampler_noises = deque(sampler_noises)
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item == 'randn_like':
|
||||
return self.kdiff_sampler.randn_like
|
||||
return self.randn_like
|
||||
|
||||
if hasattr(torch, item):
|
||||
return getattr(torch, item)
|
||||
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
|
||||
|
||||
def randn_like(self, x):
|
||||
if self.sampler_noises:
|
||||
noise = self.sampler_noises.popleft()
|
||||
if noise.shape == x.shape:
|
||||
return noise
|
||||
|
||||
if x.device.type == 'mps':
|
||||
return torch.randn_like(x, device=devices.cpu).to(x.device)
|
||||
else:
|
||||
return torch.randn_like(x)
|
||||
|
||||
|
||||
# MPS fix for randn in torchsde
|
||||
def torchsde_randn(size, dtype, device, seed):
|
||||
if device.type == 'mps':
|
||||
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
|
||||
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
|
||||
else:
|
||||
generator = torch.Generator(device).manual_seed(int(seed))
|
||||
return torch.randn(size, dtype=dtype, device=device, generator=generator)
|
||||
|
||||
|
||||
torchsde._brownian.brownian_interval._randn = torchsde_randn
|
||||
|
||||
|
||||
class KDiffusionSampler:
|
||||
def __init__(self, funcname, sd_model):
|
||||
self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization)
|
||||
denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
||||
|
||||
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
|
||||
self.funcname = funcname
|
||||
self.func = getattr(k_diffusion.sampling, self.funcname)
|
||||
self.extra_params = sampler_extra_params.get(funcname, [])
|
||||
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
||||
self.sampler_noises = None
|
||||
self.sampler_noise_index = 0
|
||||
self.stop_at = None
|
||||
self.eta = None
|
||||
self.default_eta = 1.0
|
||||
self.config = None
|
||||
self.last_latent = None
|
||||
|
||||
self.conditioning_key = sd_model.model.conditioning_key
|
||||
|
||||
def callback_state(self, d):
|
||||
step = d['i']
|
||||
latent = d["denoised"]
|
||||
@ -330,26 +443,13 @@ class KDiffusionSampler:
|
||||
def number_of_needed_noises(self, p):
|
||||
return p.steps
|
||||
|
||||
def randn_like(self, x):
|
||||
noise = self.sampler_noises[self.sampler_noise_index] if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises) else None
|
||||
|
||||
if noise is not None and x.shape == noise.shape:
|
||||
res = noise
|
||||
else:
|
||||
res = torch.randn_like(x)
|
||||
|
||||
self.sampler_noise_index += 1
|
||||
return res
|
||||
|
||||
def initialize(self, p):
|
||||
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
self.model_wrap.step = 0
|
||||
self.sampler_noise_index = 0
|
||||
self.eta = p.eta or opts.eta_ancestral
|
||||
|
||||
if self.sampler_noises is not None:
|
||||
k_diffusion.sampling.torch = TorchHijack(self)
|
||||
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
||||
|
||||
extra_params_kwargs = {}
|
||||
for param_name in self.extra_params:
|
||||
@ -361,16 +461,26 @@ class KDiffusionSampler:
|
||||
|
||||
return extra_params_kwargs
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
|
||||
steps, t_enc = setup_img2img_steps(p, steps)
|
||||
|
||||
def get_sigmas(self, p, steps):
|
||||
if p.sampler_noise_scheduler_override:
|
||||
sigmas = p.sampler_noise_scheduler_override(steps)
|
||||
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
|
||||
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
|
||||
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
||||
|
||||
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
|
||||
else:
|
||||
sigmas = self.model_wrap.get_sigmas(steps)
|
||||
|
||||
if self.config is not None and self.config.options.get('discard_next_to_last_sigma', False):
|
||||
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
||||
|
||||
return sigmas
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
steps, t_enc = setup_img2img_steps(p, steps)
|
||||
|
||||
sigmas = self.get_sigmas(p, steps)
|
||||
|
||||
sigma_sched = sigmas[steps - t_enc - 1:]
|
||||
xi = x + noise * sigma_sched[0]
|
||||
|
||||
@ -388,20 +498,21 @@ class KDiffusionSampler:
|
||||
extra_params_kwargs['sigmas'] = sigma_sched
|
||||
|
||||
self.model_wrap_cfg.init_latent = x
|
||||
self.last_latent = x
|
||||
|
||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': p.cfg_scale
|
||||
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
return samples
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
|
||||
steps = steps or p.steps
|
||||
|
||||
if p.sampler_noise_scheduler_override:
|
||||
sigmas = p.sampler_noise_scheduler_override(steps)
|
||||
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
|
||||
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
|
||||
else:
|
||||
sigmas = self.model_wrap.get_sigmas(steps)
|
||||
sigmas = self.get_sigmas(p, steps)
|
||||
|
||||
x = x * sigmas[0]
|
||||
|
||||
@ -414,7 +525,13 @@ class KDiffusionSampler:
|
||||
else:
|
||||
extra_params_kwargs['sigmas'] = sigmas
|
||||
|
||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
self.last_latent = x
|
||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': p.cfg_scale
|
||||
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
return samples
|
||||
|
||||
|
231
modules/sd_vae.py
Normal file
231
modules/sd_vae.py
Normal file
@ -0,0 +1,231 @@
|
||||
import torch
|
||||
import os
|
||||
import collections
|
||||
from collections import namedtuple
|
||||
from modules import shared, devices, script_callbacks
|
||||
from modules.paths import models_path
|
||||
import glob
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
model_dir = "Stable-diffusion"
|
||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||
vae_dir = "VAE"
|
||||
vae_path = os.path.abspath(os.path.join(models_path, vae_dir))
|
||||
|
||||
|
||||
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
|
||||
|
||||
|
||||
default_vae_dict = {"auto": "auto", "None": None, None: None}
|
||||
default_vae_list = ["auto", "None"]
|
||||
|
||||
|
||||
default_vae_values = [default_vae_dict[x] for x in default_vae_list]
|
||||
vae_dict = dict(default_vae_dict)
|
||||
vae_list = list(default_vae_list)
|
||||
first_load = True
|
||||
|
||||
|
||||
base_vae = None
|
||||
loaded_vae_file = None
|
||||
checkpoint_info = None
|
||||
|
||||
checkpoints_loaded = collections.OrderedDict()
|
||||
|
||||
def get_base_vae(model):
|
||||
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
|
||||
return base_vae
|
||||
return None
|
||||
|
||||
|
||||
def store_base_vae(model):
|
||||
global base_vae, checkpoint_info
|
||||
if checkpoint_info != model.sd_checkpoint_info:
|
||||
assert not loaded_vae_file, "Trying to store non-base VAE!"
|
||||
base_vae = deepcopy(model.first_stage_model.state_dict())
|
||||
checkpoint_info = model.sd_checkpoint_info
|
||||
|
||||
|
||||
def delete_base_vae():
|
||||
global base_vae, checkpoint_info
|
||||
base_vae = None
|
||||
checkpoint_info = None
|
||||
|
||||
|
||||
def restore_base_vae(model):
|
||||
global loaded_vae_file
|
||||
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
|
||||
print("Restoring base VAE")
|
||||
_load_vae_dict(model, base_vae)
|
||||
loaded_vae_file = None
|
||||
delete_base_vae()
|
||||
|
||||
|
||||
def get_filename(filepath):
|
||||
return os.path.splitext(os.path.basename(filepath))[0]
|
||||
|
||||
|
||||
def refresh_vae_list(vae_path=vae_path, model_path=model_path):
|
||||
global vae_dict, vae_list
|
||||
res = {}
|
||||
candidates = [
|
||||
*glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
|
||||
*glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True),
|
||||
*glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True),
|
||||
*glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True)
|
||||
]
|
||||
if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
|
||||
candidates.append(shared.cmd_opts.vae_path)
|
||||
for filepath in candidates:
|
||||
name = get_filename(filepath)
|
||||
res[name] = filepath
|
||||
vae_list.clear()
|
||||
vae_list.extend(default_vae_list)
|
||||
vae_list.extend(list(res.keys()))
|
||||
vae_dict.clear()
|
||||
vae_dict.update(res)
|
||||
vae_dict.update(default_vae_dict)
|
||||
return vae_list
|
||||
|
||||
|
||||
def get_vae_from_settings(vae_file="auto"):
|
||||
# else, we load from settings, if not set to be default
|
||||
if vae_file == "auto" and shared.opts.sd_vae is not None:
|
||||
# if saved VAE settings isn't recognized, fallback to auto
|
||||
vae_file = vae_dict.get(shared.opts.sd_vae, "auto")
|
||||
# if VAE selected but not found, fallback to auto
|
||||
if vae_file not in default_vae_values and not os.path.isfile(vae_file):
|
||||
vae_file = "auto"
|
||||
print(f"Selected VAE doesn't exist: {vae_file}")
|
||||
return vae_file
|
||||
|
||||
|
||||
def resolve_vae(checkpoint_file=None, vae_file="auto"):
|
||||
global first_load, vae_dict, vae_list
|
||||
|
||||
# if vae_file argument is provided, it takes priority, but not saved
|
||||
if vae_file and vae_file not in default_vae_list:
|
||||
if not os.path.isfile(vae_file):
|
||||
print(f"VAE provided as function argument doesn't exist: {vae_file}")
|
||||
vae_file = "auto"
|
||||
# for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
|
||||
if first_load and shared.cmd_opts.vae_path is not None:
|
||||
if os.path.isfile(shared.cmd_opts.vae_path):
|
||||
vae_file = shared.cmd_opts.vae_path
|
||||
shared.opts.data['sd_vae'] = get_filename(vae_file)
|
||||
else:
|
||||
print(f"VAE provided as command line argument doesn't exist: {vae_file}")
|
||||
# fallback to selector in settings, if vae selector not set to act as default fallback
|
||||
if not shared.opts.sd_vae_as_default:
|
||||
vae_file = get_vae_from_settings(vae_file)
|
||||
# vae-path cmd arg takes priority for auto
|
||||
if vae_file == "auto" and shared.cmd_opts.vae_path is not None:
|
||||
if os.path.isfile(shared.cmd_opts.vae_path):
|
||||
vae_file = shared.cmd_opts.vae_path
|
||||
print(f"Using VAE provided as command line argument: {vae_file}")
|
||||
# if still not found, try look for ".vae.pt" beside model
|
||||
model_path = os.path.splitext(checkpoint_file)[0]
|
||||
if vae_file == "auto":
|
||||
vae_file_try = model_path + ".vae.pt"
|
||||
if os.path.isfile(vae_file_try):
|
||||
vae_file = vae_file_try
|
||||
print(f"Using VAE found similar to selected model: {vae_file}")
|
||||
# if still not found, try look for ".vae.ckpt" beside model
|
||||
if vae_file == "auto":
|
||||
vae_file_try = model_path + ".vae.ckpt"
|
||||
if os.path.isfile(vae_file_try):
|
||||
vae_file = vae_file_try
|
||||
print(f"Using VAE found similar to selected model: {vae_file}")
|
||||
# No more fallbacks for auto
|
||||
if vae_file == "auto":
|
||||
vae_file = None
|
||||
# Last check, just because
|
||||
if vae_file and not os.path.exists(vae_file):
|
||||
vae_file = None
|
||||
|
||||
return vae_file
|
||||
|
||||
|
||||
def load_vae(model, vae_file=None):
|
||||
global first_load, vae_dict, vae_list, loaded_vae_file
|
||||
# save_settings = False
|
||||
|
||||
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
|
||||
|
||||
if vae_file:
|
||||
if cache_enabled and vae_file in checkpoints_loaded:
|
||||
# use vae checkpoint cache
|
||||
print(f"Loading VAE weights [{get_filename(vae_file)}] from cache")
|
||||
store_base_vae(model)
|
||||
_load_vae_dict(model, checkpoints_loaded[vae_file])
|
||||
else:
|
||||
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
|
||||
print(f"Loading VAE weights from: {vae_file}")
|
||||
store_base_vae(model)
|
||||
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
|
||||
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||
_load_vae_dict(model, vae_dict_1)
|
||||
|
||||
if cache_enabled:
|
||||
# cache newly loaded vae
|
||||
checkpoints_loaded[vae_file] = vae_dict_1.copy()
|
||||
|
||||
# clean up cache if limit is reached
|
||||
if cache_enabled:
|
||||
while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
|
||||
checkpoints_loaded.popitem(last=False) # LRU
|
||||
|
||||
# If vae used is not in dict, update it
|
||||
# It will be removed on refresh though
|
||||
vae_opt = get_filename(vae_file)
|
||||
if vae_opt not in vae_dict:
|
||||
vae_dict[vae_opt] = vae_file
|
||||
vae_list.append(vae_opt)
|
||||
elif loaded_vae_file:
|
||||
restore_base_vae(model)
|
||||
|
||||
loaded_vae_file = vae_file
|
||||
|
||||
first_load = False
|
||||
|
||||
|
||||
# don't call this from outside
|
||||
def _load_vae_dict(model, vae_dict_1):
|
||||
model.first_stage_model.load_state_dict(vae_dict_1)
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
|
||||
def clear_loaded_vae():
|
||||
global loaded_vae_file
|
||||
loaded_vae_file = None
|
||||
|
||||
def reload_vae_weights(sd_model=None, vae_file="auto"):
|
||||
from modules import lowvram, devices, sd_hijack
|
||||
|
||||
if not sd_model:
|
||||
sd_model = shared.sd_model
|
||||
|
||||
checkpoint_info = sd_model.sd_checkpoint_info
|
||||
checkpoint_file = checkpoint_info.filename
|
||||
vae_file = resolve_vae(checkpoint_file, vae_file=vae_file)
|
||||
|
||||
if loaded_vae_file == vae_file:
|
||||
return
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
else:
|
||||
sd_model.to(devices.cpu)
|
||||
|
||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||
|
||||
load_vae(sd_model, vae_file)
|
||||
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
|
||||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
||||
sd_model.to(devices.device)
|
||||
|
||||
print("VAE Weights loaded.")
|
||||
return sd_model
|
58
modules/sd_vae_approx.py
Normal file
58
modules/sd_vae_approx.py
Normal file
@ -0,0 +1,58 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from modules import devices, paths
|
||||
|
||||
sd_vae_approx_model = None
|
||||
|
||||
|
||||
class VAEApprox(nn.Module):
|
||||
def __init__(self):
|
||||
super(VAEApprox, self).__init__()
|
||||
self.conv1 = nn.Conv2d(4, 8, (7, 7))
|
||||
self.conv2 = nn.Conv2d(8, 16, (5, 5))
|
||||
self.conv3 = nn.Conv2d(16, 32, (3, 3))
|
||||
self.conv4 = nn.Conv2d(32, 64, (3, 3))
|
||||
self.conv5 = nn.Conv2d(64, 32, (3, 3))
|
||||
self.conv6 = nn.Conv2d(32, 16, (3, 3))
|
||||
self.conv7 = nn.Conv2d(16, 8, (3, 3))
|
||||
self.conv8 = nn.Conv2d(8, 3, (3, 3))
|
||||
|
||||
def forward(self, x):
|
||||
extra = 11
|
||||
x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
|
||||
x = nn.functional.pad(x, (extra, extra, extra, extra))
|
||||
|
||||
for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, ]:
|
||||
x = layer(x)
|
||||
x = nn.functional.leaky_relu(x, 0.1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def model():
|
||||
global sd_vae_approx_model
|
||||
|
||||
if sd_vae_approx_model is None:
|
||||
sd_vae_approx_model = VAEApprox()
|
||||
sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt")))
|
||||
sd_vae_approx_model.eval()
|
||||
sd_vae_approx_model.to(devices.device, devices.dtype)
|
||||
|
||||
return sd_vae_approx_model
|
||||
|
||||
|
||||
def cheap_approximation(sample):
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
|
||||
|
||||
coefs = torch.tensor([
|
||||
[0.298, 0.207, 0.208],
|
||||
[0.187, 0.286, 0.173],
|
||||
[-0.158, 0.189, 0.264],
|
||||
[-0.184, -0.271, -0.473],
|
||||
]).to(sample.device)
|
||||
|
||||
x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
|
||||
|
||||
return x_sample
|
@ -3,24 +3,27 @@ import datetime
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
from PIL import Image
|
||||
import gradio as gr
|
||||
import tqdm
|
||||
|
||||
import modules.artists
|
||||
import modules.interrogate
|
||||
import modules.memmon
|
||||
import modules.sd_models
|
||||
import modules.styles
|
||||
import modules.devices as devices
|
||||
from modules import sd_samplers, sd_models, localization
|
||||
from modules.hypernetworks import hypernetwork
|
||||
from modules import localization, sd_vae, extensions, script_loading, errors
|
||||
from modules.paths import models_path, script_path, sd_path
|
||||
|
||||
|
||||
demo = None
|
||||
|
||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||
default_sd_model_file = sd_model_file
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
|
||||
parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",)
|
||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
||||
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||
@ -39,34 +42,35 @@ parser.add_argument("--lowram", action='store_true', help="load stable diffusion
|
||||
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
|
||||
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
||||
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
|
||||
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
|
||||
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
|
||||
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
|
||||
parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
|
||||
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
|
||||
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
|
||||
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
|
||||
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
|
||||
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
|
||||
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET'))
|
||||
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR'))
|
||||
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
|
||||
parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
|
||||
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
|
||||
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
||||
parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator")
|
||||
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
|
||||
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
|
||||
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
|
||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
||||
parser.add_argument("--use-cpu", nargs='+',choices=['all', 'sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'], help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
||||
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
||||
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
||||
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
||||
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
||||
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(script_path, 'ui-config.json'))
|
||||
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
|
||||
parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
|
||||
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
|
||||
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
|
||||
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor")
|
||||
parser.add_argument("--gradio-inpaint-tool", type=str, choices=["sketch", "color-sketch"], default="sketch", help="gradio inpainting editor: can be either sketch to only blur/noise the input, or color-sketch to paint over it")
|
||||
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
|
||||
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
||||
@ -76,12 +80,27 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help=
|
||||
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
|
||||
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
|
||||
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
||||
parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui")
|
||||
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
|
||||
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
|
||||
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests")
|
||||
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
|
||||
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
|
||||
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
|
||||
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
|
||||
parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None)
|
||||
parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
|
||||
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
|
||||
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
|
||||
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
||||
|
||||
script_loading.preload_extensions(extensions.extensions_dir, parser)
|
||||
script_loading.preload_extensions(extensions.extensions_builtin_dir, parser)
|
||||
|
||||
cmd_opts = parser.parse_args()
|
||||
restricted_opts = [
|
||||
|
||||
restricted_opts = {
|
||||
"samples_filename_pattern",
|
||||
"directories_filename_pattern",
|
||||
"outdir_samples",
|
||||
"outdir_txt2img_samples",
|
||||
"outdir_img2img_samples",
|
||||
@ -89,10 +108,23 @@ restricted_opts = [
|
||||
"outdir_grids",
|
||||
"outdir_txt2img_grids",
|
||||
"outdir_save",
|
||||
}
|
||||
|
||||
ui_reorder_categories = [
|
||||
"sampler",
|
||||
"dimensions",
|
||||
"cfg",
|
||||
"seed",
|
||||
"checkboxes",
|
||||
"hires_fix",
|
||||
"batch",
|
||||
"scripts",
|
||||
]
|
||||
|
||||
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
|
||||
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'])
|
||||
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
|
||||
|
||||
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
|
||||
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
|
||||
|
||||
device = devices.device
|
||||
weight_load_location = None if cmd_opts.lowram else "cpu"
|
||||
@ -103,10 +135,12 @@ xformers_available = False
|
||||
config_filename = cmd_opts.ui_settings_file
|
||||
|
||||
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
|
||||
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
||||
hypernetworks = {}
|
||||
loaded_hypernetwork = None
|
||||
|
||||
|
||||
def reload_hypernetworks():
|
||||
from modules.hypernetworks import hypernetwork
|
||||
global hypernetworks
|
||||
|
||||
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
||||
@ -126,6 +160,8 @@ class State:
|
||||
current_image = None
|
||||
current_image_sampling_step = 0
|
||||
textinfo = None
|
||||
time_start = None
|
||||
need_restart = False
|
||||
|
||||
def skip(self):
|
||||
self.skipped = True
|
||||
@ -134,12 +170,67 @@ class State:
|
||||
self.interrupted = True
|
||||
|
||||
def nextjob(self):
|
||||
if opts.show_progress_every_n_steps == -1:
|
||||
self.do_set_current_image()
|
||||
|
||||
self.job_no += 1
|
||||
self.sampling_step = 0
|
||||
self.current_image_sampling_step = 0
|
||||
|
||||
def get_job_timestamp(self):
|
||||
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
|
||||
def dict(self):
|
||||
obj = {
|
||||
"skipped": self.skipped,
|
||||
"interrupted": self.interrupted,
|
||||
"job": self.job,
|
||||
"job_count": self.job_count,
|
||||
"job_timestamp": self.job_timestamp,
|
||||
"job_no": self.job_no,
|
||||
"sampling_step": self.sampling_step,
|
||||
"sampling_steps": self.sampling_steps,
|
||||
}
|
||||
|
||||
return obj
|
||||
|
||||
def begin(self):
|
||||
self.sampling_step = 0
|
||||
self.job_count = -1
|
||||
self.job_no = 0
|
||||
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
self.current_latent = None
|
||||
self.current_image = None
|
||||
self.current_image_sampling_step = 0
|
||||
self.skipped = False
|
||||
self.interrupted = False
|
||||
self.textinfo = None
|
||||
self.time_start = time.time()
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
def end(self):
|
||||
self.job = ""
|
||||
self.job_count = 0
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
"""sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
|
||||
def set_current_image(self):
|
||||
if not parallel_processing_allowed:
|
||||
return
|
||||
|
||||
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0:
|
||||
self.do_set_current_image()
|
||||
|
||||
def do_set_current_image(self):
|
||||
if self.current_latent is None:
|
||||
return
|
||||
|
||||
import modules.sd_samplers
|
||||
if opts.show_progress_grid:
|
||||
self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent)
|
||||
else:
|
||||
self.current_image = modules.sd_samplers.sample_to_image(self.current_latent)
|
||||
|
||||
self.current_image_sampling_step = self.sampling_step
|
||||
|
||||
|
||||
state = State()
|
||||
@ -153,8 +244,6 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
|
||||
|
||||
face_restorers = []
|
||||
|
||||
localization.list_localizations(cmd_opts.localizations_dir)
|
||||
|
||||
|
||||
def realesrgan_models_names():
|
||||
import modules.realesrgan_model
|
||||
@ -162,13 +251,13 @@ def realesrgan_models_names():
|
||||
|
||||
|
||||
class OptionInfo:
|
||||
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False, refresh=None):
|
||||
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None):
|
||||
self.default = default
|
||||
self.label = label
|
||||
self.component = component
|
||||
self.component_args = component_args
|
||||
self.onchange = onchange
|
||||
self.section = None
|
||||
self.section = section
|
||||
self.refresh = refresh
|
||||
|
||||
|
||||
@ -179,6 +268,21 @@ def options_section(section_identifier, options_dict):
|
||||
return options_dict
|
||||
|
||||
|
||||
def list_checkpoint_tiles():
|
||||
import modules.sd_models
|
||||
return modules.sd_models.checkpoint_tiles()
|
||||
|
||||
|
||||
def refresh_checkpoints():
|
||||
import modules.sd_models
|
||||
return modules.sd_models.list_models()
|
||||
|
||||
|
||||
def list_samplers():
|
||||
import modules.sd_samplers
|
||||
return modules.sd_samplers.all_samplers
|
||||
|
||||
|
||||
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
|
||||
|
||||
options_templates = {}
|
||||
@ -186,7 +290,8 @@ options_templates = {}
|
||||
options_templates.update(options_section(('saving-images', "Saving images/grids"), {
|
||||
"samples_save": OptionInfo(True, "Always save all generated images"),
|
||||
"samples_format": OptionInfo('png', 'File format for images'),
|
||||
"samples_filename_pattern": OptionInfo("", "Images filename pattern"),
|
||||
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs),
|
||||
"save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
|
||||
|
||||
"grid_save": OptionInfo(True, "Always save all generated image grids"),
|
||||
"grid_format": OptionInfo('png', 'File format for grids'),
|
||||
@ -198,12 +303,19 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
||||
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
|
||||
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
|
||||
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
|
||||
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
|
||||
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
||||
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
|
||||
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
|
||||
|
||||
"use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
|
||||
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
|
||||
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
|
||||
"do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
|
||||
|
||||
"temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
|
||||
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
|
||||
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
|
||||
@ -221,19 +333,15 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
|
||||
"save_to_dirs": OptionInfo(False, "Save images to a subdirectory"),
|
||||
"grid_save_to_dirs": OptionInfo(False, "Save grids to a subdirectory"),
|
||||
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
|
||||
"directories_filename_pattern": OptionInfo("", "Directory name pattern"),
|
||||
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}),
|
||||
"directories_filename_pattern": OptionInfo("", "Directory name pattern", component_args=hide_dirs),
|
||||
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
||||
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN x4+", "R-ESRGAN x4+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
|
||||
"SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
|
||||
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
||||
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
|
||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
|
||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
||||
"use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('face-restoration', "Face restoration"), {
|
||||
@ -249,31 +357,42 @@ options_templates.update(options_section(('system', "System"), {
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('training', "Training"), {
|
||||
"unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"),
|
||||
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
||||
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
|
||||
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
|
||||
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
||||
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
||||
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
||||
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
|
||||
"training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
|
||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
|
||||
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
|
||||
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }),
|
||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
||||
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
|
||||
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", gr.ColorPicker, {}),
|
||||
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
|
||||
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
||||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
|
||||
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
|
||||
'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
||||
'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
||||
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
||||
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
|
||||
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
|
||||
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
|
||||
@ -286,26 +405,34 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
|
||||
"deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
|
||||
"deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
|
||||
"deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"),
|
||||
"deepbooru_filter_tags": OptionInfo("", "filter out those tags from deepbooru output (separated by comma)"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui', "User interface"), {
|
||||
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
||||
"show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
|
||||
"show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
|
||||
"show_progress_type": OptionInfo("Full", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
|
||||
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
|
||||
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||
"add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
|
||||
"disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
|
||||
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
||||
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
||||
"font": OptionInfo("", "Font for image grids that have text"),
|
||||
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
||||
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
||||
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
||||
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
|
||||
"dimensions_and_batch_together": OptionInfo(True, "Show Witdth/Height and Batch sliders in same row"),
|
||||
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
|
||||
'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/ing2img UI item order"),
|
||||
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
||||
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in sd_samplers.all_samplers]}),
|
||||
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}),
|
||||
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
||||
@ -315,6 +442,12 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
||||
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section((None, "Hidden options"), {
|
||||
"disabled_extensions": OptionInfo([], "Disable those extensions"),
|
||||
}))
|
||||
|
||||
options_templates.update()
|
||||
|
||||
|
||||
class Options:
|
||||
data = None
|
||||
@ -326,8 +459,19 @@ class Options:
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if self.data is not None:
|
||||
if key in self.data:
|
||||
if key in self.data or key in self.data_labels:
|
||||
assert not cmd_opts.freeze_settings, "changing settings is disabled"
|
||||
|
||||
info = opts.data_labels.get(key, None)
|
||||
comp_args = info.component_args if info else None
|
||||
if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
|
||||
raise RuntimeError(f"not possible to set {key} because it is restricted")
|
||||
|
||||
if cmd_opts.hide_ui_dir_config and key in restricted_opts:
|
||||
raise RuntimeError(f"not possible to set {key} because it is restricted")
|
||||
|
||||
self.data[key] = value
|
||||
return
|
||||
|
||||
return super(Options, self).__setattr__(key, value)
|
||||
|
||||
@ -341,9 +485,33 @@ class Options:
|
||||
|
||||
return super(Options, self).__getattribute__(item)
|
||||
|
||||
def set(self, key, value):
|
||||
"""sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
|
||||
|
||||
oldval = self.data.get(key, None)
|
||||
if oldval == value:
|
||||
return False
|
||||
|
||||
try:
|
||||
setattr(self, key, value)
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
if self.data_labels[key].onchange is not None:
|
||||
try:
|
||||
self.data_labels[key].onchange()
|
||||
except Exception as e:
|
||||
errors.display(e, f"changing setting {key} to {value}")
|
||||
setattr(self, key, oldval)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def save(self, filename):
|
||||
assert not cmd_opts.freeze_settings, "saving settings is disabled"
|
||||
|
||||
with open(filename, "w", encoding="utf8") as file:
|
||||
json.dump(self.data, file)
|
||||
json.dump(self.data, file, indent=4)
|
||||
|
||||
def same_type(self, x, y):
|
||||
if x is None or y is None:
|
||||
@ -368,25 +536,51 @@ class Options:
|
||||
if bad_settings > 0:
|
||||
print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
|
||||
|
||||
def onchange(self, key, func):
|
||||
def onchange(self, key, func, call=True):
|
||||
item = self.data_labels.get(key)
|
||||
item.onchange = func
|
||||
|
||||
if call:
|
||||
func()
|
||||
|
||||
def dumpjson(self):
|
||||
d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
|
||||
return json.dumps(d)
|
||||
|
||||
def add_option(self, key, info):
|
||||
self.data_labels[key] = info
|
||||
|
||||
def reorder(self):
|
||||
"""reorder settings so that all items related to section always go together"""
|
||||
|
||||
section_ids = {}
|
||||
settings_items = self.data_labels.items()
|
||||
for k, item in settings_items:
|
||||
if item.section not in section_ids:
|
||||
section_ids[item.section] = len(section_ids)
|
||||
|
||||
self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])}
|
||||
|
||||
|
||||
opts = Options()
|
||||
if os.path.exists(config_filename):
|
||||
opts.load(config_filename)
|
||||
|
||||
latent_upscale_default_mode = "Latent"
|
||||
latent_upscale_modes = {
|
||||
"Latent": {"mode": "bilinear", "antialias": False},
|
||||
"Latent (antialiased)": {"mode": "bilinear", "antialias": True},
|
||||
"Latent (bicubic)": {"mode": "bicubic", "antialias": False},
|
||||
"Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True},
|
||||
"Latent (nearest)": {"mode": "nearest", "antialias": False},
|
||||
}
|
||||
|
||||
sd_upscalers = []
|
||||
|
||||
sd_model = None
|
||||
|
||||
clip_model = None
|
||||
|
||||
progress_print_out = sys.stdout
|
||||
|
||||
|
||||
@ -426,3 +620,8 @@ total_tqdm = TotalTQDM()
|
||||
|
||||
mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
|
||||
mem_mon.start()
|
||||
|
||||
|
||||
def listfiles(dirname):
|
||||
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")]
|
||||
return [file for file in filenames if os.path.isfile(file)]
|
||||
|
@ -65,17 +65,6 @@ class StyleDatabase:
|
||||
def apply_negative_styles_to_prompt(self, prompt, styles):
|
||||
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
|
||||
|
||||
def apply_styles(self, p: StableDiffusionProcessing) -> None:
|
||||
if isinstance(p.prompt, list):
|
||||
p.prompt = [self.apply_styles_to_prompt(prompt, p.styles) for prompt in p.prompt]
|
||||
else:
|
||||
p.prompt = self.apply_styles_to_prompt(p.prompt, p.styles)
|
||||
|
||||
if isinstance(p.negative_prompt, list):
|
||||
p.negative_prompt = [self.apply_negative_styles_to_prompt(prompt, p.styles) for prompt in p.negative_prompt]
|
||||
else:
|
||||
p.negative_prompt = self.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)
|
||||
|
||||
def save_styles(self, path: str) -> None:
|
||||
# Write to temporary file first, so we don't nuke the file if something goes wrong
|
||||
fd, temp_path = tempfile.mkstemp(".csv")
|
||||
|
341
modules/textual_inversion/autocrop.py
Normal file
341
modules/textual_inversion/autocrop.py
Normal file
@ -0,0 +1,341 @@
|
||||
import cv2
|
||||
import requests
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from math import log, sqrt
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
GREEN = "#0F0"
|
||||
BLUE = "#00F"
|
||||
RED = "#F00"
|
||||
|
||||
|
||||
def crop_image(im, settings):
|
||||
""" Intelligently crop an image to the subject matter """
|
||||
|
||||
scale_by = 1
|
||||
if is_landscape(im.width, im.height):
|
||||
scale_by = settings.crop_height / im.height
|
||||
elif is_portrait(im.width, im.height):
|
||||
scale_by = settings.crop_width / im.width
|
||||
elif is_square(im.width, im.height):
|
||||
if is_square(settings.crop_width, settings.crop_height):
|
||||
scale_by = settings.crop_width / im.width
|
||||
elif is_landscape(settings.crop_width, settings.crop_height):
|
||||
scale_by = settings.crop_width / im.width
|
||||
elif is_portrait(settings.crop_width, settings.crop_height):
|
||||
scale_by = settings.crop_height / im.height
|
||||
|
||||
im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
|
||||
im_debug = im.copy()
|
||||
|
||||
focus = focal_point(im_debug, settings)
|
||||
|
||||
# take the focal point and turn it into crop coordinates that try to center over the focal
|
||||
# point but then get adjusted back into the frame
|
||||
y_half = int(settings.crop_height / 2)
|
||||
x_half = int(settings.crop_width / 2)
|
||||
|
||||
x1 = focus.x - x_half
|
||||
if x1 < 0:
|
||||
x1 = 0
|
||||
elif x1 + settings.crop_width > im.width:
|
||||
x1 = im.width - settings.crop_width
|
||||
|
||||
y1 = focus.y - y_half
|
||||
if y1 < 0:
|
||||
y1 = 0
|
||||
elif y1 + settings.crop_height > im.height:
|
||||
y1 = im.height - settings.crop_height
|
||||
|
||||
x2 = x1 + settings.crop_width
|
||||
y2 = y1 + settings.crop_height
|
||||
|
||||
crop = [x1, y1, x2, y2]
|
||||
|
||||
results = []
|
||||
|
||||
results.append(im.crop(tuple(crop)))
|
||||
|
||||
if settings.annotate_image:
|
||||
d = ImageDraw.Draw(im_debug)
|
||||
rect = list(crop)
|
||||
rect[2] -= 1
|
||||
rect[3] -= 1
|
||||
d.rectangle(rect, outline=GREEN)
|
||||
results.append(im_debug)
|
||||
if settings.destop_view_image:
|
||||
im_debug.show()
|
||||
|
||||
return results
|
||||
|
||||
def focal_point(im, settings):
|
||||
corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
|
||||
entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else []
|
||||
face_points = image_face_points(im, settings) if settings.face_points_weight > 0 else []
|
||||
|
||||
pois = []
|
||||
|
||||
weight_pref_total = 0
|
||||
if len(corner_points) > 0:
|
||||
weight_pref_total += settings.corner_points_weight
|
||||
if len(entropy_points) > 0:
|
||||
weight_pref_total += settings.entropy_points_weight
|
||||
if len(face_points) > 0:
|
||||
weight_pref_total += settings.face_points_weight
|
||||
|
||||
corner_centroid = None
|
||||
if len(corner_points) > 0:
|
||||
corner_centroid = centroid(corner_points)
|
||||
corner_centroid.weight = settings.corner_points_weight / weight_pref_total
|
||||
pois.append(corner_centroid)
|
||||
|
||||
entropy_centroid = None
|
||||
if len(entropy_points) > 0:
|
||||
entropy_centroid = centroid(entropy_points)
|
||||
entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total
|
||||
pois.append(entropy_centroid)
|
||||
|
||||
face_centroid = None
|
||||
if len(face_points) > 0:
|
||||
face_centroid = centroid(face_points)
|
||||
face_centroid.weight = settings.face_points_weight / weight_pref_total
|
||||
pois.append(face_centroid)
|
||||
|
||||
average_point = poi_average(pois, settings)
|
||||
|
||||
if settings.annotate_image:
|
||||
d = ImageDraw.Draw(im)
|
||||
max_size = min(im.width, im.height) * 0.07
|
||||
if corner_centroid is not None:
|
||||
color = BLUE
|
||||
box = corner_centroid.bounding(max_size * corner_centroid.weight)
|
||||
d.text((box[0], box[1]-15), "Edge: %.02f" % corner_centroid.weight, fill=color)
|
||||
d.ellipse(box, outline=color)
|
||||
if len(corner_points) > 1:
|
||||
for f in corner_points:
|
||||
d.rectangle(f.bounding(4), outline=color)
|
||||
if entropy_centroid is not None:
|
||||
color = "#ff0"
|
||||
box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
|
||||
d.text((box[0], box[1]-15), "Entropy: %.02f" % entropy_centroid.weight, fill=color)
|
||||
d.ellipse(box, outline=color)
|
||||
if len(entropy_points) > 1:
|
||||
for f in entropy_points:
|
||||
d.rectangle(f.bounding(4), outline=color)
|
||||
if face_centroid is not None:
|
||||
color = RED
|
||||
box = face_centroid.bounding(max_size * face_centroid.weight)
|
||||
d.text((box[0], box[1]-15), "Face: %.02f" % face_centroid.weight, fill=color)
|
||||
d.ellipse(box, outline=color)
|
||||
if len(face_points) > 1:
|
||||
for f in face_points:
|
||||
d.rectangle(f.bounding(4), outline=color)
|
||||
|
||||
d.ellipse(average_point.bounding(max_size), outline=GREEN)
|
||||
|
||||
return average_point
|
||||
|
||||
|
||||
def image_face_points(im, settings):
|
||||
if settings.dnn_model_path is not None:
|
||||
detector = cv2.FaceDetectorYN.create(
|
||||
settings.dnn_model_path,
|
||||
"",
|
||||
(im.width, im.height),
|
||||
0.9, # score threshold
|
||||
0.3, # nms threshold
|
||||
5000 # keep top k before nms
|
||||
)
|
||||
faces = detector.detect(np.array(im))
|
||||
results = []
|
||||
if faces[1] is not None:
|
||||
for face in faces[1]:
|
||||
x = face[0]
|
||||
y = face[1]
|
||||
w = face[2]
|
||||
h = face[3]
|
||||
results.append(
|
||||
PointOfInterest(
|
||||
int(x + (w * 0.5)), # face focus left/right is center
|
||||
int(y + (h * 0.33)), # face focus up/down is close to the top of the head
|
||||
size = w,
|
||||
weight = 1/len(faces[1])
|
||||
)
|
||||
)
|
||||
return results
|
||||
else:
|
||||
np_im = np.array(im)
|
||||
gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
tries = [
|
||||
[ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ],
|
||||
[ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ],
|
||||
[ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ],
|
||||
[ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ],
|
||||
[ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ],
|
||||
[ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ],
|
||||
[ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ],
|
||||
[ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ]
|
||||
]
|
||||
for t in tries:
|
||||
classifier = cv2.CascadeClassifier(t[0])
|
||||
minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
|
||||
try:
|
||||
faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
|
||||
minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
|
||||
except:
|
||||
continue
|
||||
|
||||
if len(faces) > 0:
|
||||
rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
|
||||
return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects]
|
||||
return []
|
||||
|
||||
|
||||
def image_corner_points(im, settings):
|
||||
grayscale = im.convert("L")
|
||||
|
||||
# naive attempt at preventing focal points from collecting at watermarks near the bottom
|
||||
gd = ImageDraw.Draw(grayscale)
|
||||
gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999")
|
||||
|
||||
np_im = np.array(grayscale)
|
||||
|
||||
points = cv2.goodFeaturesToTrack(
|
||||
np_im,
|
||||
maxCorners=100,
|
||||
qualityLevel=0.04,
|
||||
minDistance=min(grayscale.width, grayscale.height)*0.06,
|
||||
useHarrisDetector=False,
|
||||
)
|
||||
|
||||
if points is None:
|
||||
return []
|
||||
|
||||
focal_points = []
|
||||
for point in points:
|
||||
x, y = point.ravel()
|
||||
focal_points.append(PointOfInterest(x, y, size=4, weight=1/len(points)))
|
||||
|
||||
return focal_points
|
||||
|
||||
|
||||
def image_entropy_points(im, settings):
|
||||
landscape = im.height < im.width
|
||||
portrait = im.height > im.width
|
||||
if landscape:
|
||||
move_idx = [0, 2]
|
||||
move_max = im.size[0]
|
||||
elif portrait:
|
||||
move_idx = [1, 3]
|
||||
move_max = im.size[1]
|
||||
else:
|
||||
return []
|
||||
|
||||
e_max = 0
|
||||
crop_current = [0, 0, settings.crop_width, settings.crop_height]
|
||||
crop_best = crop_current
|
||||
while crop_current[move_idx[1]] < move_max:
|
||||
crop = im.crop(tuple(crop_current))
|
||||
e = image_entropy(crop)
|
||||
|
||||
if (e > e_max):
|
||||
e_max = e
|
||||
crop_best = list(crop_current)
|
||||
|
||||
crop_current[move_idx[0]] += 4
|
||||
crop_current[move_idx[1]] += 4
|
||||
|
||||
x_mid = int(crop_best[0] + settings.crop_width/2)
|
||||
y_mid = int(crop_best[1] + settings.crop_height/2)
|
||||
|
||||
return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)]
|
||||
|
||||
|
||||
def image_entropy(im):
|
||||
# greyscale image entropy
|
||||
# band = np.asarray(im.convert("L"))
|
||||
band = np.asarray(im.convert("1"), dtype=np.uint8)
|
||||
hist, _ = np.histogram(band, bins=range(0, 256))
|
||||
hist = hist[hist > 0]
|
||||
return -np.log2(hist / hist.sum()).sum()
|
||||
|
||||
def centroid(pois):
|
||||
x = [poi.x for poi in pois]
|
||||
y = [poi.y for poi in pois]
|
||||
return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois))
|
||||
|
||||
|
||||
def poi_average(pois, settings):
|
||||
weight = 0.0
|
||||
x = 0.0
|
||||
y = 0.0
|
||||
for poi in pois:
|
||||
weight += poi.weight
|
||||
x += poi.x * poi.weight
|
||||
y += poi.y * poi.weight
|
||||
avg_x = round(weight and x / weight)
|
||||
avg_y = round(weight and y / weight)
|
||||
|
||||
return PointOfInterest(avg_x, avg_y)
|
||||
|
||||
|
||||
def is_landscape(w, h):
|
||||
return w > h
|
||||
|
||||
|
||||
def is_portrait(w, h):
|
||||
return h > w
|
||||
|
||||
|
||||
def is_square(w, h):
|
||||
return w == h
|
||||
|
||||
|
||||
def download_and_cache_models(dirname):
|
||||
download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
|
||||
model_file_name = 'face_detection_yunet.onnx'
|
||||
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
cache_file = os.path.join(dirname, model_file_name)
|
||||
if not os.path.exists(cache_file):
|
||||
print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
|
||||
response = requests.get(download_url)
|
||||
with open(cache_file, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
if os.path.exists(cache_file):
|
||||
return cache_file
|
||||
return None
|
||||
|
||||
|
||||
class PointOfInterest:
|
||||
def __init__(self, x, y, weight=1.0, size=10):
|
||||
self.x = x
|
||||
self.y = y
|
||||
self.weight = weight
|
||||
self.size = size
|
||||
|
||||
def bounding(self, size):
|
||||
return [
|
||||
self.x - size//2,
|
||||
self.y - size//2,
|
||||
self.x + size//2,
|
||||
self.y + size//2
|
||||
]
|
||||
|
||||
|
||||
class Settings:
|
||||
def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
|
||||
self.crop_width = crop_width
|
||||
self.crop_height = crop_height
|
||||
self.corner_points_weight = corner_points_weight
|
||||
self.entropy_points_weight = entropy_points_weight
|
||||
self.face_points_weight = face_points_weight
|
||||
self.annotate_image = annotate_image
|
||||
self.destop_view_image = False
|
||||
self.dnn_model_path = dnn_model_path
|
@ -3,7 +3,7 @@ import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms
|
||||
|
||||
import random
|
||||
@ -11,25 +11,28 @@ import tqdm
|
||||
from modules import devices, shared
|
||||
import re
|
||||
|
||||
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
|
||||
re_numbers_at_start = re.compile(r"^[-\d]+\s*")
|
||||
|
||||
|
||||
class DatasetEntry:
|
||||
def __init__(self, filename=None, latent=None, filename_text=None):
|
||||
def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None):
|
||||
self.filename = filename
|
||||
self.latent = latent
|
||||
self.filename_text = filename_text
|
||||
self.cond = None
|
||||
self.cond_text = None
|
||||
self.latent_dist = latent_dist
|
||||
self.latent_sample = latent_sample
|
||||
self.cond = cond
|
||||
self.cond_text = cond_text
|
||||
self.pixel_values = pixel_values
|
||||
|
||||
|
||||
class PersonalizedBase(Dataset):
|
||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
|
||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
|
||||
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
|
||||
|
||||
self.placeholder_token = placeholder_token
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||
@ -42,12 +45,19 @@ class PersonalizedBase(Dataset):
|
||||
self.lines = lines
|
||||
|
||||
assert data_root, 'dataset directory not specified'
|
||||
|
||||
cond_model = shared.sd_model.cond_stage_model
|
||||
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
||||
assert os.listdir(data_root), "Dataset directory is empty"
|
||||
|
||||
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
||||
|
||||
|
||||
self.shuffle_tags = shuffle_tags
|
||||
self.tag_drop_out = tag_drop_out
|
||||
|
||||
print("Preparing dataset...")
|
||||
for path in tqdm.tqdm(self.image_paths):
|
||||
if shared.state.interrupted:
|
||||
raise Exception("interrupted")
|
||||
try:
|
||||
image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
|
||||
except Exception:
|
||||
@ -69,53 +79,94 @@ class PersonalizedBase(Dataset):
|
||||
npimage = np.array(image).astype(np.uint8)
|
||||
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
|
||||
torchdata = torch.moveaxis(torchdata, 2, 0)
|
||||
torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
|
||||
latent_sample = None
|
||||
|
||||
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
|
||||
init_latent = init_latent.to(devices.cpu)
|
||||
with devices.autocast():
|
||||
latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
|
||||
|
||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent)
|
||||
if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
|
||||
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
||||
latent_sampling_method = "once"
|
||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
|
||||
elif latent_sampling_method == "deterministic":
|
||||
# Works only for DiagonalGaussianDistribution
|
||||
latent_dist.std = 0
|
||||
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
|
||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
|
||||
elif latent_sampling_method == "random":
|
||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist)
|
||||
|
||||
if include_cond:
|
||||
if not (self.tag_drop_out != 0 or self.shuffle_tags):
|
||||
entry.cond_text = self.create_text(filename_text)
|
||||
|
||||
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
|
||||
with devices.autocast():
|
||||
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
|
||||
|
||||
self.dataset.append(entry)
|
||||
del torchdata
|
||||
del latent_dist
|
||||
del latent_sample
|
||||
|
||||
assert len(self.dataset) > 1, "No images have been found in the dataset."
|
||||
self.length = len(self.dataset) * repeats // batch_size
|
||||
|
||||
self.initial_indexes = np.arange(len(self.dataset))
|
||||
self.indexes = None
|
||||
self.shuffle()
|
||||
|
||||
def shuffle(self):
|
||||
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
|
||||
self.length = len(self.dataset)
|
||||
assert self.length > 0, "No images have been found in the dataset."
|
||||
self.batch_size = min(batch_size, self.length)
|
||||
self.gradient_step = min(gradient_step, self.length // self.batch_size)
|
||||
self.latent_sampling_method = latent_sampling_method
|
||||
|
||||
def create_text(self, filename_text):
|
||||
text = random.choice(self.lines)
|
||||
tags = filename_text.split(',')
|
||||
if self.tag_drop_out != 0:
|
||||
tags = [t for t in tags if random.random() > self.tag_drop_out]
|
||||
if self.shuffle_tags:
|
||||
random.shuffle(tags)
|
||||
text = text.replace("[filewords]", ','.join(tags))
|
||||
text = text.replace("[name]", self.placeholder_token)
|
||||
text = text.replace("[filewords]", filename_text)
|
||||
return text
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, i):
|
||||
res = []
|
||||
|
||||
for j in range(self.batch_size):
|
||||
position = i * self.batch_size + j
|
||||
if position % len(self.indexes) == 0:
|
||||
self.shuffle()
|
||||
|
||||
index = self.indexes[position % len(self.indexes)]
|
||||
entry = self.dataset[index]
|
||||
|
||||
if entry.cond is None:
|
||||
entry = self.dataset[i]
|
||||
if self.tag_drop_out != 0 or self.shuffle_tags:
|
||||
entry.cond_text = self.create_text(entry.filename_text)
|
||||
if self.latent_sampling_method == "random":
|
||||
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
|
||||
return entry
|
||||
|
||||
res.append(entry)
|
||||
class PersonalizedDataLoader(DataLoader):
|
||||
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
|
||||
super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, pin_memory=pin_memory)
|
||||
if latent_sampling_method == "random":
|
||||
self.collate_fn = collate_wrapper_random
|
||||
else:
|
||||
self.collate_fn = collate_wrapper
|
||||
|
||||
return res
|
||||
|
||||
class BatchLoader:
|
||||
def __init__(self, data):
|
||||
self.cond_text = [entry.cond_text for entry in data]
|
||||
self.cond = [entry.cond for entry in data]
|
||||
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
|
||||
#self.emb_index = [entry.emb_index for entry in data]
|
||||
#print(self.latent_sample.device)
|
||||
|
||||
def pin_memory(self):
|
||||
self.latent_sample = self.latent_sample.pin_memory()
|
||||
return self
|
||||
|
||||
def collate_wrapper(batch):
|
||||
return BatchLoader(batch)
|
||||
|
||||
class BatchLoaderRandom(BatchLoader):
|
||||
def __init__(self, data):
|
||||
super().__init__(data)
|
||||
|
||||
def pin_memory(self):
|
||||
return self
|
||||
|
||||
def collate_wrapper_random(batch):
|
||||
return BatchLoaderRandom(batch)
|
@ -5,6 +5,7 @@ import zlib
|
||||
from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
|
||||
from fonts.ttf import Roboto
|
||||
import torch
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
class EmbeddingEncoder(json.JSONEncoder):
|
||||
@ -133,7 +134,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
|
||||
from math import cos
|
||||
|
||||
image = srcimage.copy()
|
||||
|
||||
fontsize = 32
|
||||
if textfont is None:
|
||||
try:
|
||||
textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
|
||||
@ -150,7 +151,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
|
||||
image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
|
||||
|
||||
draw = ImageDraw.Draw(image)
|
||||
fontsize = 32
|
||||
|
||||
font = ImageFont.truetype(textfont, fontsize)
|
||||
padding = 10
|
||||
|
||||
|
@ -4,14 +4,17 @@ import tqdm
|
||||
class LearnScheduleIterator:
|
||||
def __init__(self, learn_rate, max_steps, cur_step=0):
|
||||
"""
|
||||
specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000
|
||||
specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000
|
||||
"""
|
||||
|
||||
pairs = learn_rate.split(',')
|
||||
self.rates = []
|
||||
self.it = 0
|
||||
self.maxit = 0
|
||||
try:
|
||||
for i, pair in enumerate(pairs):
|
||||
if not pair.strip():
|
||||
continue
|
||||
tmp = pair.split(':')
|
||||
if len(tmp) == 2:
|
||||
step = int(tmp[1])
|
||||
@ -28,6 +31,10 @@ class LearnScheduleIterator:
|
||||
self.rates.append((float(tmp[0]), max_steps))
|
||||
self.maxit += 1
|
||||
return
|
||||
assert self.rates
|
||||
except (ValueError, AssertionError):
|
||||
raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.')
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
@ -52,7 +59,7 @@ class LearnRateScheduler:
|
||||
self.finished = False
|
||||
|
||||
def apply(self, optimizer, step_number):
|
||||
if step_number <= self.end_step:
|
||||
if step_number < self.end_step:
|
||||
return
|
||||
|
||||
try:
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user