From 920a3a4dceafd56f2c45f14ea61a31baa05bfc92 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Wed, 31 Jul 2024 02:43:13 +0900 Subject: [PATCH] SD3 Lora page filter - detection not implemented --- extensions-builtin/Lora/network.py | 2 ++ .../Lora/scripts/lora_script.py | 3 ++- .../Lora/ui_edit_user_metadata.py | 2 +- .../Lora/ui_extra_networks_lora.py | 21 +++++++++++++++++-- 4 files changed, 24 insertions(+), 4 deletions(-) diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index 98ff367fd..dc413998e 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -19,6 +19,7 @@ class SdVersion(enum.Enum): SD1 = 2 SD2 = 3 SDXL = 4 + SD3 = 5 class NetworkOnDisk: @@ -59,6 +60,7 @@ class NetworkOnDisk: self.sd_version = self.detect_version() def detect_version(self): + # TODO: SdVersion.SD3 detection if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"): return SdVersion.SDXL elif str(self.metadata.get('ss_v2', "")) == "True": diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index d3ea369ae..4075c43c9 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -38,7 +38,8 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra "lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"), "lora_bundled_ti_to_infotext": shared.OptionInfo(True, "Add Lora name as TI hashes for bundled Textual Inversion").info('"Add Textual Inversion hashes to infotext" needs to be enabled'), "lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"), - "lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}), + "lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL", "SD3"]}), + "TEMP_setting_sd3_lora_filter": shared.OptionInfo(["SD1", "Unknown"], "For SD3 model also show Lora of other sd version", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL", "Unknown"]}).info('Temporary setting until SD3 Lora detection is implemented'), "lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}), "lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"), "lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"), diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index b6c4d1c6a..654b47d39 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -160,7 +160,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) def create_extra_default_items_in_left_column(self): # this would be a lot better as gr.Radio but I can't make it work - self.select_sd_version = gr.Dropdown(['SD1', 'SD2', 'SDXL', 'Unknown'], value='Unknown', label='Stable Diffusion version', interactive=True) + self.select_sd_version = gr.Dropdown(['SD1', 'SD2', 'SDXL', 'SD3', 'Unknown'], value='Unknown', label='Stable Diffusion version', interactive=True) def create_editor(self): self.create_default_editor_elems() diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 3e34d69dc..a6077e7b6 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -3,7 +3,7 @@ import os import network import networks -from modules import shared, ui_extra_networks +from modules import shared, ui_extra_networks, sd_models_types from modules.ui_extra_networks import quote_js from ui_edit_user_metadata import LoraUserMetadataEditor @@ -62,8 +62,14 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): if shared.opts.lora_show_all or not enable_filter or not shared.sd_model: pass + elif shared.sd_model.is_sd3: + # TODO: add proper SD3 filtering when detection is implemented + # TODO: move after Unknown block when implemented + if sd_version is network.SdVersion.SD3 or sd_version.name in shared.opts.TEMP_setting_sd3_lora_filter: + return item + return None elif sd_version == network.SdVersion.Unknown: - model_version = network.SdVersion.SDXL if shared.sd_model.is_sdxl else network.SdVersion.SD2 if shared.sd_model.is_sd2 else network.SdVersion.SD1 + model_version = self.sd_to_lora_version(shared.sd_model) if model_version.name in shared.opts.lora_hide_unknown_for_versions: return None elif shared.sd_model.is_sdxl and sd_version != network.SdVersion.SDXL: @@ -88,3 +94,14 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): def create_user_metadata_editor(self, ui, tabname): return LoraUserMetadataEditor(ui, tabname, self) + + @staticmethod + def sd_to_lora_version(sd_model: sd_models_types.WebuiSdModel): + if sd_model.is_sd1: + return network.SdVersion.SD1 + elif sd_model.is_sd2: + return network.SdVersion.SD2 + elif sd_model.is_sdxl: + return network.SdVersion.SDXL + elif sd_model.is_sd3: + return network.SdVersion.SD3