From 81133d4168ae0bae9bf8bf1a1d4983319a589112 Mon Sep 17 00:00:00 2001 From: Faber Date: Fri, 6 Jan 2023 03:38:37 +0700 Subject: [PATCH] allow loading embeddings from subdirectories --- .../textual_inversion/textual_inversion.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 24b430459..0a0590440 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -149,19 +149,20 @@ class EmbeddingDatabase: else: self.skipped_embeddings[name] = embedding - for fn in os.listdir(self.embeddings_dir): - try: - fullfn = os.path.join(self.embeddings_dir, fn) + for root, dirs, fns in os.walk(self.embeddings_dir): + for fn in fns: + try: + fullfn = os.path.join(root, fn) - if os.stat(fullfn).st_size == 0: + if os.stat(fullfn).st_size == 0: + continue + + process_file(fullfn, fn) + except Exception: + print(f"Error loading embedding {fn}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) continue - process_file(fullfn, fn) - except Exception: - print(f"Error loading embedding {fn}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - continue - print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") if len(self.skipped_embeddings) > 0: print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")