extract backup/restore io-bound operations out of forward hooks to speed up

This commit is contained in:
Won-Kyu Park 2024-10-03 18:04:51 +09:00
parent 0ab4d7992c
commit 04f9084253
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15
2 changed files with 40 additions and 10 deletions

View File

@ -417,16 +417,8 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li
restore_weights_backup(self, 'bias', bias_backup)
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
"""
Applies the currently selected set of networks to the weights of torch layer self.
If weights already have this particular set of networks applied, does nothing.
If not, restores original weights from backup and alters weights according to networks.
"""
def network_backup_weights(self):
network_layer_name = getattr(self, 'network_layer_name', None)
if network_layer_name is None:
return
current_names = getattr(self, "network_current_names", ())
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
@ -459,9 +451,31 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self.network_bias_backup = bias_backup
if current_names != wanted_names:
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
"""
Applies the currently selected set of networks to the weights of torch layer self.
If weights already have this particular set of networks applied, does nothing.
If not, restores original weights from backup and alters weights according to networks.
"""
network_layer_name = getattr(self, 'network_layer_name', None)
if network_layer_name is None:
return
current_names = getattr(self, "network_current_names", ())
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
weights_backup = getattr(self, "network_weights_backup", None)
if weights_backup is None and wanted_names != ():
network_backup_weights(self)
elif current_names != () and current_names != wanted_names and not getattr(self, "weights_restored", False):
network_restore_weights_from_backup(self)
if current_names != wanted_names:
if hasattr(self, "weights_restored"):
self.weights_restored = False
for net in loaded_networks:
module = net.modules.get(network_layer_name, None)
if module is not None and hasattr(self, 'weight') and not all(isinstance(module, linear) for linear in (modules.models.sd3.mmdit.QkvLinear, modules.models.flux.modules.layers.QkvLinear)):

View File

@ -143,6 +143,10 @@ class ScriptLora(scripts.Script):
target_dtype = devices.dtype_inference
for module in modules:
network_layer_name = getattr(module, 'network_layer_name', None)
if network_layer_name is None:
continue
if isinstance(module, torch.nn.MultiheadAttention):
org_dtype = torch.float32
else:
@ -155,6 +159,18 @@ class ScriptLora(scripts.Script):
# set org_dtype
module.org_dtype = org_dtype
# backup/restore weights
current_names = getattr(module, "network_current_names", ())
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in networks.loaded_networks)
weights_backup = getattr(module, "network_weights_backup", None)
if current_names == () and current_names != wanted_names and weights_backup is None:
networks.network_backup_weights(module)
elif current_names != () and current_names != wanted_names:
networks.network_restore_weights_from_backup(module)
module.weights_restored = True
script_callbacks.on_infotext_pasted(infotext_pasted)