diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index fa48708e4..b47414f34 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -3,8 +3,10 @@ import numpy as np import PIL import torch from PIL import Image -from torch.utils.data import Dataset, DataLoader +from torch.utils.data import Dataset, DataLoader, Sampler from torchvision import transforms +from collections import defaultdict +from random import shuffle, choices import random import tqdm @@ -45,12 +47,12 @@ class PersonalizedBase(Dataset): assert data_root, 'dataset directory not specified' assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.listdir(data_root), "Dataset directory is empty" - assert batch_size == 1 or not varsize, 'variable img size must have batch size 1' self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] self.shuffle_tags = shuffle_tags self.tag_drop_out = tag_drop_out + groups = defaultdict(list) print("Preparing dataset...") for path in tqdm.tqdm(self.image_paths): @@ -103,13 +105,14 @@ class PersonalizedBase(Dataset): if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags): with devices.autocast(): entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) - + groups[image.size].append(len(self.dataset)) self.dataset.append(entry) del torchdata del latent_dist del latent_sample self.length = len(self.dataset) + self.groups = list(groups.values()) assert self.length > 0, "No images have been found in the dataset." self.batch_size = min(batch_size, self.length) self.gradient_step = min(gradient_step, self.length // self.batch_size) @@ -137,9 +140,34 @@ class PersonalizedBase(Dataset): entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu) return entry +class GroupedBatchSampler(Sampler): + def __init__(self, data_source: PersonalizedBase, batch_size: int): + n = len(data_source) + self.groups = data_source.groups + self.len = n_batch = n // batch_size + expected = [len(g) / n * n_batch * batch_size for g in data_source.groups] + self.base = [int(e) // batch_size for e in expected] + self.n_rand_batches = nrb = n_batch - sum(self.base) + self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected] + self.batch_size = batch_size + def __len__(self): + return self.len + def __iter__(self): + b = self.batch_size + for g in self.groups: + shuffle(g) + batches = [] + for g in self.groups: + batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b)) + for _ in range(self.n_rand_batches): + rand_group = choices(self.groups, self.probs)[0] + batches.append(choices(rand_group, k=b)) + shuffle(batches) + yield from batches + class PersonalizedDataLoader(DataLoader): def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False): - super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, pin_memory=pin_memory) + super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory) if latent_sampling_method == "random": self.collate_fn = collate_wrapper_random else: