diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 88d68c76c..375178ede 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -17,7 +17,7 @@ re_numbers_at_start = re.compile(r"^[-\d]+\s*") class DatasetEntry: - def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None): + def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, img_shape=None): self.filename = filename self.filename_text = filename_text self.latent_dist = latent_dist @@ -25,6 +25,7 @@ class DatasetEntry: self.cond = cond self.cond_text = cond_text self.pixel_values = pixel_values + self.img_shape = img_shape class PersonalizedBase(Dataset): @@ -33,8 +34,6 @@ class PersonalizedBase(Dataset): self.placeholder_token = placeholder_token - self.width = width - self.height = height self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.dataset = [] @@ -59,7 +58,11 @@ class PersonalizedBase(Dataset): if shared.state.interrupted: raise Exception("interrupted") try: - image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC) + image = Image.open(path).convert('RGB') + if width < 2000: + image = image.resize((width, height), PIL.Image.BICUBIC) + else: + assert batch_size == 1, 'variable img size must have batch size 1' except Exception: continue @@ -88,14 +91,14 @@ class PersonalizedBase(Dataset): if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)): latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) latent_sampling_method = "once" - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, img_shape=image.size) elif latent_sampling_method == "deterministic": # Works only for DiagonalGaussianDistribution latent_dist.std = 0 latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, img_shape=image.size) elif latent_sampling_method == "random": - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, img_shape=image.size) if not (self.tag_drop_out != 0 or self.shuffle_tags): entry.cond_text = self.create_text(filename_text) @@ -151,6 +154,7 @@ class BatchLoader: self.cond_text = [entry.cond_text for entry in data] self.cond = [entry.cond for entry in data] self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1) + self.img_shape = [entry.img_shape for entry in data] #self.emb_index = [entry.emb_index for entry in data] #print(self.latent_sample.device) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 45882ed68..9f96d0fda 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -451,8 +451,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ else: p.prompt = batch.cond_text[0] p.steps = 20 - p.width = training_width - p.height = training_height + p.width = batch.img_shape[0][0] + p.height = batch.img_shape[0][1] preview_text = p.prompt