From 64a8f9d1b11c9f6427693f097f8b7eb8f5b4aec1 Mon Sep 17 00:00:00 2001
From: arrmansa <41120982+arrmansa@users.noreply.github.com>
Date: Mon, 30 Dec 2024 04:14:50 +0530
Subject: [PATCH] Update img2imgalt.py

Fix with documentation
---
 scripts/img2imgalt.py | 41 +++++++++++++++++++++++++++++++++++++----
 1 file changed, 37 insertions(+), 4 deletions(-)

diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py
index 1e833fa89..109c4a2ab 100644
--- a/scripts/img2imgalt.py
+++ b/scripts/img2imgalt.py
@@ -11,6 +11,10 @@ from modules import processing, shared, sd_samplers, sd_samplers_common
 import torch
 import k_diffusion as K
 
+# Debugging notes - the original method apply_model is being called for sd1.5 is in modules.sd_hijack_utils and is ldm.models.diffusion.ddpm.LatentDiffusion
+# For sdxl - OpenAIWrapper will be called, which will call the underlying diffusion_model
+
+
 def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
     x = p.init_latent
 
@@ -30,7 +34,13 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
 
         x_in = torch.cat([x] * 2)
         sigma_in = torch.cat([sigmas[i] * s_in] * 2)
-        cond_in = torch.cat([uncond, cond])
+
+        if shared.sd_model.is_sdxl:
+            cond_tensor = cond['crossattn']
+            uncond_tensor = uncond['crossattn']
+            cond_in = torch.cat([uncond_tensor, cond_tensor])
+        else:
+            cond_in = torch.cat([uncond, cond])
 
         image_conditioning = torch.cat([p.image_conditioning] * 2)
         cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
@@ -38,7 +48,11 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
         c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]]
         t = dnw.sigma_to_t(sigma_in)
 
-        eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)
+        if shared.sd_model.is_sdxl:
+            eps = shared.sd_model.model(x_in * c_in, t, {"crossattn": cond_in["c_crossattn"][0]} )
+        else:
+            eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)
+
         denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)
 
         denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale
@@ -64,6 +78,13 @@ Cached = namedtuple("Cached", ["noise", "cfg_scale", "steps", "latent", "origina
 
 # Based on changes suggested by briansemrau in https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/736
 def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
+    if shared.sd_model.is_sdxl:
+        cond_tensor = cond['crossattn']
+        uncond_tensor = uncond['crossattn']
+        cond_in = torch.cat([uncond_tensor, cond_tensor])
+    else:
+        cond_in = torch.cat([uncond, cond])
+
     x = p.init_latent
 
     s_in = x.new_ones([x.shape[0]])
@@ -82,7 +103,14 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
 
         x_in = torch.cat([x] * 2)
         sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2)
-        cond_in = torch.cat([uncond, cond])
+
+
+        if shared.sd_model.is_sdxl:
+            cond_tensor = cond['crossattn']
+            uncond_tensor = uncond['crossattn']
+            cond_in = torch.cat([uncond_tensor, cond_tensor])
+        else:
+            cond_in = torch.cat([uncond, cond])
 
         image_conditioning = torch.cat([p.image_conditioning] * 2)
         cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
@@ -94,7 +122,12 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
         else:
             t = dnw.sigma_to_t(sigma_in)
 
-        eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)
+
+        if shared.sd_model.is_sdxl:
+            eps = shared.sd_model.model(x_in * c_in, t, {"crossattn": cond_in["c_crossattn"][0]} )
+        else:
+            eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)
+
         denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)
 
         denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale