mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-20 05:10:15 +08:00
change json tensor key name
This commit is contained in:
parent
5d12ec82d3
commit
d0184b8f76
@ -19,15 +19,15 @@ import modules.textual_inversion.dataset
|
|||||||
class EmbeddingEncoder(json.JSONEncoder):
|
class EmbeddingEncoder(json.JSONEncoder):
|
||||||
def default(self, obj):
|
def default(self, obj):
|
||||||
if isinstance(obj, torch.Tensor):
|
if isinstance(obj, torch.Tensor):
|
||||||
return {'EMBEDDINGTENSOR':obj.cpu().detach().numpy().tolist()}
|
return {'TORCHTENSOR':obj.cpu().detach().numpy().tolist()}
|
||||||
return json.JSONEncoder.default(self, o)
|
return json.JSONEncoder.default(self, o)
|
||||||
|
|
||||||
class EmbeddingDecoder(json.JSONDecoder):
|
class EmbeddingDecoder(json.JSONDecoder):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
|
json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
|
||||||
def object_hook(self, d):
|
def object_hook(self, d):
|
||||||
if 'EMBEDDINGTENSOR' in d:
|
if 'TORCHTENSOR' in d:
|
||||||
return torch.from_numpy(np.array(d['EMBEDDINGTENSOR']))
|
return torch.from_numpy(np.array(d['TORCHTENSOR']))
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def embeddingToB64(data):
|
def embeddingToB64(data):
|
||||||
|
Loading…
Reference in New Issue
Block a user