Merge branch 'dev' into dora-weight-decompose

This commit is contained in:
AUTOMATIC1111 2024-03-16 20:20:02 +03:00 committed by GitHub
commit 8dcb8faf5d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
41 changed files with 761 additions and 327 deletions

View File

@ -11,8 +11,8 @@ jobs:
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name
steps: steps:
- name: Checkout Code - name: Checkout Code
uses: actions/checkout@v3 uses: actions/checkout@v4
- uses: actions/setup-python@v4 - uses: actions/setup-python@v5
with: with:
python-version: 3.11 python-version: 3.11
# NB: there's no cache: pip here since we're not installing anything # NB: there's no cache: pip here since we're not installing anything
@ -29,9 +29,9 @@ jobs:
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name
steps: steps:
- name: Checkout Code - name: Checkout Code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Install Node.js - name: Install Node.js
uses: actions/setup-node@v3 uses: actions/setup-node@v4
with: with:
node-version: 18 node-version: 18
- run: npm i --ci - run: npm i --ci

View File

@ -11,9 +11,9 @@ jobs:
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name
steps: steps:
- name: Checkout Code - name: Checkout Code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Set up Python 3.10 - name: Set up Python 3.10
uses: actions/setup-python@v4 uses: actions/setup-python@v5
with: with:
python-version: 3.10.6 python-version: 3.10.6
cache: pip cache: pip
@ -22,7 +22,7 @@ jobs:
launch.py launch.py
- name: Cache models - name: Cache models
id: cache-models id: cache-models
uses: actions/cache@v3 uses: actions/cache@v4
with: with:
path: models path: models
key: "2023-12-30" key: "2023-12-30"
@ -68,13 +68,13 @@ jobs:
python -m coverage report -i python -m coverage report -i
python -m coverage html -i python -m coverage html -i
- name: Upload main app output - name: Upload main app output
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v4
if: always() if: always()
with: with:
name: output name: output
path: output.txt path: output.txt
- name: Upload coverage HTML - name: Upload coverage HTML
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v4
if: always() if: always()
with: with:
name: htmlcov name: htmlcov

View File

@ -98,6 +98,7 @@ Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-di
- [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) - [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended)
- [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. - [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
- [Intel CPUs, Intel GPUs (both integrated and discrete)](https://github.com/openvinotoolkit/stable-diffusion-webui/wiki/Installation-on-Intel-Silicon) (external wiki page) - [Intel CPUs, Intel GPUs (both integrated and discrete)](https://github.com/openvinotoolkit/stable-diffusion-webui/wiki/Installation-on-Intel-Silicon) (external wiki page)
- [Ascend NPUs](https://github.com/wangshuai09/stable-diffusion-webui/wiki/Install-and-run-on-Ascend-NPUs) (external wiki page)
Alternatively, use online services (like Google Colab): Alternatively, use online services (like Google Colab):

View File

@ -36,13 +36,6 @@ class NetworkModuleOFT(network.NetworkModule):
# self.alpha is unused # self.alpha is unused
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size) self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
# LyCORIS BOFT
if self.oft_blocks.dim() == 4:
self.is_boft = True
self.rescale = weights.w.get('rescale', None)
if self.rescale is not None:
self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1))
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
is_conv = type(self.sd_module) in [torch.nn.Conv2d] is_conv = type(self.sd_module) in [torch.nn.Conv2d]
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
@ -54,6 +47,13 @@ class NetworkModuleOFT(network.NetworkModule):
elif is_other_linear: elif is_other_linear:
self.out_dim = self.sd_module.embed_dim self.out_dim = self.sd_module.embed_dim
# LyCORIS BOFT
if self.oft_blocks.dim() == 4:
self.is_boft = True
self.rescale = weights.w.get('rescale', None)
if self.rescale is not None and not is_other_linear:
self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1))
self.num_blocks = self.dim self.num_blocks = self.dim
self.block_size = self.out_dim // self.dim self.block_size = self.out_dim // self.dim
self.constraint = (0 if self.alpha is None else self.alpha) * self.out_dim self.constraint = (0 if self.alpha is None else self.alpha) * self.out_dim

View File

