mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-04-06 04:39:01 +08:00
remove duplicated code
This commit is contained in:
parent
891ccb767c
commit
a8cbe50c9f
@ -15,7 +15,7 @@ import torch
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
|
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
|
||||||
from modules.textual_inversion.textual_inversion import Embedding
|
import modules.textual_inversion.textual_inversion as textual_inversion
|
||||||
|
|
||||||
from lora_logger import logger
|
from lora_logger import logger
|
||||||
|
|
||||||
@ -210,34 +210,7 @@ def load_network(name, network_on_disk):
|
|||||||
|
|
||||||
embeddings = {}
|
embeddings = {}
|
||||||
for emb_name, data in bundle_embeddings.items():
|
for emb_name, data in bundle_embeddings.items():
|
||||||
# textual inversion embeddings
|
embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name)
|
||||||
if 'string_to_param' in data:
|
|
||||||
param_dict = data['string_to_param']
|
|
||||||
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
|
|
||||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
|
||||||
emb = next(iter(param_dict.items()))[1]
|
|
||||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
|
||||||
shape = vec.shape[-1]
|
|
||||||
vectors = vec.shape[0]
|
|
||||||
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
|
|
||||||
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
|
|
||||||
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
|
|
||||||
vectors = data['clip_g'].shape[0]
|
|
||||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
|
|
||||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
|
||||||
|
|
||||||
emb = next(iter(data.values()))
|
|
||||||
if len(emb.shape) == 1:
|
|
||||||
emb = emb.unsqueeze(0)
|
|
||||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
|
||||||
shape = vec.shape[-1]
|
|
||||||
vectors = vec.shape[0]
|
|
||||||
else:
|
|
||||||
raise Exception(f"Couldn't identify {emb_name} in lora: {name} as neither textual inversion embedding nor diffuser concept.")
|
|
||||||
|
|
||||||
embedding = Embedding(vec, emb_name)
|
|
||||||
embedding.vectors = vectors
|
|
||||||
embedding.shape = shape
|
|
||||||
embedding.loaded = None
|
embedding.loaded = None
|
||||||
embeddings[emb_name] = embedding
|
embeddings[emb_name] = embedding
|
||||||
|
|
||||||
|
@ -181,40 +181,7 @@ class EmbeddingDatabase:
|
|||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
embedding = create_embedding_from_data(data, name, filename=filename, filepath=path)
|
||||||
# textual inversion embeddings
|
|
||||||
if 'string_to_param' in data:
|
|
||||||
param_dict = data['string_to_param']
|
|
||||||
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
|
|
||||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
|
||||||
emb = next(iter(param_dict.items()))[1]
|
|
||||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
|
||||||
shape = vec.shape[-1]
|
|
||||||
vectors = vec.shape[0]
|
|
||||||
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
|
|
||||||
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
|
|
||||||
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
|
|
||||||
vectors = data['clip_g'].shape[0]
|
|
||||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
|
|
||||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
|
||||||
|
|
||||||
emb = next(iter(data.values()))
|
|
||||||
if len(emb.shape) == 1:
|
|
||||||
emb = emb.unsqueeze(0)
|
|
||||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
|
||||||
shape = vec.shape[-1]
|
|
||||||
vectors = vec.shape[0]
|
|
||||||
else:
|
|
||||||
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
|
||||||
|
|
||||||
embedding = Embedding(vec, name)
|
|
||||||
embedding.step = data.get('step', None)
|
|
||||||
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
|
||||||
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
|
||||||
embedding.vectors = vectors
|
|
||||||
embedding.shape = shape
|
|
||||||
embedding.filename = path
|
|
||||||
embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')
|
|
||||||
|
|
||||||
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
||||||
self.register_embedding(embedding, shared.sd_model)
|
self.register_embedding(embedding, shared.sd_model)
|
||||||
@ -313,6 +280,45 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
|
|||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
def create_embedding_from_data(data, name, filename='unknown embedding file', filepath=None):
|
||||||
|
if 'string_to_param' in data: # textual inversion embeddings
|
||||||
|
param_dict = data['string_to_param']
|
||||||
|
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||||
|
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||||
|
emb = next(iter(param_dict.items()))[1]
|
||||||
|
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||||
|
shape = vec.shape[-1]
|
||||||
|
vectors = vec.shape[0]
|
||||||
|
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
|
||||||
|
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
|
||||||
|
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
|
||||||
|
vectors = data['clip_g'].shape[0]
|
||||||
|
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
|
||||||
|
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||||
|
|
||||||
|
emb = next(iter(data.values()))
|
||||||
|
if len(emb.shape) == 1:
|
||||||
|
emb = emb.unsqueeze(0)
|
||||||
|
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||||
|
shape = vec.shape[-1]
|
||||||
|
vectors = vec.shape[0]
|
||||||
|
else:
|
||||||
|
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||||
|
|
||||||
|
embedding = Embedding(vec, name)
|
||||||
|
embedding.step = data.get('step', None)
|
||||||
|
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
||||||
|
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
||||||
|
embedding.vectors = vectors
|
||||||
|
embedding.shape = shape
|
||||||
|
|
||||||
|
if filepath:
|
||||||
|
embedding.filename = filepath
|
||||||
|
embedding.set_hash(hashes.sha256(filepath, "textual_inversion/" + name) or '')
|
||||||
|
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
def write_loss(log_directory, filename, step, epoch_len, values):
|
def write_loss(log_directory, filename, step, epoch_len, values):
|
||||||
if shared.opts.training_write_csv_every == 0:
|
if shared.opts.training_write_csv_every == 0:
|
||||||
return
|
return
|
||||||
|
Loading…
x
Reference in New Issue
Block a user