Update sd_hijack_ddpm_v1.py

This commit is contained in:
Dalton 2024-03-19 14:45:07 -04:00 committed by GitHub
parent 61f321756f
commit 86276832e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,6 +4,7 @@
# Some models such as LDSR require VQ to work correctly # Some models such as LDSR require VQ to work correctly
# The classes are suffixed with "V1" and added back to the "ldm.models.diffusion.ddpm" module # The classes are suffixed with "V1" and added back to the "ldm.models.diffusion.ddpm" module
import sys
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
@ -14,6 +15,8 @@ from contextlib import contextmanager
from functools import partial from functools import partial
from tqdm import tqdm from tqdm import tqdm
from torchvision.utils import make_grid from torchvision.utils import make_grid
import pytorch_lightning.utilities.rank_zero
sys.modules['pytorch_lightning.utilities.distributed'] = sys.modules['pytorch_lightning.utilities.rank_zero']
from pytorch_lightning.utilities.rank_zero import rank_zero_only from pytorch_lightning.utilities.rank_zero import rank_zero_only
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config