From 0b143c1163a96b193a4e8512be9c5831c661a50d Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Thu, 3 Nov 2022 14:30:53 +0900 Subject: [PATCH] Separate .optim file from model --- modules/hypernetworks/hypernetwork.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 8f74cdeae..63c25de8b 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -161,6 +161,7 @@ class Hypernetwork: def save(self, filename): state_dict = {} + optimizer_saved_dict = {} for k, v in self.layers.items(): state_dict[k] = (v[0].state_dict(), v[1].state_dict()) @@ -175,9 +176,10 @@ class Hypernetwork: state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name if self.optimizer_name is not None: - state_dict['optimizer_name'] = self.optimizer_name + optimizer_saved_dict['optimizer_name'] = self.optimizer_name if self.optimizer_state_dict: - state_dict['optimizer_state_dict'] = self.optimizer_state_dict + optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict + torch.save(optimizer_saved_dict, filename + '.optim') torch.save(state_dict, filename) @@ -198,9 +200,11 @@ class Hypernetwork: print(f"Layer norm is set to {self.add_layer_norm}") self.use_dropout = state_dict.get('use_dropout', False) print(f"Dropout usage is set to {self.use_dropout}") - self.optimizer_name = state_dict.get('optimizer_name', 'AdamW') + + optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {} + self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW') print(f"Optimizer name is {self.optimizer_name}") - self.optimizer_state_dict = state_dict.get('optimizer_state_dict', None) + self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) if self.optimizer_state_dict: print("Loaded existing optimizer from checkpoint") else: