mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-20 05:10:15 +08:00
extract backup/restore io-bound operations out of forward hooks to speed up
This commit is contained in:
parent
0ab4d7992c
commit
04f9084253
@ -417,16 +417,8 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li
|
|||||||
restore_weights_backup(self, 'bias', bias_backup)
|
restore_weights_backup(self, 'bias', bias_backup)
|
||||||
|
|
||||||
|
|
||||||
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
|
def network_backup_weights(self):
|
||||||
"""
|
|
||||||
Applies the currently selected set of networks to the weights of torch layer self.
|
|
||||||
If weights already have this particular set of networks applied, does nothing.
|
|
||||||
If not, restores original weights from backup and alters weights according to networks.
|
|
||||||
"""
|
|
||||||
|
|
||||||
network_layer_name = getattr(self, 'network_layer_name', None)
|
network_layer_name = getattr(self, 'network_layer_name', None)
|
||||||
if network_layer_name is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
current_names = getattr(self, "network_current_names", ())
|
current_names = getattr(self, "network_current_names", ())
|
||||||
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
|
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
|
||||||
@ -459,9 +451,31 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||||||
|
|
||||||
self.network_bias_backup = bias_backup
|
self.network_bias_backup = bias_backup
|
||||||
|
|
||||||
if current_names != wanted_names:
|
|
||||||
|
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
|
||||||
|
"""
|
||||||
|
Applies the currently selected set of networks to the weights of torch layer self.
|
||||||
|
If weights already have this particular set of networks applied, does nothing.
|
||||||
|
If not, restores original weights from backup and alters weights according to networks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
network_layer_name = getattr(self, 'network_layer_name', None)
|
||||||
|
if network_layer_name is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
current_names = getattr(self, "network_current_names", ())
|
||||||
|
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
|
||||||
|
|
||||||
|
weights_backup = getattr(self, "network_weights_backup", None)
|
||||||
|
if weights_backup is None and wanted_names != ():
|
||||||
|
network_backup_weights(self)
|
||||||
|
elif current_names != () and current_names != wanted_names and not getattr(self, "weights_restored", False):
|
||||||
network_restore_weights_from_backup(self)
|
network_restore_weights_from_backup(self)
|
||||||
|
|
||||||
|
if current_names != wanted_names:
|
||||||
|
if hasattr(self, "weights_restored"):
|
||||||
|
self.weights_restored = False
|
||||||
|
|
||||||
for net in loaded_networks:
|
for net in loaded_networks:
|
||||||
module = net.modules.get(network_layer_name, None)
|
module = net.modules.get(network_layer_name, None)
|
||||||
if module is not None and hasattr(self, 'weight') and not all(isinstance(module, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)):
|
if module is not None and hasattr(self, 'weight') and not all(isinstance(module, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)):
|
||||||
|
@ -143,6 +143,10 @@ class ScriptLora(scripts.Script):
|
|||||||
|
|
||||||
target_dtype = devices.dtype_inference
|
target_dtype = devices.dtype_inference
|
||||||
for module in modules:
|
for module in modules:
|
||||||
|
network_layer_name = getattr(module, 'network_layer_name', None)
|
||||||
|
if network_layer_name is None:
|
||||||
|
continue
|
||||||
|
|
||||||
if isinstance(module, torch.nn.MultiheadAttention):
|
if isinstance(module, torch.nn.MultiheadAttention):
|
||||||
org_dtype = torch.float32
|
org_dtype = torch.float32
|
||||||
else:
|
else:
|
||||||
@ -155,6 +159,18 @@ class ScriptLora(scripts.Script):
|
|||||||
# set org_dtype
|
# set org_dtype
|
||||||
module.org_dtype = org_dtype
|
module.org_dtype = org_dtype
|
||||||
|
|
||||||
|
# backup/restore weights
|
||||||
|
current_names = getattr(module, "network_current_names", ())
|
||||||
|
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in networks.loaded_networks)
|
||||||
|
|
||||||
|
weights_backup = getattr(module, "network_weights_backup", None)
|
||||||
|
|
||||||
|
if current_names == () and current_names != wanted_names and weights_backup is None:
|
||||||
|
networks.network_backup_weights(module)
|
||||||
|
elif current_names != () and current_names != wanted_names:
|
||||||
|
networks.network_restore_weights_from_backup(module)
|
||||||
|
module.weights_restored = True
|
||||||
|
|
||||||
|
|
||||||
script_callbacks.on_infotext_pasted(infotext_pasted)
|
script_callbacks.on_infotext_pasted(infotext_pasted)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user