sort self.word_embeddings without instantiating it a new dict

This commit is contained in:
Brad Smith 2023-04-13 23:12:33 -04:00
parent 27b9ec60e4
commit dab5002c59
No known key found for this signature in database
GPG Key ID: CDABCFFBBD8DA710

View File

@ -2,7 +2,7 @@ import os
import sys import sys
import traceback import traceback
import inspect import inspect
from collections import namedtuple, OrderedDict from collections import namedtuple
import torch import torch
import tqdm import tqdm
@ -108,7 +108,7 @@ class DirWithTextualInversionEmbeddings:
class EmbeddingDatabase: class EmbeddingDatabase:
def __init__(self): def __init__(self):
self.ids_lookup = {} self.ids_lookup = {}
self.word_embeddings = OrderedDict() self.word_embeddings = {}
self.skipped_embeddings = {} self.skipped_embeddings = {}
self.expected_shape = -1 self.expected_shape = -1
self.embedding_dirs = {} self.embedding_dirs = {}
@ -234,7 +234,10 @@ class EmbeddingDatabase:
embdir.update() embdir.update()
# re-sort word_embeddings because load_from_dir may not load in alphabetic order. # re-sort word_embeddings because load_from_dir may not load in alphabetic order.
self.word_embeddings = {e.name: e for e in sorted(self.word_embeddings.values(), key=lambda e: e.name.lower())} # using a temporary copy so we don't reinitialize self.word_embeddings in case other objects have a reference to it.
sorted_word_embeddings = {e.name: e for e in sorted(self.word_embeddings.values(), key=lambda e: e.name.lower())}
self.word_embeddings.clear()
self.word_embeddings.update(sorted_word_embeddings)
displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys())) displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
if self.previously_displayed_embeddings != displayed_embeddings: if self.previously_displayed_embeddings != displayed_embeddings: