mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-02-03 20:22:56 +08:00
fix linter issues
This commit is contained in:
parent
f8f4ff2bb8
commit
2d947175b9
@ -212,7 +212,7 @@ class StableDiffusionModelHijack:
|
|||||||
model_embeddings = m.cond_stage_model.roberta.embeddings
|
model_embeddings = m.cond_stage_model.roberta.embeddings
|
||||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
|
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
|
||||||
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
|
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
|
||||||
|
|
||||||
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
|
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
|
||||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
||||||
@ -258,7 +258,7 @@ class StableDiffusionModelHijack:
|
|||||||
|
|
||||||
if hasattr(m, 'cond_stage_model'):
|
if hasattr(m, 'cond_stage_model'):
|
||||||
delattr(m, 'cond_stage_model')
|
delattr(m, 'cond_stage_model')
|
||||||
|
|
||||||
elif type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
|
elif type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
|
||||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||||
|
|
||||||
|
@ -95,8 +95,7 @@ def guess_model_config_from_state_dict(sd, filename):
|
|||||||
if diffusion_model_input.shape[1] == 8:
|
if diffusion_model_input.shape[1] == 8:
|
||||||
return config_instruct_pix2pix
|
return config_instruct_pix2pix
|
||||||
|
|
||||||
|
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
|
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
|
||||||
if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
|
if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
|
||||||
return config_alt_diffusion_m18
|
return config_alt_diffusion_m18
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from transformers import BertPreTrainedModel,BertModel,BertConfig
|
from transformers import BertPreTrainedModel,BertConfig
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch
|
import torch
|
||||||
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
|
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
|
||||||
@ -28,7 +28,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
|||||||
config_class = BertSeriesConfig
|
config_class = BertSeriesConfig
|
||||||
|
|
||||||
def __init__(self, config=None, **kargs):
|
def __init__(self, config=None, **kargs):
|
||||||
# modify initialization for autoloading
|
# modify initialization for autoloading
|
||||||
if config is None:
|
if config is None:
|
||||||
config = XLMRobertaConfig()
|
config = XLMRobertaConfig()
|
||||||
config.attention_probs_dropout_prob= 0.1
|
config.attention_probs_dropout_prob= 0.1
|
||||||
@ -80,7 +80,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
|||||||
text["attention_mask"] = torch.tensor(
|
text["attention_mask"] = torch.tensor(
|
||||||
text['attention_mask']).to(device)
|
text['attention_mask']).to(device)
|
||||||
features = self(**text)
|
features = self(**text)
|
||||||
return features['projection_state']
|
return features['projection_state']
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -147,8 +147,8 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
|||||||
"hidden_states": outputs.hidden_states,
|
"hidden_states": outputs.hidden_states,
|
||||||
"attentions": outputs.attentions,
|
"attentions": outputs.attentions,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# return {
|
# return {
|
||||||
# 'pooler_output':pooler_output,
|
# 'pooler_output':pooler_output,
|
||||||
# 'last_hidden_state':outputs.last_hidden_state,
|
# 'last_hidden_state':outputs.last_hidden_state,
|
||||||
@ -161,4 +161,4 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
|||||||
|
|
||||||
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
|
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
|
||||||
base_model_prefix = 'roberta'
|
base_model_prefix = 'roberta'
|
||||||
config_class= RobertaSeriesConfig
|
config_class= RobertaSeriesConfig
|
||||||
|
Loading…
Reference in New Issue
Block a user