mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-18 12:20:11 +08:00
31 lines
894 B
Python
31 lines
894 B
Python
|
import torch
|
||
|
|
||
|
|
||
|
class TorchHijackForUnet:
|
||
|
"""
|
||
|
This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
|
||
|
this makes it possible to create pictures with dimensions that are muliples of 8 rather than 64
|
||
|
"""
|
||
|
|
||
|
def __getattr__(self, item):
|
||
|
if item == 'cat':
|
||
|
return self.cat
|
||
|
|
||
|
if hasattr(torch, item):
|
||
|
return getattr(torch, item)
|
||
|
|
||
|
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
|
||
|
|
||
|
def cat(self, tensors, *args, **kwargs):
|
||
|
if len(tensors) == 2:
|
||
|
a, b = tensors
|
||
|
if a.shape[-2:] != b.shape[-2:]:
|
||
|
a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
|
||
|
|
||
|
tensors = (a, b)
|
||
|
|
||
|
return torch.cat(tensors, *args, **kwargs)
|
||
|
|
||
|
|
||
|
th = TorchHijackForUnet()
|