diff --git a/scripts/embedding_editor.py b/scripts/embedding_editor.py index b000dde..b67c810 100644 --- a/scripts/embedding_editor.py +++ b/scripts/embedding_editor.py @@ -248,7 +248,11 @@ def save_embedding_weights(embedding_name, vector_num, *weights): checkpoint = sd_models.select_checkpoint() filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') - save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True) + optimizer = torch.optim.AdamW([embedding.vec]) + + save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True) + + print(f"Saved embedding to {filename}") def update_guidance_embeddings(text): try: