diff --git a/modules/initialize_util.py b/modules/initialize_util.py index b6767138d..63cea137f 100644 --- a/modules/initialize_util.py +++ b/modules/initialize_util.py @@ -24,6 +24,13 @@ def fix_torch_version(): torch.__long_version__ = torch.__version__ torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) +def fix_pytorch_lightning(): + import pytorch_lightning + # Checks if pytorch_lightning.utilities.distributed already exists in the sys.modules cache + if 'pytorch_lightning.utilities.distributed' not in sys.modules: + # Lets the user know that the library was not found and then will set it to pytorch_lightning.utilities.rank_zero + print(f"Pytorch_lightning.distributed not found, attempting pytorch_lightning.rank_zero") + sys.modules["pytorch_lightning.utilities.distributed"] = pytorch_lightning.utilities.rank_zero def fix_asyncio_event_loop_policy(): """