@ -252,6 +252,7 @@ onUiLoaded(async() => {
let isMoving = false; let isMoving = false;
let mouseX, mouseY; let mouseX, mouseY;
let activeElement; let activeElement;
let interactedWithAltKey = false;
const elements = Object.fromEntries( const elements = Object.fromEntries(
Object.keys(elementIDs).map(id => [ Object.keys(elementIDs).map(id => [
@ -508,6 +509,10 @@ onUiLoaded(async() => {
if (isModifierKey(e, hotkeysConfig.canvas_hotkey_zoom)) { if (isModifierKey(e, hotkeysConfig.canvas_hotkey_zoom)) {
e.preventDefault(); e.preventDefault();
if(hotkeysConfig.canvas_hotkey_zoom === "Alt"){
interactedWithAltKey = true;
}
let zoomPosX, zoomPosY; let zoomPosX, zoomPosY;
let delta = 0.2; let delta = 0.2;
if (elemData[elemId].zoomLevel > 7) { if (elemData[elemId].zoomLevel > 7) {
@ -793,13 +798,17 @@ onUiLoaded(async() => {
targetElement.addEventListener("wheel", e => { targetElement.addEventListener("wheel", e => {
// change zoom level // change zoom level
const operation = e.deltaY > 0 ? "-" : "+"; const operation = (e.deltaY || -e.wheelDelta) > 0 ? "-" : "+";
changeZoomLevel(operation, e); changeZoomLevel(operation, e);
// Handle brush size adjustment with ctrl key pressed // Handle brush size adjustment with ctrl key pressed
if (isModifierKey(e, hotkeysConfig.canvas_hotkey_adjust)) { if (isModifierKey(e, hotkeysConfig.canvas_hotkey_adjust)) {
e.preventDefault(); e.preventDefault();
if(hotkeysConfig.canvas_hotkey_adjust === "Alt"){
interactedWithAltKey = true;
}
// Increase or decrease brush size based on scroll direction // Increase or decrease brush size based on scroll direction
adjustBrushSize(elemId, e.deltaY); adjustBrushSize(elemId, e.deltaY);
} }
@ -839,6 +848,20 @@ onUiLoaded(async() => {
document.addEventListener("keydown", handleMoveKeyDown); document.addEventListener("keydown", handleMoveKeyDown);
document.addEventListener("keyup", handleMoveKeyUp); document.addEventListener("keyup", handleMoveKeyUp);
// Prevent firefox from opening main menu when alt is used as a hotkey for zoom or brush size
function handleAltKeyUp(e) {
if (e.key !== "Alt" || !interactedWithAltKey) {
return;
}
e.preventDefault();
interactedWithAltKey = false;
}
document.addEventListener("keyup", handleAltKeyUp);
// Detect zoom level and update the pan speed. // Detect zoom level and update the pan speed.
function updatePanPosition(movementX, movementY) { function updatePanPosition(movementX, movementY) {
let panSpeed = 2; let panSpeed = 2;

View File

@ -0,0 +1,8 @@
<div class="extra-network-pane-content-dirs">
<div id='{tabname}_{extra_networks_tabname}_dirs' class='extra-network-dirs'>
{dirs_html}
</div>
<div id='{tabname}_{extra_networks_tabname}_cards' class='extra-network-cards'>
{items_html}
</div>
</div>

View File

@ -0,0 +1,8 @@
<div class="extra-network-pane-content-tree resize-handle-row">
<div id='{tabname}_{extra_networks_tabname}_tree' class='extra-network-tree' style='flex-basis: {extra_networks_tree_view_default_width}px'>
{tree_html}
</div>
<div id='{tabname}_{extra_networks_tabname}_cards' class='extra-network-cards' style='flex-grow: 1;'>
{items_html}
</div>
</div>

View File

@ -1,23 +1,53 @@
<div id='{tabname}_{extra_networks_tabname}_pane' class='extra-network-pane'> <div id='{tabname}_{extra_networks_tabname}_pane' class='extra-network-pane {tree_view_div_default_display_class}'>
<div class="extra-network-control" id="{tabname}_{extra_networks_tabname}_controls" style="display:none" > <div class="extra-network-control" id="{tabname}_{extra_networks_tabname}_controls" style="display:none" >
<div class="extra-network-control--search"> <div class="extra-network-control--search">
<input <input
id="{tabname}_{extra_networks_tabname}_extra_search" id="{tabname}_{extra_networks_tabname}_extra_search"
class="extra-network-control--search-text" class="extra-network-control--search-text"
type="search" type="search"
placeholder="Filter files" placeholder="Search"
> >
</div> </div>
<small>Sort: </small>
<div <div
id="{tabname}_{extra_networks_tabname}_extra_sort" id="{tabname}_{extra_networks_tabname}_extra_sort_path"
class="extra-network-control--sort" class="extra-network-control--sort{sort_path_active}"
data-sortmode="{data_sortmode}" data-sortkey="default"
data-sortkey="{data_sortkey}"
title="Sort by path" title="Sort by path"
onclick="extraNetworksControlSortOnClick(event, '{tabname}', '{extra_networks_tabname}');" onclick="extraNetworksControlSortOnClick(event, '{tabname}', '{extra_networks_tabname}');"
> >
<i class="extra-network-control--sort-icon"></i> <i class="extra-network-control--icon extra-network-control--sort-icon"></i>
</div> </div>
<div
id="{tabname}_{extra_networks_tabname}_extra_sort_name"
class="extra-network-control--sort{sort_name_active}"
data-sortkey="name"
title="Sort by name"
onclick="extraNetworksControlSortOnClick(event, '{tabname}', '{extra_networks_tabname}');"
>
<i class="extra-network-control--icon extra-network-control--sort-icon"></i>
</div>
<div
id="{tabname}_{extra_networks_tabname}_extra_sort_date_created"
class="extra-network-control--sort{sort_date_created_active}"
data-sortkey="date_created"
title="Sort by date created"
onclick="extraNetworksControlSortOnClick(event, '{tabname}', '{extra_networks_tabname}');"
>
<i class="extra-network-control--icon extra-network-control--sort-icon"></i>
</div>
<div
id="{tabname}_{extra_networks_tabname}_extra_sort_date_modified"
class="extra-network-control--sort{sort_date_modified_active}"
data-sortkey="date_modified"
title="Sort by date modified"
onclick="extraNetworksControlSortOnClick(event, '{tabname}', '{extra_networks_tabname}');"
>
<i class="extra-network-control--icon extra-network-control--sort-icon"></i>
</div>
<small> </small>
<div <div
id="{tabname}_{extra_networks_tabname}_extra_sort_dir" id="{tabname}_{extra_networks_tabname}_extra_sort_dir"
class="extra-network-control--sort-dir" class="extra-network-control--sort-dir"
@ -25,15 +55,18 @@
title="Sort ascending" title="Sort ascending"
onclick="extraNetworksControlSortDirOnClick(event, '{tabname}', '{extra_networks_tabname}');" onclick="extraNetworksControlSortDirOnClick(event, '{tabname}', '{extra_networks_tabname}');"
> >
<i class="extra-network-control--sort-dir-icon"></i> <i class="extra-network-control--icon extra-network-control--sort-dir-icon"></i>
</div> </div>
<small> </small>
<div <div
id="{tabname}_{extra_networks_tabname}_extra_tree_view" id="{tabname}_{extra_networks_tabname}_extra_tree_view"
class="extra-network-control--tree-view {tree_view_btn_extra_class}" class="extra-network-control--tree-view {tree_view_btn_extra_class}"
title="Enable Tree View" title="Enable Tree View"
onclick="extraNetworksControlTreeViewOnClick(event, '{tabname}', '{extra_networks_tabname}');" onclick="extraNetworksControlTreeViewOnClick(event, '{tabname}', '{extra_networks_tabname}');"
> >
<i class="extra-network-control--tree-view-icon"></i> <i class="extra-network-control--icon extra-network-control--tree-view-icon"></i>
</div> </div>
<div <div
id="{tabname}_{extra_networks_tabname}_extra_refresh" id="{tabname}_{extra_networks_tabname}_extra_refresh"
@ -41,15 +74,8 @@
title="Refresh page" title="Refresh page"
onclick="extraNetworksControlRefreshOnClick(event, '{tabname}', '{extra_networks_tabname}');" onclick="extraNetworksControlRefreshOnClick(event, '{tabname}', '{extra_networks_tabname}');"
> >
<i class="extra-network-control--refresh-icon"></i> <i class="extra-network-control--icon extra-network-control--refresh-icon"></i>
</div>
</div>
<div class="extra-network-pane-content resize-handle-row" style="display: {extra_network_pane_content_default_display};">
<div id='{tabname}_{extra_networks_tabname}_tree' class='extra-network-tree {tree_view_div_extra_class}' style='flex-basis: {extra_networks_tree_view_default_width}px; display: {tree_view_div_default_display};'>
{tree_html}
</div>
<div id='{tabname}_{extra_networks_tabname}_cards' class='extra-network-cards' style='flex-grow: 1;'>
{items_html}
</div> </div>
</div> </div>
{pane_content}
</div> </div>

View File

@ -64,6 +64,14 @@ function keyupEditAttention(event) {
selectionEnd++; selectionEnd++;
} }
// deselect surrounding whitespace
while (text[selectionStart] == " " && selectionStart < selectionEnd) {
selectionStart++;
}
while (text[selectionEnd - 1] == " " && selectionEnd > selectionStart) {
selectionEnd--;
}
target.setSelectionRange(selectionStart, selectionEnd); target.setSelectionRange(selectionStart, selectionEnd);
return true; return true;
} }

View File

@ -39,12 +39,12 @@ function setupExtraNetworksForTab(tabname) {
// tabname_full = {tabname}_{extra_networks_tabname} // tabname_full = {tabname}_{extra_networks_tabname}
var tabname_full = elem.id; var tabname_full = elem.id;
var search = gradioApp().querySelector("#" + tabname_full + "_extra_search"); var search = gradioApp().querySelector("#" + tabname_full + "_extra_search");
var sort_mode = gradioApp().querySelector("#" + tabname_full + "_extra_sort");
var sort_dir = gradioApp().querySelector("#" + tabname_full + "_extra_sort_dir"); var sort_dir = gradioApp().querySelector("#" + tabname_full + "_extra_sort_dir");
var refresh = gradioApp().querySelector("#" + tabname_full + "_extra_refresh"); var refresh = gradioApp().querySelector("#" + tabname_full + "_extra_refresh");
var currentSort = '';
// If any of the buttons above don't exist, we want to skip this iteration of the loop. // If any of the buttons above don't exist, we want to skip this iteration of the loop.
if (!search || !sort_mode || !sort_dir || !refresh) { if (!search || !sort_dir || !refresh) {
return; // `return` is equivalent of `continue` but for forEach loops. return; // `return` is equivalent of `continue` but for forEach loops.
} }
@ -52,7 +52,7 @@ function setupExtraNetworksForTab(tabname) {
var searchTerm = search.value.toLowerCase(); var searchTerm = search.value.toLowerCase();
gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card').forEach(function(elem) { gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card').forEach(function(elem) {
var searchOnly = elem.querySelector('.search_only'); var searchOnly = elem.querySelector('.search_only');
var text = Array.prototype.map.call(elem.querySelectorAll('.search_terms'), function(t) { var text = Array.prototype.map.call(elem.querySelectorAll('.search_terms, .description'), function(t) {
return t.textContent.toLowerCase(); return t.textContent.toLowerCase();
}).join(" "); }).join(" ");
@ -71,42 +71,46 @@ function setupExtraNetworksForTab(tabname) {
}; };
var applySort = function(force) { var applySort = function(force) {
var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card'); var cards = gradioApp().querySelectorAll('#' + tabname_full + ' div.card');
var parent = gradioApp().querySelector('#' + tabname_full + "_cards");
var reverse = sort_dir.dataset.sortdir == "Descending"; var reverse = sort_dir.dataset.sortdir == "Descending";
var sortKey = sort_mode.dataset.sortmode.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name"; var activeSearchElem = gradioApp().querySelector('#' + tabname_full + "_controls .extra-network-control--sort.extra-network-control--enabled");
sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1); var sortKey = activeSearchElem ? activeSearchElem.dataset.sortkey : "default";
var sortKeyStore = sortKey + "-" + (reverse ? "Descending" : "Ascending") + "-" + cards.length; var sortKeyDataField = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1);
var sortKeyStore = sortKey + "-" + sort_dir.dataset.sortdir + "-" + cards.length;
if (sortKeyStore == sort_mode.dataset.sortkey && !force) { if (sortKeyStore == currentSort && !force) {
return; return;
} }
sort_mode.dataset.sortkey = sortKeyStore; currentSort = sortKeyStore;
cards.forEach(function(card) {
card.originalParentElement = card.parentElement;
});
var sortedCards = Array.from(cards); var sortedCards = Array.from(cards);
sortedCards.sort(function(cardA, cardB) { sortedCards.sort(function(cardA, cardB) {
var a = cardA.dataset[sortKey]; var a = cardA.dataset[sortKeyDataField];
var b = cardB.dataset[sortKey]; var b = cardB.dataset[sortKeyDataField];
if (!isNaN(a) && !isNaN(b)) { if (!isNaN(a) && !isNaN(b)) {
return parseInt(a) - parseInt(b); return parseInt(a) - parseInt(b);
} }
return (a < b ? -1 : (a > b ? 1 : 0)); return (a < b ? -1 : (a > b ? 1 : 0));
}); });
if (reverse) { if (reverse) {
sortedCards.reverse(); sortedCards.reverse();
} }
cards.forEach(function(card) {
card.remove(); parent.innerHTML = '';
});
var frag = document.createDocumentFragment();
sortedCards.forEach(function(card) { sortedCards.forEach(function(card) {
card.originalParentElement.appendChild(card); frag.appendChild(card);
}); });
parent.appendChild(frag);
}; };
search.addEventListener("input", applyFilter); search.addEventListener("input", function() {
applyFilter();
});
applySort(); applySort();
applyFilter(); applyFilter();
extraNetworksApplySort[tabname_full] = applySort; extraNetworksApplySort[tabname_full] = applySort;
@ -272,6 +276,15 @@ function saveCardPreview(event, tabname, filename) {
event.preventDefault(); event.preventDefault();
} }
function extraNetworksSearchButton(tabname, extra_networks_tabname, event) {
var searchTextarea = gradioApp().querySelector("#" + tabname + "_" + extra_networks_tabname + "_extra_search");
var button = event.target;
var text = button.classList.contains("search-all") ? "" : button.textContent.trim();
searchTextarea.value = text;
updateInput(searchTextarea);
}
function extraNetworksTreeProcessFileClick(event, btn, tabname, extra_networks_tabname) { function extraNetworksTreeProcessFileClick(event, btn, tabname, extra_networks_tabname) {
/** /**
* Processes `onclick` events when user clicks on files in tree. * Processes `onclick` events when user clicks on files in tree.
@ -383,36 +396,17 @@ function extraNetworksTreeOnClick(event, tabname, extra_networks_tabname) {
} }
function extraNetworksControlSortOnClick(event, tabname, extra_networks_tabname) { function extraNetworksControlSortOnClick(event, tabname, extra_networks_tabname) {
/** /** Handles `onclick` events for Sort Mode buttons. */
* Handles `onclick` events for the Sort Mode button.
* var self = event.currentTarget;
* Modifies the data attributes of the Sort Mode button to cycle between var parent = event.currentTarget.parentElement;
* various sorting modes.
* parent.querySelectorAll('.extra-network-control--sort').forEach(function(x) {
* @param event The generated event. x.classList.remove('extra-network-control--enabled');
* @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. });
* @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc.
*/ self.classList.add('extra-network-control--enabled');
var curr_mode = event.currentTarget.dataset.sortmode;
var el_sort_dir = gradioApp().querySelector("#" + tabname + "_" + extra_networks_tabname + "_extra_sort_dir");
var sort_dir = el_sort_dir.dataset.sortdir;
if (curr_mode == "path") {
event.currentTarget.dataset.sortmode = "name";
event.currentTarget.dataset.sortkey = "sortName-" + sort_dir + "-640";
event.currentTarget.setAttribute("title", "Sort by filename");
} else if (curr_mode == "name") {
event.currentTarget.dataset.sortmode = "date_created";
event.currentTarget.dataset.sortkey = "sortDate_created-" + sort_dir + "-640";
event.currentTarget.setAttribute("title", "Sort by date created");
} else if (curr_mode == "date_created") {
event.currentTarget.dataset.sortmode = "date_modified";
event.currentTarget.dataset.sortkey = "sortDate_modified-" + sort_dir + "-640";
event.currentTarget.setAttribute("title", "Sort by date modified");
} else {
event.currentTarget.dataset.sortmode = "path";
event.currentTarget.dataset.sortkey = "sortPath-" + sort_dir + "-640";
event.currentTarget.setAttribute("title", "Sort by path");
}
applyExtraNetworkSort(tabname + "_" + extra_networks_tabname); applyExtraNetworkSort(tabname + "_" + extra_networks_tabname);
} }
@ -447,27 +441,12 @@ function extraNetworksControlTreeViewOnClick(event, tabname, extra_networks_tabn
* @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc.
* @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. * @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc.
*/ */
const tree = gradioApp().getElementById(tabname + "_" + extra_networks_tabname + "_tree"); var button = event.currentTarget;
const parent = tree.parentElement; button.classList.toggle("extra-network-control--enabled");
let resizeHandle = parent.querySelector('.resize-handle'); var show = !button.classList.contains("extra-network-control--enabled");
tree.classList.toggle("hidden");
if (tree.classList.contains("hidden")) { var pane = gradioApp().getElementById(tabname + "_" + extra_networks_tabname + "_pane");
tree.style.display = 'none'; pane.classList.toggle("extra-network-dirs-hidden", show);
parent.style.display = 'flex';
if (resizeHandle) {
resizeHandle.style.display = 'none';
}
} else {
tree.style.display = 'block';
parent.style.display = 'grid';
if (!resizeHandle) {
setupResizeHandle(parent);
resizeHandle = parent.querySelector('.resize-handle');
}
resizeHandle.style.display = 'block';
}
event.currentTarget.classList.toggle("extra-network-control--enabled");
} }
function extraNetworksControlRefreshOnClick(event, tabname, extra_networks_tabname) { function extraNetworksControlRefreshOnClick(event, tabname, extra_networks_tabname) {

View File

@ -131,19 +131,15 @@ function setupImageForLightbox(e) {
e.style.cursor = 'pointer'; e.style.cursor = 'pointer';
e.style.userSelect = 'none'; e.style.userSelect = 'none';
var isFirefox = navigator.userAgent.toLowerCase().indexOf('firefox') > -1; e.addEventListener('mousedown', function(evt) {
// For Firefox, listening on click first switched to next image then shows the lightbox.
// If you know how to fix this without switching to mousedown event, please.
// For other browsers the event is click to make it possiblr to drag picture.
var event = isFirefox ? 'mousedown' : 'click';
e.addEventListener(event, function(evt) {
if (evt.button == 1) { if (evt.button == 1) {
open(evt.target.src); open(evt.target.src);
evt.preventDefault(); evt.preventDefault();
return; return;
} }
}, true);
e.addEventListener('click', function(evt) {
if (!opts.js_modal_lightbox || evt.button != 0) return; if (!opts.js_modal_lightbox || evt.button != 0) return;
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed); modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed);

View File

@ -79,6 +79,11 @@
parent.minRightColWidth = 0; parent.minRightColWidth = 0;
parent.needHideOnMoblie = false; parent.needHideOnMoblie = false;
} }
if (!leftColTemplate) {
leftColTemplate = '1fr';
}
const gridTemplateColumns = `${leftColTemplate} ${PAD}px ${parent.children[1].style.flexGrow}fr`; const gridTemplateColumns = `${leftColTemplate} ${PAD}px ${parent.children[1].style.flexGrow}fr`;
parent.style.gridTemplateColumns = gridTemplateColumns; parent.style.gridTemplateColumns = gridTemplateColumns;
parent.style.originalGridTemplateColumns = gridTemplateColumns; parent.style.originalGridTemplateColumns = gridTemplateColumns;

View File

@ -136,8 +136,7 @@ function showSubmitInterruptingPlaceholder(tabname) {
function showRestoreProgressButton(tabname, show) { function showRestoreProgressButton(tabname, show) {
var button = gradioApp().getElementById(tabname + "_restore_progress"); var button = gradioApp().getElementById(tabname + "_restore_progress");
if (!button) return; if (!button) return;
button.style.setProperty('display', show ? 'flex' : 'none', 'important');
button.style.display = show ? "flex" : "none";
} }
function submit() { function submit() {
@ -209,6 +208,7 @@ function restoreProgressTxt2img() {
var id = localGet("txt2img_task_id"); var id = localGet("txt2img_task_id");
if (id) { if (id) {
showSubmitInterruptingPlaceholder('txt2img');
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() { requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
showSubmitButtons('txt2img', true); showSubmitButtons('txt2img', true);
}, null, 0); }, null, 0);
@ -223,6 +223,7 @@ function restoreProgressImg2img() {
var id = localGet("img2img_task_id"); var id = localGet("img2img_task_id");
if (id) { if (id) {
showSubmitInterruptingPlaceholder('img2img');
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() { requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() {
showSubmitButtons('img2img', true); showSubmitButtons('img2img', true);
}, null, 0); }, null, 0);

View File

@ -124,3 +124,4 @@ parser.add_argument("--disable-extra-extensions", action='store_true', help="pre
parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui") parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui")
parser.add_argument("--unix-filenames-sanitization", action='store_true', help="allow any symbols except '/' in filenames. May conflict with your browser and file system") parser.add_argument("--unix-filenames-sanitization", action='store_true', help="allow any symbols except '/' in filenames. May conflict with your browser and file system")
parser.add_argument("--filenames-max-length", type=int, default=128, help='maximal length of filenames of saved images. If you override it, it can conflict with your file system') parser.add_argument("--filenames-max-length", type=int, default=128, help='maximal length of filenames of saved images. If you override it, it can conflict with your file system')
parser.add_argument("--no-prompt-history", action='store_true', help="disable read prompt from last generation feature; settings this argument will not create '--data_path/params.txt' file")

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import configparser import configparser
import dataclasses
import os import os
import threading import threading
import re import re
@ -22,6 +23,13 @@ def active():
return [x for x in extensions if x.enabled] return [x for x in extensions if x.enabled]
@dataclasses.dataclass
class CallbackOrderInfo:
name: str
before: list
after: list
class ExtensionMetadata: class ExtensionMetadata:
filename = "metadata.ini" filename = "metadata.ini"
config: configparser.ConfigParser config: configparser.ConfigParser
@ -65,6 +73,22 @@ class ExtensionMetadata:
# both "," and " " are accepted as separator # both "," and " " are accepted as separator
return [x for x in re.split(r"[,\s]+", text.strip()) if x] return [x for x in re.split(r"[,\s]+", text.strip()) if x]
def list_callback_order_instructions(self):
for section in self.config.sections():
if not section.startswith("callbacks/"):
continue
callback_name = section[10:]
if not callback_name.startswith(self.canonical_name):
errors.report(f"Callback order section for extension {self.canonical_name} is referencing the wrong extension: {section}")
continue
before = self.parse_list(self.config.get(section, 'Before', fallback=''))
after = self.parse_list(self.config.get(section, 'After', fallback=''))
yield CallbackOrderInfo(callback_name, before, after)
class Extension: class Extension:
lock = threading.Lock() lock = threading.Lock()
@ -156,6 +180,8 @@ class Extension:
def check_updates(self): def check_updates(self):
repo = Repo(self.path) repo = Repo(self.path)
for fetch in repo.remote().fetch(dry_run=True): for fetch in repo.remote().fetch(dry_run=True):
if self.branch and fetch.name != f'{repo.remote().name}/{self.branch}':
continue
if fetch.flags != fetch.HEAD_UPTODATE: if fetch.flags != fetch.HEAD_UPTODATE:
self.can_update = True self.can_update = True
self.status = "new commits" self.status = "new commits"
@ -186,6 +212,7 @@ class Extension:
def list_extensions(): def list_extensions():
extensions.clear() extensions.clear()
extension_paths.clear()
if shared.cmd_opts.disable_all_extensions: if shared.cmd_opts.disable_all_extensions:
print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***") print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
@ -220,6 +247,7 @@ def list_extensions():
is_builtin = dirname == extensions_builtin_dir is_builtin = dirname == extensions_builtin_dir
extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata) extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata)
extensions.append(extension) extensions.append(extension)
extension_paths[extension.path] = extension
loaded_extensions[canonical_name] = extension loaded_extensions[canonical_name] = extension
# check for requirements # check for requirements
@ -238,4 +266,19 @@ def list_extensions():
continue continue
def find_extension(filename):
parentdir = os.path.dirname(os.path.realpath(filename))
while parentdir != filename:
extension = extension_paths.get(parentdir)
if extension is not None:
return extension
filename = parentdir
parentdir = os.path.dirname(filename)
return None
extensions: list[Extension] = [] extensions: list[Extension] = []
extension_paths: dict[str, Extension] = {}

View File

@ -265,17 +265,6 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
else: else:
prompt += ("" if prompt == "" else "\n") + line prompt += ("" if prompt == "" else "\n") + line
if shared.opts.infotext_styles != "Ignore":
found_styles, prompt, negative_prompt = shared.prompt_styles.extract_styles_from_prompt(prompt, negative_prompt)
if shared.opts.infotext_styles == "Apply":
res["Styles array"] = found_styles
elif shared.opts.infotext_styles == "Apply if any" and found_styles:
res["Styles array"] = found_styles
res["Prompt"] = prompt
res["Negative prompt"] = negative_prompt
for k, v in re_param.findall(lastline): for k, v in re_param.findall(lastline):
try: try:
if v[0] == '"' and v[-1] == '"': if v[0] == '"' and v[-1] == '"':
@ -290,6 +279,26 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
except Exception: except Exception:
print(f"Error parsing \"{k}: {v}\"") print(f"Error parsing \"{k}: {v}\"")
# Extract styles from prompt
if shared.opts.infotext_styles != "Ignore":
found_styles, prompt_no_styles, negative_prompt_no_styles = shared.prompt_styles.extract_styles_from_prompt(prompt, negative_prompt)
same_hr_styles = True
if ("Hires prompt" in res or "Hires negative prompt" in res) and (infotext_ver > infotext_versions.v180_hr_styles if (infotext_ver := infotext_versions.parse_version(res.get("Version"))) else True):
hr_prompt, hr_negative_prompt = res.get("Hires prompt", prompt), res.get("Hires negative prompt", negative_prompt)
hr_found_styles, hr_prompt_no_styles, hr_negative_prompt_no_styles = shared.prompt_styles.extract_styles_from_prompt(hr_prompt, hr_negative_prompt)
if same_hr_styles := found_styles == hr_found_styles:
res["Hires prompt"] = '' if hr_prompt_no_styles == prompt_no_styles else hr_prompt_no_styles
res['Hires negative prompt'] = '' if hr_negative_prompt_no_styles == negative_prompt_no_styles else hr_negative_prompt_no_styles
if same_hr_styles:
prompt, negative_prompt = prompt_no_styles, negative_prompt_no_styles
if (shared.opts.infotext_styles == "Apply if any" and found_styles) or shared.opts.infotext_styles == "Apply":
res['Styles array'] = found_styles
res["Prompt"] = prompt
res["Negative prompt"] = negative_prompt
# Missing CLIP skip means it was set to 1 (the default) # Missing CLIP skip means it was set to 1 (the default)
if "Clip skip" not in res: if "Clip skip" not in res:
res["Clip skip"] = "1" res["Clip skip"] = "1"
@ -462,7 +471,7 @@ def get_override_settings(params, *, skip_fields=None):
def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname): def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
def paste_func(prompt): def paste_func(prompt):
if not prompt and not shared.cmd_opts.hide_ui_dir_config: if not prompt and not shared.cmd_opts.hide_ui_dir_config and not shared.cmd_opts.no_prompt_history:
filename = os.path.join(data_path, "params.txt") filename = os.path.join(data_path, "params.txt")
try: try:
with open(filename, "r", encoding="utf8") as file: with open(filename, "r", encoding="utf8") as file:

View File

@ -6,6 +6,7 @@ import re
v160 = version.parse("1.6.0") v160 = version.parse("1.6.0")
v170_tsnr = version.parse("v1.7.0-225") v170_tsnr = version.parse("v1.7.0-225")
v180 = version.parse("1.8.0") v180 = version.parse("1.8.0")
v180_hr_styles = version.parse("1.8.0-139")
def parse_version(text): def parse_version(text):

View File

@ -240,6 +240,9 @@ class Options:
item_categories = {} item_categories = {}
for item in self.data_labels.values(): for item in self.data_labels.values():
if item.section[0] is None:
continue
category = categories.mapping.get(item.category_id) category = categories.mapping.get(item.category_id)
category = "Uncategorized" if category is None else category.label category = "Uncategorized" if category is None else category.label
if category not in item_categories: if category not in item_categories:

View File

@ -702,7 +702,7 @@ def program_version():
return res return res
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False, index=None, all_negative_prompts=None): def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False, index=None, all_negative_prompts=None, all_hr_prompts=None, all_hr_negative_prompts=None):
if index is None: if index is None:
index = position_in_batch + iteration * p.batch_size index = position_in_batch + iteration * p.batch_size
@ -745,11 +745,18 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"RNG": opts.randn_source if opts.randn_source != "GPU" else None, "RNG": opts.randn_source if opts.randn_source != "GPU" else None,
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond, "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
"Tiling": "True" if p.tiling else None, "Tiling": "True" if p.tiling else None,
"Hires prompt": None, # This is set later, insert here to keep order
"Hires negative prompt": None, # This is set later, insert here to keep order
**p.extra_generation_params, **p.extra_generation_params,
"Version": program_version() if opts.add_version_to_infotext else None, "Version": program_version() if opts.add_version_to_infotext else None,
"User": p.user if opts.add_user_name_to_info else None, "User": p.user if opts.add_user_name_to_info else None,
} }
if all_hr_prompts := all_hr_prompts or getattr(p, 'all_hr_prompts', None):
generation_params['Hires prompt'] = all_hr_prompts[index] if all_hr_prompts[index] != all_prompts[index] else None
if all_hr_negative_prompts := all_hr_negative_prompts or getattr(p, 'all_hr_negative_prompts', None):
generation_params['Hires negative prompt'] = all_hr_negative_prompts[index] if all_hr_negative_prompts[index] != all_negative_prompts[index] else None
generation_params_text = ", ".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in generation_params.items() if v is not None]) generation_params_text = ", ".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in generation_params.items() if v is not None])
prompt_text = p.main_prompt if use_main_prompt else all_prompts[index] prompt_text = p.main_prompt if use_main_prompt else all_prompts[index]
@ -904,7 +911,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
# infotext could be modified by that callback # infotext could be modified by that callback
# Example: a wildcard processed by process_batch sets an extra model # Example: a wildcard processed by process_batch sets an extra model
# strength, which is saved as "Model Strength: 1.0" in the infotext # strength, which is saved as "Model Strength: 1.0" in the infotext
if n == 0: if n == 0 and not cmd_opts.no_prompt_history:
with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file: with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
processed = Processed(p, []) processed = Processed(p, [])
file.write(processed.infotext(p, 0)) file.write(processed.infotext(p, 0))
@ -1194,12 +1201,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name: if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
self.extra_generation_params["Hires sampler"] = self.hr_sampler_name self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
if tuple(self.hr_prompt) != tuple(self.prompt):
self.extra_generation_params["Hires prompt"] = self.hr_prompt
if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest") self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
if self.enable_hr and self.latent_scale_mode is None: if self.enable_hr and self.latent_scale_mode is None:
if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers): if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):

View File

@ -26,6 +26,13 @@ class ScriptStripComments(scripts.Script):
p.main_prompt = strip_comments(p.main_prompt) p.main_prompt = strip_comments(p.main_prompt)
p.main_negative_prompt = strip_comments(p.main_negative_prompt) p.main_negative_prompt = strip_comments(p.main_negative_prompt)
if getattr(p, 'enable_hr', False):
p.all_hr_prompts = [strip_comments(x) for x in p.all_hr_prompts]
p.all_hr_negative_prompts = [strip_comments(x) for x in p.all_hr_negative_prompts]
p.hr_prompt = strip_comments(p.hr_prompt)
p.hr_negative_prompt = strip_comments(p.hr_negative_prompt)
def before_token_counter(params: script_callbacks.BeforeTokenCounterParams): def before_token_counter(params: script_callbacks.BeforeTokenCounterParams):
if not shared.opts.enable_prompt_comments: if not shared.opts.enable_prompt_comments:

View File

@ -1,13 +1,14 @@
from __future__ import annotations
import dataclasses import dataclasses
import inspect import inspect
import os import os
from collections import namedtuple
from typing import Optional, Any from typing import Optional, Any
from fastapi import FastAPI from fastapi import FastAPI
from gradio import Blocks from gradio import Blocks
from modules import errors, timer from modules import errors, timer, extensions, shared, util
def report_exception(c, job): def report_exception(c, job):
@ -116,7 +117,105 @@ class BeforeTokenCounterParams:
is_positive: bool = True is_positive: bool = True
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) @dataclasses.dataclass
class ScriptCallback:
script: str
callback: any
name: str = None
def add_callback(callbacks, fun, *, name=None, category='unknown', filename=None):
if filename is None:
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if stack else 'unknown file'
extension = extensions.find_extension(filename)
extension_name = extension.canonical_name if extension else 'base'
callback_name = f"{extension_name}/{os.path.basename(filename)}/{category}"
if name is not None:
callback_name += f'/{name}'
unique_callback_name = callback_name
for index in range(1000):
existing = any(x.name == unique_callback_name for x in callbacks)
if not existing:
break
unique_callback_name = f'{callback_name}-{index+1}'
callbacks.append(ScriptCallback(filename, fun, unique_callback_name))
def sort_callbacks(category, unordered_callbacks, *, enable_user_sort=True):
callbacks = unordered_callbacks.copy()
callback_lookup = {x.name: x for x in callbacks}
dependencies = {}
order_instructions = {}
for extension in extensions.extensions:
for order_instruction in extension.metadata.list_callback_order_instructions():
if order_instruction.name in callback_lookup:
if order_instruction.name not in order_instructions:
order_instructions[order_instruction.name] = []
order_instructions[order_instruction.name].append(order_instruction)
if order_instructions:
for callback in callbacks:
dependencies[callback.name] = []
for callback in callbacks:
for order_instruction in order_instructions.get(callback.name, []):
for after in order_instruction.after:
if after not in callback_lookup:
continue
dependencies[callback.name].append(after)
for before in order_instruction.before:
if before not in callback_lookup:
continue
dependencies[before].append(callback.name)
sorted_names = util.topological_sort(dependencies)
callbacks = [callback_lookup[x] for x in sorted_names]
if enable_user_sort:
for name in reversed(getattr(shared.opts, 'prioritized_callbacks_' + category, [])):
index = next((i for i, callback in enumerate(callbacks) if callback.name == name), None)
if index is not None:
callbacks.insert(0, callbacks.pop(index))
return callbacks
def ordered_callbacks(category, unordered_callbacks=None, *, enable_user_sort=True):
if unordered_callbacks is None:
unordered_callbacks = callback_map.get('callbacks_' + category, [])
if not enable_user_sort:
return sort_callbacks(category, unordered_callbacks, enable_user_sort=False)
callbacks = ordered_callbacks_map.get(category)
if callbacks is not None and len(callbacks) == len(unordered_callbacks):
return callbacks
callbacks = sort_callbacks(category, unordered_callbacks)
ordered_callbacks_map[category] = callbacks
return callbacks
def enumerate_callbacks():
for category, callbacks in callback_map.items():
if category.startswith('callbacks_'):
category = category[10:]
yield category, callbacks
callback_map = dict( callback_map = dict(
callbacks_app_started=[], callbacks_app_started=[],
callbacks_model_loaded=[], callbacks_model_loaded=[],
@ -141,14 +240,18 @@ callback_map = dict(
callbacks_before_token_counter=[], callbacks_before_token_counter=[],
) )
ordered_callbacks_map = {}
def clear_callbacks(): def clear_callbacks():
for callback_list in callback_map.values(): for callback_list in callback_map.values():
callback_list.clear() callback_list.clear()
ordered_callbacks_map.clear()
def app_started_callback(demo: Optional[Blocks], app: FastAPI): def app_started_callback(demo: Optional[Blocks], app: FastAPI):
for c in callback_map['callbacks_app_started']: for c in ordered_callbacks('app_started'):
try: try:
c.callback(demo, app) c.callback(demo, app)
timer.startup_timer.record(os.path.basename(c.script)) timer.startup_timer.record(os.path.basename(c.script))
@ -157,7 +260,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI):
def app_reload_callback(): def app_reload_callback():
for c in callback_map['callbacks_on_reload']: for c in ordered_callbacks('on_reload'):
try: try:
c.callback() c.callback()
except Exception: except Exception:
@ -165,7 +268,7 @@ def app_reload_callback():
def model_loaded_callback(sd_model): def model_loaded_callback(sd_model):
for c in callback_map['callbacks_model_loaded']: for c in ordered_callbacks('model_loaded'):
try: try:
c.callback(sd_model) c.callback(sd_model)
except Exception: except Exception:
@ -175,7 +278,7 @@ def model_loaded_callback(sd_model):
def ui_tabs_callback(): def ui_tabs_callback():
res = [] res = []
for c in callback_map['callbacks_ui_tabs']: for c in ordered_callbacks('ui_tabs'):
try: try:
res += c.callback() or [] res += c.callback() or []
except Exception: except Exception:
@ -185,7 +288,7 @@ def ui_tabs_callback():
def ui_train_tabs_callback(params: UiTrainTabParams): def ui_train_tabs_callback(params: UiTrainTabParams):
for c in callback_map['callbacks_ui_train_tabs']: for c in ordered_callbacks('ui_train_tabs'):
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
@ -193,7 +296,7 @@ def ui_train_tabs_callback(params: UiTrainTabParams):
def ui_settings_callback(): def ui_settings_callback():
for c in callback_map['callbacks_ui_settings']: for c in ordered_callbacks('ui_settings'):
try: try:
c.callback() c.callback()
except Exception: except Exception:
@ -201,7 +304,7 @@ def ui_settings_callback():
def before_image_saved_callback(params: ImageSaveParams): def before_image_saved_callback(params: ImageSaveParams):
for c in callback_map['callbacks_before_image_saved']: for c in ordered_callbacks('before_image_saved'):
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
@ -209,7 +312,7 @@ def before_image_saved_callback(params: ImageSaveParams):
def image_saved_callback(params: ImageSaveParams): def image_saved_callback(params: ImageSaveParams):
for c in callback_map['callbacks_image_saved']: for c in ordered_callbacks('image_saved'):
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
@ -217,7 +320,7 @@ def image_saved_callback(params: ImageSaveParams):
def extra_noise_callback(params: ExtraNoiseParams): def extra_noise_callback(params: ExtraNoiseParams):
for c in callback_map['callbacks_extra_noise']: for c in ordered_callbacks('extra_noise'):
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
@ -225,7 +328,7 @@ def extra_noise_callback(params: ExtraNoiseParams):
def cfg_denoiser_callback(params: CFGDenoiserParams): def cfg_denoiser_callback(params: CFGDenoiserParams):
for c in callback_map['callbacks_cfg_denoiser']: for c in ordered_callbacks('cfg_denoiser'):
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
@ -233,7 +336,7 @@ def cfg_denoiser_callback(params: CFGDenoiserParams):
def cfg_denoised_callback(params: CFGDenoisedParams): def cfg_denoised_callback(params: CFGDenoisedParams):
for c in callback_map['callbacks_cfg_denoised']: for c in ordered_callbacks('cfg_denoised'):
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
@ -241,7 +344,7 @@ def cfg_denoised_callback(params: CFGDenoisedParams):
def cfg_after_cfg_callback(params: AfterCFGCallbackParams): def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
for c in callback_map['callbacks_cfg_after_cfg']: for c in ordered_callbacks('cfg_after_cfg'):
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
@ -249,7 +352,7 @@ def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
def before_component_callback(component, **kwargs): def before_component_callback(component, **kwargs):
for c in callback_map['callbacks_before_component']: for c in ordered_callbacks('before_component'):
try: try:
c.callback(component, **kwargs) c.callback(component, **kwargs)
except Exception: except Exception:
@ -257,7 +360,7 @@ def before_component_callback(component, **kwargs):
def after_component_callback(component, **kwargs): def after_component_callback(component, **kwargs):
for c in callback_map['callbacks_after_component']: for c in ordered_callbacks('after_component'):
try: try:
c.callback(component, **kwargs) c.callback(component, **kwargs)
except Exception: except Exception:
@ -265,7 +368,7 @@ def after_component_callback(component, **kwargs):
def image_grid_callback(params: ImageGridLoopParams): def image_grid_callback(params: ImageGridLoopParams):
for c in callback_map['callbacks_image_grid']: for c in ordered_callbacks('image_grid'):
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
@ -273,7 +376,7 @@ def image_grid_callback(params: ImageGridLoopParams):
def infotext_pasted_callback(infotext: str, params: dict[str, Any]): def infotext_pasted_callback(infotext: str, params: dict[str, Any]):
for c in callback_map['callbacks_infotext_pasted']: for c in ordered_callbacks('infotext_pasted'):
try: try:
c.callback(infotext, params) c.callback(infotext, params)
except Exception: except Exception:
@ -281,7 +384,7 @@ def infotext_pasted_callback(infotext: str, params: dict[str, Any]):
def script_unloaded_callback(): def script_unloaded_callback():
for c in reversed(callback_map['callbacks_script_unloaded']): for c in reversed(ordered_callbacks('script_unloaded')):
try: try:
c.callback() c.callback()
except Exception: except Exception:
@ -289,7 +392,7 @@ def script_unloaded_callback():
def before_ui_callback(): def before_ui_callback():
for c in reversed(callback_map['callbacks_before_ui']): for c in reversed(ordered_callbacks('before_ui')):
try: try:
c.callback() c.callback()
except Exception: except Exception:
@ -299,7 +402,7 @@ def before_ui_callback():
def list_optimizers_callback(): def list_optimizers_callback():
res = [] res = []
for c in callback_map['callbacks_list_optimizers']: for c in ordered_callbacks('list_optimizers'):
try: try:
c.callback(res) c.callback(res)
except Exception: except Exception:
@ -311,7 +414,7 @@ def list_optimizers_callback():
def list_unets_callback(): def list_unets_callback():
res = [] res = []
for c in callback_map['callbacks_list_unets']: for c in ordered_callbacks('list_unets'):
try: try:
c.callback(res) c.callback(res)
except Exception: except Exception:
@ -321,20 +424,13 @@ def list_unets_callback():
def before_token_counter_callback(params: BeforeTokenCounterParams): def before_token_counter_callback(params: BeforeTokenCounterParams):
for c in callback_map['callbacks_before_token_counter']: for c in ordered_callbacks('before_token_counter'):
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
report_exception(c, 'before_token_counter') report_exception(c, 'before_token_counter')
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if stack else 'unknown file'
callbacks.append(ScriptCallback(filename, fun))
def remove_current_script_callbacks(): def remove_current_script_callbacks():
stack = [x for x in inspect.stack() if x.filename != __file__] stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if stack else 'unknown file' filename = stack[0].filename if stack else 'unknown file'
@ -351,24 +447,24 @@ def remove_callbacks_for_function(callback_func):
callback_list.remove(callback_to_remove) callback_list.remove(callback_to_remove)
def on_app_started(callback): def on_app_started(callback, *, name=None):
"""register a function to be called when the webui started, the gradio `Block` component and """register a function to be called when the webui started, the gradio `Block` component and
fastapi `FastAPI` object are passed as the arguments""" fastapi `FastAPI` object are passed as the arguments"""
add_callback(callback_map['callbacks_app_started'], callback) add_callback(callback_map['callbacks_app_started'], callback, name=name, category='app_started')
def on_before_reload(callback): def on_before_reload(callback, *, name=None):
"""register a function to be called just before the server reloads.""" """register a function to be called just before the server reloads."""
add_callback(callback_map['callbacks_on_reload'], callback) add_callback(callback_map['callbacks_on_reload'], callback, name=name, category='on_reload')
def on_model_loaded(callback): def on_model_loaded(callback, *, name=None):
"""register a function to be called when the stable diffusion model is created; the model is """register a function to be called when the stable diffusion model is created; the model is
passed as an argument; this function is also called when the script is reloaded. """ passed as an argument; this function is also called when the script is reloaded. """
add_callback(callback_map['callbacks_model_loaded'], callback) add_callback(callback_map['callbacks_model_loaded'], callback, name=name, category='model_loaded')
def on_ui_tabs(callback): def on_ui_tabs(callback, *, name=None):
"""register a function to be called when the UI is creating new tabs. """register a function to be called when the UI is creating new tabs.
The function must either return a None, which means no new tabs to be added, or a list, where The function must either return a None, which means no new tabs to be added, or a list, where
each element is a tuple: each element is a tuple:
@ -378,71 +474,71 @@ def on_ui_tabs(callback):
title is tab text displayed to user in the UI title is tab text displayed to user in the UI
elem_id is HTML id for the tab elem_id is HTML id for the tab
""" """
add_callback(callback_map['callbacks_ui_tabs'], callback) add_callback(callback_map['callbacks_ui_tabs'], callback, name=name, category='ui_tabs')
def on_ui_train_tabs(callback): def on_ui_train_tabs(callback, *, name=None):
"""register a function to be called when the UI is creating new tabs for the train tab. """register a function to be called when the UI is creating new tabs for the train tab.
Create your new tabs with gr.Tab. Create your new tabs with gr.Tab.
""" """
add_callback(callback_map['callbacks_ui_train_tabs'], callback) add_callback(callback_map['callbacks_ui_train_tabs'], callback, name=name, category='ui_train_tabs')
def on_ui_settings(callback): def on_ui_settings(callback, *, name=None):
"""register a function to be called before UI settings are populated; add your settings """register a function to be called before UI settings are populated; add your settings
by using shared.opts.add_option(shared.OptionInfo(...)) """ by using shared.opts.add_option(shared.OptionInfo(...)) """
add_callback(callback_map['callbacks_ui_settings'], callback) add_callback(callback_map['callbacks_ui_settings'], callback, name=name, category='ui_settings')
def on_before_image_saved(callback): def on_before_image_saved(callback, *, name=None):
"""register a function to be called before an image is saved to a file. """register a function to be called before an image is saved to a file.
The callback is called with one argument: The callback is called with one argument:
- params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object. - params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
""" """
add_callback(callback_map['callbacks_before_image_saved'], callback) add_callback(callback_map['callbacks_before_image_saved'], callback, name=name, category='before_image_saved')
def on_image_saved(callback): def on_image_saved(callback, *, name=None):
"""register a function to be called after an image is saved to a file. """register a function to be called after an image is saved to a file.
The callback is called with one argument: The callback is called with one argument:
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing. - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
""" """
add_callback(callback_map['callbacks_image_saved'], callback) add_callback(callback_map['callbacks_image_saved'], callback, name=name, category='image_saved')
def on_extra_noise(callback): def on_extra_noise(callback, *, name=None):
"""register a function to be called before adding extra noise in img2img or hires fix; """register a function to be called before adding extra noise in img2img or hires fix;
The callback is called with one argument: The callback is called with one argument:
- params: ExtraNoiseParams - contains noise determined by seed and latent representation of image - params: ExtraNoiseParams - contains noise determined by seed and latent representation of image
""" """
add_callback(callback_map['callbacks_extra_noise'], callback) add_callback(callback_map['callbacks_extra_noise'], callback, name=name, category='extra_noise')
def on_cfg_denoiser(callback): def on_cfg_denoiser(callback, *, name=None):
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs. """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
The callback is called with one argument: The callback is called with one argument:
- params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details. - params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
""" """
add_callback(callback_map['callbacks_cfg_denoiser'], callback) add_callback(callback_map['callbacks_cfg_denoiser'], callback, name=name, category='cfg_denoiser')
def on_cfg_denoised(callback): def on_cfg_denoised(callback, *, name=None):
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs. """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
The callback is called with one argument: The callback is called with one argument:
- params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details. - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details.
""" """
add_callback(callback_map['callbacks_cfg_denoised'], callback) add_callback(callback_map['callbacks_cfg_denoised'], callback, name=name, category='cfg_denoised')
def on_cfg_after_cfg(callback): def on_cfg_after_cfg(callback, *, name=None):
"""register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed. """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed.
The callback is called with one argument: The callback is called with one argument:
- params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation. - params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation.
""" """
add_callback(callback_map['callbacks_cfg_after_cfg'], callback) add_callback(callback_map['callbacks_cfg_after_cfg'], callback, name=name, category='cfg_after_cfg')
def on_before_component(callback): def on_before_component(callback, *, name=None):
"""register a function to be called before a component is created. """register a function to be called before a component is created.
The callback is called with arguments: The callback is called with arguments:
- component - gradio component that is about to be created. - component - gradio component that is about to be created.
@ -451,61 +547,61 @@ def on_before_component(callback):
Use elem_id/label fields of kwargs to figure out which component it is. Use elem_id/label fields of kwargs to figure out which component it is.
This can be useful to inject your own components somewhere in the middle of vanilla UI. This can be useful to inject your own components somewhere in the middle of vanilla UI.
""" """
add_callback(callback_map['callbacks_before_component'], callback) add_callback(callback_map['callbacks_before_component'], callback, name=name, category='before_component')
def on_after_component(callback): def on_after_component(callback, *, name=None):
"""register a function to be called after a component is created. See on_before_component for more.""" """register a function to be called after a component is created. See on_before_component for more."""
add_callback(callback_map['callbacks_after_component'], callback) add_callback(callback_map['callbacks_after_component'], callback, name=name, category='after_component')
def on_image_grid(callback): def on_image_grid(callback, *, name=None):
"""register a function to be called before making an image grid. """register a function to be called before making an image grid.
The callback is called with one argument: The callback is called with one argument:
- params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified. - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
""" """
add_callback(callback_map['callbacks_image_grid'], callback) add_callback(callback_map['callbacks_image_grid'], callback, name=name, category='image_grid')
def on_infotext_pasted(callback): def on_infotext_pasted(callback, *, name=None):
"""register a function to be called before applying an infotext. """register a function to be called before applying an infotext.
The callback is called with two arguments: The callback is called with two arguments:
- infotext: str - raw infotext. - infotext: str - raw infotext.
- result: dict[str, any] - parsed infotext parameters. - result: dict[str, any] - parsed infotext parameters.
""" """
add_callback(callback_map['callbacks_infotext_pasted'], callback) add_callback(callback_map['callbacks_infotext_pasted'], callback, name=name, category='infotext_pasted')
def on_script_unloaded(callback): def on_script_unloaded(callback, *, name=None):
"""register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that """register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
the script did should be reverted here""" the script did should be reverted here"""
add_callback(callback_map['callbacks_script_unloaded'], callback) add_callback(callback_map['callbacks_script_unloaded'], callback, name=name, category='script_unloaded')
def on_before_ui(callback): def on_before_ui(callback, *, name=None):
"""register a function to be called before the UI is created.""" """register a function to be called before the UI is created."""
add_callback(callback_map['callbacks_before_ui'], callback) add_callback(callback_map['callbacks_before_ui'], callback, name=name, category='before_ui')
def on_list_optimizers(callback): def on_list_optimizers(callback, *, name=None):
"""register a function to be called when UI is making a list of cross attention optimization options. """register a function to be called when UI is making a list of cross attention optimization options.
The function will be called with one argument, a list, and shall add objects of type modules.sd_hijack_optimizations.SdOptimization The function will be called with one argument, a list, and shall add objects of type modules.sd_hijack_optimizations.SdOptimization
to it.""" to it."""
add_callback(callback_map['callbacks_list_optimizers'], callback) add_callback(callback_map['callbacks_list_optimizers'], callback, name=name, category='list_optimizers')
def on_list_unets(callback): def on_list_unets(callback, *, name=None):
"""register a function to be called when UI is making a list of alternative options for unet. """register a function to be called when UI is making a list of alternative options for unet.
The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it.""" The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
add_callback(callback_map['callbacks_list_unets'], callback) add_callback(callback_map['callbacks_list_unets'], callback, name=name, category='list_unets')
def on_before_token_counter(callback): def on_before_token_counter(callback, *, name=None):
"""register a function to be called when UI is counting tokens for a prompt. """register a function to be called when UI is counting tokens for a prompt.
The function will be called with one argument of type BeforeTokenCounterParams, and should modify its fields if necessary.""" The function will be called with one argument of type BeforeTokenCounterParams, and should modify its fields if necessary."""
add_callback(callback_map['callbacks_before_token_counter'], callback) add_callback(callback_map['callbacks_before_token_counter'], callback, name=name, category='before_token_counter')

View File

@ -7,7 +7,9 @@ from dataclasses import dataclass
import gradio as gr import gradio as gr
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer, util
topological_sort = util.topological_sort
AlwaysVisible = object() AlwaysVisible = object()
@ -138,7 +140,6 @@ class Script:
""" """
pass pass
def before_process(self, p, *args): def before_process(self, p, *args):
""" """
This function is called very early during processing begins for AlwaysVisible scripts. This function is called very early during processing begins for AlwaysVisible scripts.
@ -369,29 +370,6 @@ scripts_data = []
postprocessing_scripts_data = [] postprocessing_scripts_data = []
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"]) ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
def topological_sort(dependencies):
"""Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies.
Ignores errors relating to missing dependeencies or circular dependencies
"""
visited = {}
result = []
def inner(name):
visited[name] = True
for dep in dependencies.get(name, []):
if dep in dependencies and dep not in visited:
inner(dep)
result.append(name)
for depname in dependencies:
if depname not in visited:
inner(depname)
return result
@dataclass @dataclass
class ScriptWithDependencies: class ScriptWithDependencies:
@ -562,6 +540,25 @@ class ScriptRunner:
self.paste_field_names = [] self.paste_field_names = []
self.inputs = [None] self.inputs = [None]
self.callback_map = {}
self.callback_names = [
'before_process',
'process',
'before_process_batch',
'after_extra_networks_activate',
'process_batch',
'postprocess',
'postprocess_batch',
'postprocess_batch_list',
'post_sample',
'on_mask_blend',
'postprocess_image',
'postprocess_maskoverlay',
'postprocess_image_after_composite',
'before_component',
'after_component',
]
self.on_before_component_elem_id = {} self.on_before_component_elem_id = {}
"""dict of callbacks to be called before an element is created; key=elem_id, value=list of callbacks""" """dict of callbacks to be called before an element is created; key=elem_id, value=list of callbacks"""
@ -600,6 +597,8 @@ class ScriptRunner:
self.scripts.append(script) self.scripts.append(script)
self.selectable_scripts.append(script) self.selectable_scripts.append(script)
self.callback_map.clear()
self.apply_on_before_component_callbacks() self.apply_on_before_component_callbacks()
def apply_on_before_component_callbacks(self): def apply_on_before_component_callbacks(self):
@ -769,8 +768,42 @@ class ScriptRunner:
return processed return processed
def list_scripts_for_method(self, method_name):
if method_name in ('before_component', 'after_component'):
return self.scripts
else:
return self.alwayson_scripts
def create_ordered_callbacks_list(self, method_name, *, enable_user_sort=True):
script_list = self.list_scripts_for_method(method_name)
category = f'script_{method_name}'
callbacks = []
for script in script_list:
if getattr(script.__class__, method_name, None) == getattr(Script, method_name, None):
continue
script_callbacks.add_callback(callbacks, script, category=category, name=script.__class__.__name__, filename=script.filename)
return script_callbacks.sort_callbacks(category, callbacks, enable_user_sort=enable_user_sort)
def ordered_callbacks(self, method_name, *, enable_user_sort=True):
script_list = self.list_scripts_for_method(method_name)
category = f'script_{method_name}'
scrpts_len, callbacks = self.callback_map.get(category, (-1, None))
if callbacks is None or scrpts_len != len(script_list):
callbacks = self.create_ordered_callbacks_list(method_name, enable_user_sort=enable_user_sort)
self.callback_map[category] = len(script_list), callbacks
return callbacks
def ordered_scripts(self, method_name):
return [x.callback for x in self.ordered_callbacks(method_name)]
def before_process(self, p): def before_process(self, p):
for script in self.alwayson_scripts: for script in self.ordered_scripts('before_process'):
try: try:
script_args = p.script_args[script.args_from:script.args_to] script_args = p.script_args[script.args_from:script.args_to]
script.before_process(p, *script_args) script.before_process(p, *script_args)
@ -778,7 +811,7 @@ class ScriptRunner:
errors.report(f"Error running before_process: {script.filename}", exc_info=True) errors.report(f"Error running before_process: {script.filename}", exc_info=True)
def process(self, p): def process(self, p):
for script in self.alwayson_scripts: for script in self.ordered_scripts('process'):
try: try:
script_args = p.script_args[script.args_from:script.args_to] script_args = p.script_args[script.args_from:script.args_to]
script.process(p, *script_args) script.process(p, *script_args)
@ -786,7 +819,7 @@ class ScriptRunner:
errors.report(f"Error running process: {script.filename}", exc_info=True) errors.report(f"Error running process: {script.filename}", exc_info=True)
def before_process_batch(self, p, **kwargs): def before_process_batch(self, p, **kwargs):
for script in self.alwayson_scripts: for script in self.ordered_scripts('before_process_batch'):
try: try:
script_args = p.script_args[script.args_from:script.args_to] script_args = p.script_args[script.args_from:script.args_to]
script.before_process_batch(p, *script_args, **kwargs) script.before_process_batch(p, *script_args, **kwargs)
@ -794,7 +827,7 @@ class ScriptRunner:
errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True) errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)
def after_extra_networks_activate(self, p, **kwargs): def after_extra_networks_activate(self, p, **kwargs):
for script in self.alwayson_scripts: for script in self.ordered_scripts('after_extra_networks_activate'):
try: try:
script_args = p.script_args[script.args_from:script.args_to] script_args = p.script_args[script.args_from:script.args_to]
script.after_extra_networks_activate(p, *script_args, **kwargs) script.after_extra_networks_activate(p, *script_args, **kwargs)
@ -802,7 +835,7 @@ class ScriptRunner:
errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True) errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True)
def process_batch(self, p, **kwargs): def process_batch(self, p, **kwargs):
for script in self.alwayson_scripts: for script in self.ordered_scripts('process_batch'):
try: try:
script_args = p.script_args[script.args_from:script.args_to] script_args = p.script_args[script.args_from:script.args_to]
script.process_batch(p, *script_args, **kwargs) script.process_batch(p, *script_args, **kwargs)
@ -810,7 +843,7 @@ class ScriptRunner:
errors.report(f"Error running process_batch: {script.filename}", exc_info=True) errors.report(f"Error running process_batch: {script.filename}", exc_info=True)
def postprocess(self, p, processed): def postprocess(self, p, processed):
for script in self.alwayson_scripts: for script in self.ordered_scripts('postprocess'):
try: try:
script_args = p.script_args[script.args_from:script.args_to] script_args = p.script_args[script.args_from:script.args_to]
script.postprocess(p, processed, *script_args) script.postprocess(p, processed, *script_args)
@ -818,7 +851,7 @@ class ScriptRunner:
errors.report(f"Error running postprocess: {script.filename}", exc_info=True) errors.report(f"Error running postprocess: {script.filename}", exc_info=True)
def postprocess_batch(self, p, images, **kwargs): def postprocess_batch(self, p, images, **kwargs):
for script in self.alwayson_scripts: for script in self.ordered_scripts('postprocess_batch'):
try: try:
script_args = p.script_args[script.args_from:script.args_to] script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_batch(p, *script_args, images=images, **kwargs) script.postprocess_batch(p, *script_args, images=images, **kwargs)
@ -826,7 +859,7 @@ class ScriptRunner:
errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True) errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)
def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs): def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
for script in self.alwayson_scripts: for script in self.ordered_scripts('postprocess_batch_list'):
try: try:
script_args = p.script_args[script.args_from:script.args_to] script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_batch_list(p, pp, *script_args, **kwargs) script.postprocess_batch_list(p, pp, *script_args, **kwargs)
@ -834,7 +867,7 @@ class ScriptRunner:
errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True) errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
def post_sample(self, p, ps: PostSampleArgs): def post_sample(self, p, ps: PostSampleArgs):
for script in self.alwayson_scripts: for script in self.ordered_scripts('post_sample'):
try: try:
script_args = p.script_args[script.args_from:script.args_to] script_args = p.script_args[script.args_from:script.args_to]
script.post_sample(p, ps, *script_args) script.post_sample(p, ps, *script_args)
@ -842,7 +875,7 @@ class ScriptRunner:
errors.report(f"Error running post_sample: {script.filename}", exc_info=True) errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
def on_mask_blend(self, p, mba: MaskBlendArgs): def on_mask_blend(self, p, mba: MaskBlendArgs):
for script in self.alwayson_scripts: for script in self.ordered_scripts('on_mask_blend'):
try: try:
script_args = p.script_args[script.args_from:script.args_to] script_args = p.script_args[script.args_from:script.args_to]
script.on_mask_blend(p, mba, *script_args) script.on_mask_blend(p, mba, *script_args)
@ -850,7 +883,7 @@ class ScriptRunner:
errors.report(f"Error running post_sample: {script.filename}", exc_info=True) errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
def postprocess_image(self, p, pp: PostprocessImageArgs): def postprocess_image(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts: for script in self.ordered_scripts('postprocess_image'):
try: try:
script_args = p.script_args[script.args_from:script.args_to] script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_image(p, pp, *script_args) script.postprocess_image(p, pp, *script_args)
@ -858,7 +891,7 @@ class ScriptRunner:
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs): def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs):
for script in self.alwayson_scripts: for script in self.ordered_scripts('postprocess_maskoverlay'):
try: try:
script_args = p.script_args[script.args_from:script.args_to] script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_maskoverlay(p, ppmo, *script_args) script.postprocess_maskoverlay(p, ppmo, *script_args)
@ -866,7 +899,7 @@ class ScriptRunner:
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs): def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts: for script in self.ordered_scripts('postprocess_image_after_composite'):
try: try:
script_args = p.script_args[script.args_from:script.args_to] script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_image_after_composite(p, pp, *script_args) script.postprocess_image_after_composite(p, pp, *script_args)
@ -880,7 +913,7 @@ class ScriptRunner:
except Exception: except Exception:
errors.report(f"Error running on_before_component: {script.filename}", exc_info=True) errors.report(f"Error running on_before_component: {script.filename}", exc_info=True)
for script in self.scripts: for script in self.ordered_scripts('before_component'):
try: try:
script.before_component(component, **kwargs) script.before_component(component, **kwargs)
except Exception: except Exception:
@ -893,7 +926,7 @@ class ScriptRunner:
except Exception: except Exception:
errors.report(f"Error running on_after_component: {script.filename}", exc_info=True) errors.report(f"Error running on_after_component: {script.filename}", exc_info=True)
for script in self.scripts: for script in self.ordered_scripts('after_component'):
try: try:
script.after_component(component, **kwargs) script.after_component(component, **kwargs)
except Exception: except Exception:
@ -921,7 +954,7 @@ class ScriptRunner:
self.scripts[si].args_to = args_to self.scripts[si].args_to = args_to
def before_hr(self, p): def before_hr(self, p):
for script in self.alwayson_scripts: for script in self.ordered_scripts('before_hr'):
try: try:
script_args = p.script_args[script.args_from:script.args_to] script_args = p.script_args[script.args_from:script.args_to]
script.before_hr(p, *script_args) script.before_hr(p, *script_args)
@ -929,7 +962,7 @@ class ScriptRunner:
errors.report(f"Error running before_hr: {script.filename}", exc_info=True) errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
def setup_scrips(self, p, *, is_ui=True): def setup_scrips(self, p, *, is_ui=True):
for script in self.alwayson_scripts: for script in self.ordered_scripts('setup'):
if not is_ui and script.setup_for_ui_only: if not is_ui and script.setup_for_ui_only:
continue continue

View File

@ -13,8 +13,8 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
for embedder in self.conditioner.embedders: for embedder in self.conditioner.embedders:
embedder.ucg_rate = 0.0 embedder.ucg_rate = 0.0
width = getattr(batch, 'width', 1024) width = getattr(batch, 'width', 1024) or 1024
height = getattr(batch, 'height', 1024) height = getattr(batch, 'height', 1024) or 1024
is_negative_prompt = getattr(batch, 'is_negative_prompt', False) is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score

View File

@ -6,6 +6,10 @@ import gradio as gr
from modules import shared_cmd_options, shared_gradio_themes, options, shared_items, sd_models_types from modules import shared_cmd_options, shared_gradio_themes, options, shared_items, sd_models_types
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401 from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
from modules import util from modules import util
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from modules import shared_state, styles, interrogate, shared_total_tqdm, memmon
cmd_opts = shared_cmd_options.cmd_opts cmd_opts = shared_cmd_options.cmd_opts
parser = shared_cmd_options.parser parser = shared_cmd_options.parser
@ -16,11 +20,11 @@ styles_filename = cmd_opts.styles_file = cmd_opts.styles_file if len(cmd_opts.st
config_filename = cmd_opts.ui_settings_file config_filename = cmd_opts.ui_settings_file
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config} hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
demo = None demo: gr.Blocks = None
device = None device: str = None
weight_load_location = None weight_load_location: str = None
xformers_available = False xformers_available = False
@ -28,21 +32,21 @@ hypernetworks = {}
loaded_hypernetworks = [] loaded_hypernetworks = []
state = None state: 'shared_state.State' = None
prompt_styles = None prompt_styles: 'styles.StyleDatabase' = None
interrogator = None interrogator: 'interrogate.InterrogateModels' = None
face_restorers = [] face_restorers = []
options_templates = None options_templates: dict = None
opts = None opts: options.Options = None
restricted_opts = None restricted_opts: set[str] = None
sd_model: sd_models_types.WebuiSdModel = None sd_model: sd_models_types.WebuiSdModel = None
settings_components = None settings_components: dict = None
"""assigned from ui.py, a mapping on setting names to gradio components repsponsible for those settings""" """assigned from ui.py, a mapping on setting names to gradio components repsponsible for those settings"""
tab_names = [] tab_names = []
@ -65,9 +69,9 @@ progress_print_out = sys.stdout
gradio_theme = gr.themes.Base() gradio_theme = gr.themes.Base()
total_tqdm = None total_tqdm: 'shared_total_tqdm.TotalTQDM' = None
mem_mon = None mem_mon: 'memmon.MemUsageMonitor' = None
options_section = options.options_section options_section = options.options_section
OptionInfo = options.OptionInfo OptionInfo = options.OptionInfo

View File

@ -1,5 +1,8 @@
import html
import sys import sys
from modules import script_callbacks, scripts, ui_components
from modules.options import OptionHTML, OptionInfo
from modules.shared_cmd_options import cmd_opts from modules.shared_cmd_options import cmd_opts
@ -118,6 +121,45 @@ def ui_reorder_categories():
yield "scripts" yield "scripts"
def callbacks_order_settings():
options = {
"sd_vae_explanation": OptionHTML("""
For categories below, callbacks added to dropdowns happen before others, in order listed.
"""),
}
callback_options = {}
for category, _ in script_callbacks.enumerate_callbacks():
callback_options[category] = script_callbacks.ordered_callbacks(category, enable_user_sort=False)
for method_name in scripts.scripts_txt2img.callback_names:
callback_options["script_" + method_name] = scripts.scripts_txt2img.create_ordered_callbacks_list(method_name, enable_user_sort=False)
for method_name in scripts.scripts_img2img.callback_names:
callbacks = callback_options.get("script_" + method_name, [])
for addition in scripts.scripts_img2img.create_ordered_callbacks_list(method_name, enable_user_sort=False):
if any(x.name == addition.name for x in callbacks):
continue
callbacks.append(addition)
callback_options["script_" + method_name] = callbacks
for category, callbacks in callback_options.items():
if not callbacks:
continue
option_info = OptionInfo([], f"{category} callback priority", ui_components.DropdownMulti, {"choices": [x.name for x in callbacks]})
option_info.needs_restart()
option_info.html("<div class='info'>Default order: <ol>" + "".join(f"<li>{html.escape(x.name)}</li>\n" for x in callbacks) + "</ol></div>")
options['prioritized_callbacks_' + category] = option_info
return options
class Shared(sys.modules[__name__].__class__): class Shared(sys.modules[__name__].__class__):
""" """
this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than

View File

@ -101,6 +101,7 @@ options_templates.update(options_section(('upscaling', "Upscaling", "postprocess
"DAT_tile": OptionInfo(192, "Tile size for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"), "DAT_tile": OptionInfo(192, "Tile size for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
"DAT_tile_overlap": OptionInfo(8, "Tile overlap for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"), "DAT_tile_overlap": OptionInfo(8, "Tile overlap for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}),
"set_scale_by_when_changing_upscaler": OptionInfo(False, "Automatically set the Scale by factor based on the name of the selected Upscaler."),
})) }))
options_templates.update(options_section(('face-restoration', "Face restoration", "postprocessing"), { options_templates.update(options_section(('face-restoration', "Face restoration", "postprocessing"), {
@ -258,7 +259,8 @@ options_templates.update(options_section(('extra_networks', "Extra Networks", "s
"extra_networks_card_description_is_html": OptionInfo(False, "Treat card description as HTML"), "extra_networks_card_description_is_html": OptionInfo(False, "Treat card description as HTML"),
"extra_networks_card_order_field": OptionInfo("Path", "Default order field for Extra Networks cards", gr.Dropdown, {"choices": ['Path', 'Name', 'Date Created', 'Date Modified']}).needs_reload_ui(), "extra_networks_card_order_field": OptionInfo("Path", "Default order field for Extra Networks cards", gr.Dropdown, {"choices": ['Path', 'Name', 'Date Created', 'Date Modified']}).needs_reload_ui(),
"extra_networks_card_order": OptionInfo("Ascending", "Default order for Extra Networks cards", gr.Dropdown, {"choices": ['Ascending', 'Descending']}).needs_reload_ui(), "extra_networks_card_order": OptionInfo("Ascending", "Default order for Extra Networks cards", gr.Dropdown, {"choices": ['Ascending', 'Descending']}).needs_reload_ui(),
"extra_networks_tree_view_default_enabled": OptionInfo(False, "Enables the Extra Networks directory tree view by default").needs_reload_ui(), "extra_networks_tree_view_style": OptionInfo("Dirs", "Extra Networks directory view style", gr.Radio, {"choices": ["Tree", "Dirs"]}).needs_reload_ui(),
"extra_networks_tree_view_default_enabled": OptionInfo(True, "Show the Extra Networks directory view by default").needs_reload_ui(),
"extra_networks_tree_view_default_width": OptionInfo(180, "Default width for the Extra Networks directory tree view", gr.Number).needs_reload_ui(), "extra_networks_tree_view_default_width": OptionInfo(180, "Default width for the Extra Networks directory tree view", gr.Number).needs_reload_ui(),
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"), "extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(), "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),

View File

@ -1,3 +1,4 @@
from __future__ import annotations
from pathlib import Path from pathlib import Path
from modules import errors from modules import errors
import csv import csv

View File

@ -164,6 +164,8 @@ class ExtraNetworksPage:
self.lister = util.MassFileLister() self.lister = util.MassFileLister()
# HTML Templates # HTML Templates
self.pane_tpl = shared.html("extra-networks-pane.html") self.pane_tpl = shared.html("extra-networks-pane.html")
self.pane_content_tree_tpl = shared.html("extra-networks-pane-tree.html")
self.pane_content_dirs_tpl = shared.html("extra-networks-pane-dirs.html")
self.card_tpl = shared.html("extra-networks-card.html") self.card_tpl = shared.html("extra-networks-card.html")
self.btn_tree_tpl = shared.html("extra-networks-tree-button.html") self.btn_tree_tpl = shared.html("extra-networks-tree-button.html")
self.btn_copy_path_tpl = shared.html("extra-networks-copy-path-button.html") self.btn_copy_path_tpl = shared.html("extra-networks-copy-path-button.html")
@ -476,6 +478,47 @@ class ExtraNetworksPage:
return f"<ul class='tree-list tree-list--tree'>{res}</ul>" return f"<ul class='tree-list tree-list--tree'>{res}</ul>"
def create_dirs_view_html(self, tabname: str) -> str:
"""Generates HTML for displaying folders."""
subdirs = {}
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
for root, dirs, _ in sorted(os.walk(parentdir, followlinks=True), key=lambda x: shared.natural_sort_key(x[0])):
for dirname in sorted(dirs, key=shared.natural_sort_key):
x = os.path.join(root, dirname)
if not os.path.isdir(x):
continue
subdir = os.path.abspath(x)[len(parentdir):]
if shared.opts.extra_networks_dir_button_function:
if not subdir.startswith(os.path.sep):
subdir = os.path.sep + subdir
else:
while subdir.startswith(os.path.sep):
subdir = subdir[1:]
is_empty = len(os.listdir(x)) == 0
if not is_empty and not subdir.endswith(os.path.sep):
subdir = subdir + os.path.sep
if (os.path.sep + "." in subdir or subdir.startswith(".")) and not shared.opts.extra_networks_show_hidden_directories:
continue
subdirs[subdir] = 1
if subdirs:
subdirs = {"": 1, **subdirs}
subdirs_html = "".join([f"""
<button class='lg secondary gradio-button custom-button{" search-all" if subdir == "" else ""}' onclick='extraNetworksSearchButton("{tabname}", "{self.extra_networks_tabname}", event)'>
{html.escape(subdir if subdir != "" else "all")}
</button>
""" for subdir in subdirs])
return subdirs_html
def create_card_view_html(self, tabname: str, *, none_message) -> str: def create_card_view_html(self, tabname: str, *, none_message) -> str:
"""Generates HTML for the network Card View section for a tab. """Generates HTML for the network Card View section for a tab.
@ -489,15 +532,15 @@ class ExtraNetworksPage:
Returns: Returns:
HTML formatted string. HTML formatted string.
""" """
res = "" res = []
for item in self.items.values(): for item in self.items.values():
res += self.create_item_html(tabname, item, self.card_tpl) res.append(self.create_item_html(tabname, item, self.card_tpl))
if res == "": if not res:
dirs = "".join([f"<li>{x}</li>" for x in self.allowed_directories_for_previews()]) dirs = "".join([f"<li>{x}</li>" for x in self.allowed_directories_for_previews()])
res = none_message or shared.html("extra-networks-no-cards.html").format(dirs=dirs) res = [none_message or shared.html("extra-networks-no-cards.html").format(dirs=dirs)]
return res return "".join(res)
def create_html(self, tabname, *, empty=False): def create_html(self, tabname, *, empty=False):
"""Generates an HTML string for the current pane. """Generates an HTML string for the current pane.
@ -526,35 +569,28 @@ class ExtraNetworksPage:
if "user_metadata" not in item: if "user_metadata" not in item:
self.read_user_metadata(item) self.read_user_metadata(item)
data_sortdir = shared.opts.extra_networks_card_order show_tree = shared.opts.extra_networks_tree_view_default_enabled
data_sortmode = shared.opts.extra_networks_card_order_field.lower().replace("sort", "").replace(" ", "_").rstrip("_").strip()
data_sortkey = f"{data_sortmode}-{data_sortdir}-{len(self.items)}"
tree_view_btn_extra_class = ""
tree_view_div_extra_class = "hidden"
tree_view_div_default_display = "none"
extra_network_pane_content_default_display = "flex"
if shared.opts.extra_networks_tree_view_default_enabled:
tree_view_btn_extra_class = "extra-network-control--enabled"
tree_view_div_extra_class = ""
tree_view_div_default_display = "block"
extra_network_pane_content_default_display = "grid"
return self.pane_tpl.format( page_params = {
**{
"tabname": tabname, "tabname": tabname,
"extra_networks_tabname": self.extra_networks_tabname, "extra_networks_tabname": self.extra_networks_tabname,
"data_sortmode": data_sortmode, "data_sortdir": shared.opts.extra_networks_card_order,
"data_sortkey": data_sortkey, "sort_path_active": ' extra-network-control--enabled' if shared.opts.extra_networks_card_order_field == 'Path' else '',
"data_sortdir": data_sortdir, "sort_name_active": ' extra-network-control--enabled' if shared.opts.extra_networks_card_order_field == 'Name' else '',
"tree_view_btn_extra_class": tree_view_btn_extra_class, "sort_date_created_active": ' extra-network-control--enabled' if shared.opts.extra_networks_card_order_field == 'Date Created' else '',
"tree_view_div_extra_class": tree_view_div_extra_class, "sort_date_modified_active": ' extra-network-control--enabled' if shared.opts.extra_networks_card_order_field == 'Date Modified' else '',
"tree_html": self.create_tree_view_html(tabname), "tree_view_btn_extra_class": "extra-network-control--enabled" if show_tree else "",
"items_html": self.create_card_view_html(tabname, none_message="Loading..." if empty else None), "items_html": self.create_card_view_html(tabname, none_message="Loading..." if empty else None),
"extra_networks_tree_view_default_width": shared.opts.extra_networks_tree_view_default_width, "extra_networks_tree_view_default_width": shared.opts.extra_networks_tree_view_default_width,
"tree_view_div_default_display": tree_view_div_default_display, "tree_view_div_default_display_class": "" if show_tree else "extra-network-dirs-hidden",
"extra_network_pane_content_default_display": extra_network_pane_content_default_display,
} }
)
if shared.opts.extra_networks_tree_view_style == "Tree":
pane_content = self.pane_content_tree_tpl.format(**page_params, tree_html=self.create_tree_view_html(tabname))
else:
pane_content = self.pane_content_dirs_tpl.format(**page_params, dirs_html=self.create_dirs_view_html(tabname))
return self.pane_tpl.format(**page_params, pane_content=pane_content)
def create_item(self, name, index=None): def create_item(self, name, index=None):
raise NotImplementedError() raise NotImplementedError()

View File

@ -133,8 +133,10 @@ class UserMetadataEditor:
filename = item.get("filename", None) filename = item.get("filename", None)
basename, ext = os.path.splitext(filename) basename, ext = os.path.splitext(filename)
with open(basename + '.json', "w", encoding="utf8") as file: metadata_path = basename + '.json'
with open(metadata_path, "w", encoding="utf8") as file:
json.dump(metadata, file, indent=4, ensure_ascii=False) json.dump(metadata, file, indent=4, ensure_ascii=False)
self.page.lister.update_file_entry(metadata_path)
def save_user_metadata(self, name, desc, notes): def save_user_metadata(self, name, desc, notes):
user_metadata = self.get_user_metadata(name) user_metadata = self.get_user_metadata(name)
@ -185,7 +187,8 @@ class UserMetadataEditor:
geninfo, items = images.read_info_from_image(image) geninfo, items = images.read_info_from_image(image)
images.save_image_with_geninfo(image, geninfo, item["local_preview"]) images.save_image_with_geninfo(image, geninfo, item["local_preview"])
self.page.lister.update_file_entry(item["local_preview"])
item['preview'] = self.page.find_preview(item["local_preview"])
return self.get_card_html(name), '' return self.get_card_html(name), ''
def setup_ui(self, gallery): def setup_ui(self, gallery):
@ -200,6 +203,3 @@ class UserMetadataEditor:
inputs=[self.edit_name_input], inputs=[self.edit_name_input],
outputs=[] outputs=[]
) )

View File

@ -104,6 +104,8 @@ class UiLoadsave:
apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None)) apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
if type(x) == InputAccordion: if type(x) == InputAccordion:
if hasattr(x, 'custom_script_source'):
x.accordion.custom_script_source = x.custom_script_source
if x.accordion.visible: if x.accordion.visible:
apply_field(x.accordion, 'visible') apply_field(x.accordion, 'visible')
apply_field(x, 'value') apply_field(x, 'value')

View File

@ -1,7 +1,8 @@
import gradio as gr import gradio as gr
from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer, shared_items
from modules.call_queue import wrap_gradio_call from modules.call_queue import wrap_gradio_call
from modules.options import options_section
from modules.shared import opts from modules.shared import opts
from modules.ui_components import FormRow from modules.ui_components import FormRow
from modules.ui_gradio_extensions import reload_javascript from modules.ui_gradio_extensions import reload_javascript
@ -108,6 +109,11 @@ class UiSettings:
shared.settings_components = self.component_dict shared.settings_components = self.component_dict
# we add this as late as possible so that scripts have already registered their callbacks
opts.data_labels.update(options_section(('callbacks', "Callbacks", "system"), {
**shared_items.callbacks_order_settings(),
}))
opts.reorder() opts.reorder()
with gr.Blocks(analytics_enabled=False) as settings_interface: with gr.Blocks(analytics_enabled=False) as settings_interface:

View File

@ -20,7 +20,7 @@ class Upscaler:
filter = None filter = None
model = None model = None
user_path = None user_path = None
scalers: [] scalers: list
tile = True tile = True
def __init__(self, create_dirs=False): def __init__(self, create_dirs=False):

View File

@ -81,6 +81,17 @@ class MassFileListerCachedDir:
self.files = {x[0].lower(): x for x in files} self.files = {x[0].lower(): x for x in files}
self.files_cased = {x[0]: x for x in files} self.files_cased = {x[0]: x for x in files}
def update_entry(self, filename):
"""Add a file to the cache"""
file_path = os.path.join(self.dirname, filename)
try:
stat = os.stat(file_path)
entry = (filename, stat.st_mtime, stat.st_ctime)
self.files[filename.lower()] = entry
self.files_cased[filename] = entry
except FileNotFoundError as e:
print(f'MassFileListerCachedDir.add_entry: "{file_path}" {e}')
class MassFileLister: class MassFileLister:
"""A class that provides a way to check for the existence and mtime/ctile of files without doing more than one stat call per file.""" """A class that provides a way to check for the existence and mtime/ctile of files without doing more than one stat call per file."""
@ -136,3 +147,27 @@ class MassFileLister:
def reset(self): def reset(self):
"""Clear the cache of all directories.""" """Clear the cache of all directories."""
self.cached_dirs.clear() self.cached_dirs.clear()
def topological_sort(dependencies):
"""Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies.
Ignores errors relating to missing dependeencies or circular dependencies
"""
visited = {}
result = []
def inner(name):
visited[name] = True
for dep in dependencies.get(name, []):
if dep in dependencies and dep not in visited:
inner(dep)
result.append(name)
for depname in dependencies:
if depname not in visited:
inner(depname)
return result

View File

@ -1,10 +1,12 @@
import re
from PIL import Image from PIL import Image
import numpy as np import numpy as np
from modules import scripts_postprocessing, shared from modules import scripts_postprocessing, shared
import gradio as gr import gradio as gr
from modules.ui_components import FormRow, ToolButton from modules.ui_components import FormRow, ToolButton, InputAccordion
from modules.ui import switch_values_symbol from modules.ui import switch_values_symbol
upscale_cache = {} upscale_cache = {}
@ -17,7 +19,14 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
def ui(self): def ui(self):
selected_tab = gr.Number(value=0, visible=False) selected_tab = gr.Number(value=0, visible=False)
with gr.Column(): with InputAccordion(True, label="Upscale", elem_id="extras_upscale") as upscale_enabled:
with FormRow():
extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
with FormRow():
extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility")
with FormRow(): with FormRow():
with gr.Tabs(elem_id="extras_resize_mode"): with gr.Tabs(elem_id="extras_resize_mode"):
with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by: with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by:
@ -32,18 +41,24 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn", tooltip="Switch width/height") upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn", tooltip="Switch width/height")
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
with FormRow(): def on_selected_upscale_method(upscale_method):
extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) if not shared.opts.set_scale_by_when_changing_upscaler:
return gr.update()
with FormRow(): match = re.search(r'(\d)[xX]|[xX](\d)', upscale_method)
extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) if not match:
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility") return gr.update()
return gr.update(value=int(match.group(1) or match.group(2)))
upscaling_res_switch_btn.click(lambda w, h: (h, w), inputs=[upscaling_resize_w, upscaling_resize_h], outputs=[upscaling_resize_w, upscaling_resize_h], show_progress=False) upscaling_res_switch_btn.click(lambda w, h: (h, w), inputs=[upscaling_resize_w, upscaling_resize_h], outputs=[upscaling_resize_w, upscaling_resize_h], show_progress=False)
tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab]) tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab])
tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab]) tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab])
extras_upscaler_1.change(on_selected_upscale_method, inputs=[extras_upscaler_1], outputs=[upscaling_resize], show_progress="hidden")
return { return {
"upscale_enabled": upscale_enabled,
"upscale_mode": selected_tab, "upscale_mode": selected_tab,
"upscale_by": upscaling_resize, "upscale_by": upscaling_resize,
"upscale_to_width": upscaling_resize_w, "upscale_to_width": upscaling_resize_w,
@ -81,7 +96,7 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
return image return image
def process_firstpass(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0): def process_firstpass(self, pp: scripts_postprocessing.PostprocessedImage, upscale_enabled=True, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0):
if upscale_mode == 1: if upscale_mode == 1:
pp.shared.target_width = upscale_to_width pp.shared.target_width = upscale_to_width
pp.shared.target_height = upscale_to_height pp.shared.target_height = upscale_to_height
@ -89,7 +104,10 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
pp.shared.target_width = int(pp.image.width * upscale_by) pp.shared.target_width = int(pp.image.width * upscale_by)
pp.shared.target_height = int(pp.image.height * upscale_by) pp.shared.target_height = int(pp.image.height * upscale_by)
def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0): def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_enabled=True, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0):
if not upscale_enabled:
return
if upscaler_1_name == "None": if upscaler_1_name == "None":
upscaler_1_name = None upscaler_1_name = None

View File

@ -1,6 +1,6 @@
/* temporary fix to load default gradio font in frontend instead of backend */ /* temporary fix to load default gradio font in frontend instead of backend */
@import url('webui-assets/css/sourcesanspro.css'); @import url('/webui-assets/css/sourcesanspro.css');
/* temporary fix to hide gradio crop tool until it's fixed https://github.com/gradio-app/gradio/issues/3810 */ /* temporary fix to hide gradio crop tool until it's fixed https://github.com/gradio-app/gradio/issues/3810 */
@ -528,6 +528,10 @@ table.popup-table .link{
opacity: 0.75; opacity: 0.75;
} }
.settings-comment .info ol{
margin: 0.4em 0 0.8em 1em;
}
#sysinfo_download a.sysinfo_big_link{ #sysinfo_download a.sysinfo_big_link{
font-size: 24pt; font-size: 24pt;
} }
@ -1205,12 +1209,24 @@ body.resizing .resize-handle {
overflow: hidden; overflow: hidden;
} }
.extra-network-pane .extra-network-pane-content { .extra-network-pane .extra-network-pane-content-dirs {
display: flex;
flex: 1;
flex-direction: column;
overflow: hidden;
}
.extra-network-pane .extra-network-pane-content-tree {
display: flex; display: flex;
flex: 1; flex: 1;
overflow: hidden; overflow: hidden;
} }
.extra-network-dirs-hidden .extra-network-dirs{ display: none; }
.extra-network-dirs-hidden .extra-network-tree{ display: none; }
.extra-network-dirs-hidden .resize-handle { display: none; }
.extra-network-dirs-hidden .resize-handle-row { display: flex !important; }
.extra-network-pane .extra-network-tree { .extra-network-pane .extra-network-tree {
flex: 1; flex: 1;
font-size: 1rem; font-size: 1rem;
@ -1260,7 +1276,7 @@ body.resizing .resize-handle {
.extra-network-control { .extra-network-control {
position: relative; position: relative;
display: grid; display: flex;
width: 100%; width: 100%;
padding: 0 !important; padding: 0 !important;
margin-top: 0 !important; margin-top: 0 !important;
@ -1277,6 +1293,12 @@ body.resizing .resize-handle {
align-items: start; align-items: start;
} }
.extra-network-control small{
color: var(--input-placeholder-color);
line-height: 2.2rem;
margin: 0 0.5rem 0 0.75rem;
}
.extra-network-tree .tree-list--tree {} .extra-network-tree .tree-list--tree {}
/* Remove auto indentation from tree. Will be overridden later. */ /* Remove auto indentation from tree. Will be overridden later. */
@ -1424,6 +1446,12 @@ body.resizing .resize-handle {
line-height: 1rem; line-height: 1rem;
} }
.extra-network-control .extra-network-control--search .extra-network-control--search-text::placeholder {
color: var(--input-placeholder-color);
}
/* <input> clear button (x on right side) styling */ /* <input> clear button (x on right side) styling */
.extra-network-control .extra-network-control--search .extra-network-control--search-text::-webkit-search-cancel-button { .extra-network-control .extra-network-control--search .extra-network-control--search-text::-webkit-search-cancel-button {
-webkit-appearance: none; -webkit-appearance: none;
@ -1456,19 +1484,19 @@ body.resizing .resize-handle {
background-color: var(--input-placeholder-color); background-color: var(--input-placeholder-color);
} }
.extra-network-control .extra-network-control--sort[data-sortmode="path"] .extra-network-control--sort-icon { .extra-network-control .extra-network-control--sort[data-sortkey="default"] .extra-network-control--sort-icon {
mask-image: url('data:image/svg+xml,<svg viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><g id="SVGRepo_bgCarrier" stroke-width="0"></g><g id="SVGRepo_tracerCarrier" stroke-linecap="round" stroke-linejoin="round"></g><g id="SVGRepo_iconCarrier"><path fill-rule="evenodd" clip-rule="evenodd" d="M1 5C1 3.34315 2.34315 2 4 2H8.43845C9.81505 2 11.015 2.93689 11.3489 4.27239L11.7808 6H13.5H20C21.6569 6 23 7.34315 23 9V11C23 11.5523 22.5523 12 22 12C21.4477 12 21 11.5523 21 11V9C21 8.44772 20.5523 8 20 8H13.5H11.7808H4C3.44772 8 3 8.44772 3 9V10V19C3 19.5523 3.44772 20 4 20H9C9.55228 20 10 20.4477 10 21C10 21.5523 9.55228 22 9 22H4C2.34315 22 1 20.6569 1 19V10V9V5ZM3 6.17071C3.31278 6.06015 3.64936 6 4 6H9.71922L9.40859 4.75746C9.2973 4.3123 8.89732 4 8.43845 4H4C3.44772 4 3 4.44772 3 5V6.17071ZM20.1716 18.7574C20.6951 17.967 21 17.0191 21 16C21 13.2386 18.7614 11 16 11C13.2386 11 11 13.2386 11 16C11 18.7614 13.2386 21 16 21C17.0191 21 17.967 20.6951 18.7574 20.1716L21.2929 22.7071C21.6834 23.0976 22.3166 23.0976 22.7071 22.7071C23.0976 22.3166 23.0976 21.6834 22.7071 21.2929L20.1716 18.7574ZM13 16C13 14.3431 14.3431 13 16 13C17.6569 13 19 14.3431 19 16C19 17.6569 17.6569 19 16 19C14.3431 19 13 17.6569 13 16Z" fill="%23000000"></path></g></svg>'); mask-image: url('data:image/svg+xml,<svg viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><g id="SVGRepo_bgCarrier" stroke-width="0"></g><g id="SVGRepo_tracerCarrier" stroke-linecap="round" stroke-linejoin="round"></g><g id="SVGRepo_iconCarrier"><path fill-rule="evenodd" clip-rule="evenodd" d="M1 5C1 3.34315 2.34315 2 4 2H8.43845C9.81505 2 11.015 2.93689 11.3489 4.27239L11.7808 6H13.5H20C21.6569 6 23 7.34315 23 9V11C23 11.5523 22.5523 12 22 12C21.4477 12 21 11.5523 21 11V9C21 8.44772 20.5523 8 20 8H13.5H11.7808H4C3.44772 8 3 8.44772 3 9V10V19C3 19.5523 3.44772 20 4 20H9C9.55228 20 10 20.4477 10 21C10 21.5523 9.55228 22 9 22H4C2.34315 22 1 20.6569 1 19V10V9V5ZM3 6.17071C3.31278 6.06015 3.64936 6 4 6H9.71922L9.40859 4.75746C9.2973 4.3123 8.89732 4 8.43845 4H4C3.44772 4 3 4.44772 3 5V6.17071ZM20.1716 18.7574C20.6951 17.967 21 17.0191 21 16C21 13.2386 18.7614 11 16 11C13.2386 11 11 13.2386 11 16C11 18.7614 13.2386 21 16 21C17.0191 21 17.967 20.6951 18.7574 20.1716L21.2929 22.7071C21.6834 23.0976 22.3166 23.0976 22.7071 22.7071C23.0976 22.3166 23.0976 21.6834 22.7071 21.2929L20.1716 18.7574ZM13 16C13 14.3431 14.3431 13 16 13C17.6569 13 19 14.3431 19 16C19 17.6569 17.6569 19 16 19C14.3431 19 13 17.6569 13 16Z" fill="%23000000"></path></g></svg>');
} }
.extra-network-control .extra-network-control--sort[data-sortmode="name"] .extra-network-control--sort-icon { .extra-network-control .extra-network-control--sort[data-sortkey="name"] .extra-network-control--sort-icon {
mask-image: url('data:image/svg+xml,<svg viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><g id="SVGRepo_bgCarrier" stroke-width="0"></g><g id="SVGRepo_tracerCarrier" stroke-linecap="round" stroke-linejoin="round"></g><g id="SVGRepo_iconCarrier"><path fill-rule="evenodd" clip-rule="evenodd" d="M17.1841 6.69223C17.063 6.42309 16.7953 6.25 16.5002 6.25C16.2051 6.25 15.9374 6.42309 15.8162 6.69223L11.3162 16.6922C11.1463 17.07 11.3147 17.514 11.6924 17.6839C12.0701 17.8539 12.5141 17.6855 12.6841 17.3078L14.1215 14.1136H18.8789L20.3162 17.3078C20.4862 17.6855 20.9302 17.8539 21.308 17.6839C21.6857 17.514 21.8541 17.07 21.6841 16.6922L17.1841 6.69223ZM16.5002 8.82764L14.7965 12.6136H18.2039L16.5002 8.82764Z" fill="%231C274C"></path><path opacity="0.5" fill-rule="evenodd" clip-rule="evenodd" d="M2.25 7C2.25 6.58579 2.58579 6.25 3 6.25H13C13.4142 6.25 13.75 6.58579 13.75 7C13.75 7.41421 13.4142 7.75 13 7.75H3C2.58579 7.75 2.25 7.41421 2.25 7Z" fill="%231C274C"></path><path opacity="0.5" d="M2.25 12C2.25 11.5858 2.58579 11.25 3 11.25H10C10.4142 11.25 10.75 11.5858 10.75 12C10.75 12.4142 10.4142 12.75 10 12.75H3C2.58579 12.75 2.25 12.4142 2.25 12Z" fill="%231C274C"></path><path opacity="0.5" d="M2.25 17C2.25 16.5858 2.58579 16.25 3 16.25H8C8.41421 16.25 8.75 16.5858 8.75 17C8.75 17.4142 8.41421 17.75 8 17.75H3C2.58579 17.75 2.25 17.4142 2.25 17Z" fill="%231C274C"></path></g></svg>'); mask-image: url('data:image/svg+xml,<svg viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><g id="SVGRepo_bgCarrier" stroke-width="0"></g><g id="SVGRepo_tracerCarrier" stroke-linecap="round" stroke-linejoin="round"></g><g id="SVGRepo_iconCarrier"><path fill-rule="evenodd" clip-rule="evenodd" d="M17.1841 6.69223C17.063 6.42309 16.7953 6.25 16.5002 6.25C16.2051 6.25 15.9374 6.42309 15.8162 6.69223L11.3162 16.6922C11.1463 17.07 11.3147 17.514 11.6924 17.6839C12.0701 17.8539 12.5141 17.6855 12.6841 17.3078L14.1215 14.1136H18.8789L20.3162 17.3078C20.4862 17.6855 20.9302 17.8539 21.308 17.6839C21.6857 17.514 21.8541 17.07 21.6841 16.6922L17.1841 6.69223ZM16.5002 8.82764L14.7965 12.6136H18.2039L16.5002 8.82764Z" fill="%231C274C"></path><path opacity="0.5" fill-rule="evenodd" clip-rule="evenodd" d="M2.25 7C2.25 6.58579 2.58579 6.25 3 6.25H13C13.4142 6.25 13.75 6.58579 13.75 7C13.75 7.41421 13.4142 7.75 13 7.75H3C2.58579 7.75 2.25 7.41421 2.25 7Z" fill="%231C274C"></path><path opacity="0.5" d="M2.25 12C2.25 11.5858 2.58579 11.25 3 11.25H10C10.4142 11.25 10.75 11.5858 10.75 12C10.75 12.4142 10.4142 12.75 10 12.75H3C2.58579 12.75 2.25 12.4142 2.25 12Z" fill="%231C274C"></path><path opacity="0.5" d="M2.25 17C2.25 16.5858 2.58579 16.25 3 16.25H8C8.41421 16.25 8.75 16.5858 8.75 17C8.75 17.4142 8.41421 17.75 8 17.75H3C2.58579 17.75 2.25 17.4142 2.25 17Z" fill="%231C274C"></path></g></svg>');
} }
.extra-network-control .extra-network-control--sort[data-sortmode="date_created"] .extra-network-control--sort-icon { .extra-network-control .extra-network-control--sort[data-sortkey="date_created"] .extra-network-control--sort-icon {
mask-image: url('data:image/svg+xml,<svg viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><g id="SVGRepo_bgCarrier" stroke-width="0"></g><g id="SVGRepo_tracerCarrier" stroke-linecap="round" stroke-linejoin="round"></g><g id="SVGRepo_iconCarrier"><path d="M17 11C14.2386 11 12 13.2386 12 16C12 18.7614 14.2386 21 17 21C19.7614 21 22 18.7614 22 16C22 13.2386 19.7614 11 17 11ZM17 11V9M2 9V15.8C2 16.9201 2 17.4802 2.21799 17.908C2.40973 18.2843 2.71569 18.5903 3.09202 18.782C3.51984 19 4.0799 19 5.2 19H13M2 9V8.2C2 7.0799 2 6.51984 2.21799 6.09202C2.40973 5.71569 2.71569 5.40973 3.09202 5.21799C3.51984 5 4.0799 5 5.2 5H13.8C14.9201 5 15.4802 5 15.908 5.21799C16.2843 5.40973 16.5903 5.71569 16.782 6.09202C17 6.51984 17 7.0799 17 8.2V9M2 9H17M5 3V5M14 3V5M15 16H17M17 16H19M17 16V14M17 16V18" stroke="black" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"></path></g></svg>'); mask-image: url('data:image/svg+xml,<svg viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><g id="SVGRepo_bgCarrier" stroke-width="0"></g><g id="SVGRepo_tracerCarrier" stroke-linecap="round" stroke-linejoin="round"></g><g id="SVGRepo_iconCarrier"><path d="M17 11C14.2386 11 12 13.2386 12 16C12 18.7614 14.2386 21 17 21C19.7614 21 22 18.7614 22 16C22 13.2386 19.7614 11 17 11ZM17 11V9M2 9V15.8C2 16.9201 2 17.4802 2.21799 17.908C2.40973 18.2843 2.71569 18.5903 3.09202 18.782C3.51984 19 4.0799 19 5.2 19H13M2 9V8.2C2 7.0799 2 6.51984 2.21799 6.09202C2.40973 5.71569 2.71569 5.40973 3.09202 5.21799C3.51984 5 4.0799 5 5.2 5H13.8C14.9201 5 15.4802 5 15.908 5.21799C16.2843 5.40973 16.5903 5.71569 16.782 6.09202C17 6.51984 17 7.0799 17 8.2V9M2 9H17M5 3V5M14 3V5M15 16H17M17 16H19M17 16V14M17 16V18" stroke="black" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"></path></g></svg>');
} }
.extra-network-control .extra-network-control--sort[data-sortmode="date_modified"] .extra-network-control--sort-icon { .extra-network-control .extra-network-control--sort[data-sortkey="date_modified"] .extra-network-control--sort-icon {
mask-image: url('data:image/svg+xml,<svg viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><g id="SVGRepo_bgCarrier" stroke-width="0"></g><g id="SVGRepo_tracerCarrier" stroke-linecap="round" stroke-linejoin="round"></g><g id="SVGRepo_iconCarrier"><path d="M10 21H6.2C5.0799 21 4.51984 21 4.09202 20.782C3.71569 20.5903 3.40973 20.2843 3.21799 19.908C3 19.4802 3 18.9201 3 17.8V8.2C3 7.0799 3 6.51984 3.21799 6.09202C3.40973 5.71569 3.71569 5.40973 4.09202 5.21799C4.51984 5 5.0799 5 6.2 5H17.8C18.9201 5 19.4802 5 19.908 5.21799C20.2843 5.40973 20.5903 5.71569 20.782 6.09202C21 6.51984 21 7.0799 21 8.2V10M7 3V5M17 3V5M3 9H21M13.5 13.0001L7 13M10 17.0001L7 17M14 21L16.025 20.595C16.2015 20.5597 16.2898 20.542 16.3721 20.5097C16.4452 20.4811 16.5147 20.4439 16.579 20.399C16.6516 20.3484 16.7152 20.2848 16.8426 20.1574L21 16C21.5523 15.4477 21.5523 14.5523 21 14C20.4477 13.4477 19.5523 13.4477 19 14L14.8426 18.1574C14.7152 18.2848 14.6516 18.3484 14.601 18.421C14.5561 18.4853 14.5189 18.5548 14.4903 18.6279C14.458 18.7102 14.4403 18.7985 14.405 18.975L14 21Z" stroke="black" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"></path></g></svg>'); mask-image: url('data:image/svg+xml,<svg viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><g id="SVGRepo_bgCarrier" stroke-width="0"></g><g id="SVGRepo_tracerCarrier" stroke-linecap="round" stroke-linejoin="round"></g><g id="SVGRepo_iconCarrier"><path d="M10 21H6.2C5.0799 21 4.51984 21 4.09202 20.782C3.71569 20.5903 3.40973 20.2843 3.21799 19.908C3 19.4802 3 18.9201 3 17.8V8.2C3 7.0799 3 6.51984 3.21799 6.09202C3.40973 5.71569 3.71569 5.40973 4.09202 5.21799C4.51984 5 5.0799 5 6.2 5H17.8C18.9201 5 19.4802 5 19.908 5.21799C20.2843 5.40973 20.5903 5.71569 20.782 6.09202C21 6.51984 21 7.0799 21 8.2V10M7 3V5M17 3V5M3 9H21M13.5 13.0001L7 13M10 17.0001L7 17M14 21L16.025 20.595C16.2015 20.5597 16.2898 20.542 16.3721 20.5097C16.4452 20.4811 16.5147 20.4439 16.579 20.399C16.6516 20.3484 16.7152 20.2848 16.8426 20.1574L21 16C21.5523 15.4477 21.5523 14.5523 21 14C20.4477 13.4477 19.5523 13.4477 19 14L14.8426 18.1574C14.7152 18.2848 14.6516 18.3484 14.601 18.421C14.5561 18.4853 14.5189 18.5548 14.4903 18.6279C14.458 18.7102 14.4403 18.7985 14.405 18.975L14 21Z" stroke="black" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"></path></g></svg>');
} }
@ -1518,13 +1546,18 @@ body.resizing .resize-handle {
} }
.extra-network-control .extra-network-control--enabled { .extra-network-control .extra-network-control--enabled {
background-color: rgba(0, 0, 0, 0.15); background-color: rgba(0, 0, 0, 0.1);
border-radius: 0.25rem;
} }
.dark .extra-network-control .extra-network-control--enabled { .dark .extra-network-control .extra-network-control--enabled {
background-color: rgba(255, 255, 255, 0.15); background-color: rgba(255, 255, 255, 0.15);
} }
.extra-network-control .extra-network-control--enabled .extra-network-control--icon{
background-color: var(--button-secondary-text-color);
}
/* ==== REFRESH ICON ACTIONS ==== */ /* ==== REFRESH ICON ACTIONS ==== */
.extra-network-control .extra-network-control--refresh { .extra-network-control .extra-network-control--refresh {
padding: 0.25rem; padding: 0.25rem;

View File

@ -130,12 +130,18 @@ case "$gpu_info" in
if [[ -z "${TORCH_COMMAND}" ]] if [[ -z "${TORCH_COMMAND}" ]]
then then
pyv="$(${python_cmd} -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')" pyv="$(${python_cmd} -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')"
if [[ $(bc <<< "$pyv <= 3.10") -eq 1 ]] # Using an old nightly compiled against rocm 5.2 for Navi1, see https://github.com/pytorch/pytorch/issues/106728#issuecomment-1749511711
if [[ $pyv == "3.8" ]]
then then
# Navi users will still use torch 1.13 because 2.0 does not seem to work. export TORCH_COMMAND="pip install https://download.pytorch.org/whl/nightly/rocm5.2/torch-2.0.0.dev20230209%2Brocm5.2-cp38-cp38-linux_x86_64.whl https://download.pytorch.org/whl/nightly/rocm5.2/torchvision-0.15.0.dev20230209%2Brocm5.2-cp38-cp38-linux_x86_64.whl"
export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.6" elif [[ $pyv == "3.9" ]]
then
export TORCH_COMMAND="pip install https://download.pytorch.org/whl/nightly/rocm5.2/torch-2.0.0.dev20230209%2Brocm5.2-cp39-cp39-linux_x86_64.whl https://download.pytorch.org/whl/nightly/rocm5.2/torchvision-0.15.0.dev20230209%2Brocm5.2-cp39-cp39-linux_x86_64.whl"
elif [[ $pyv == "3.10" ]]
then
export TORCH_COMMAND="pip install https://download.pytorch.org/whl/nightly/rocm5.2/torch-2.0.0.dev20230209%2Brocm5.2-cp310-cp310-linux_x86_64.whl https://download.pytorch.org/whl/nightly/rocm5.2/torchvision-0.15.0.dev20230209%2Brocm5.2-cp310-cp310-linux_x86_64.whl"
else else
printf "\e[1m\e[31mERROR: RX 5000 series GPUs must be using at max python 3.10, aborting...\e[0m" printf "\e[1m\e[31mERROR: RX 5000 series GPUs python version must be between 3.8 and 3.10, aborting...\e[0m"
exit 1 exit 1
fi fi
fi fi
@ -143,7 +149,7 @@ case "$gpu_info" in
*"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0 *"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0
;; ;;
*"Navi 3"*) [[ -z "${TORCH_COMMAND}" ]] && \ *"Navi 3"*) [[ -z "${TORCH_COMMAND}" ]] && \
export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.7" export TORCH_COMMAND="pip install torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.7"
;; ;;
*"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0 *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0
printf "\n%s\n" "${delimiter}" printf "\n%s\n" "${delimiter}"