UniPC progress bar adjustment

This commit is contained in:
Sakura-Luna 2023-05-11 12:26:04 +08:00
parent 22bcc7be42
commit ae17e97898

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import math import math
from tqdm.auto import trange import tqdm
class NoiseScheduleVP: class NoiseScheduleVP:
@ -757,40 +757,44 @@ class UniPC:
vec_t = timesteps[0].expand((x.shape[0])) vec_t = timesteps[0].expand((x.shape[0]))
model_prev_list = [self.model_fn(x, vec_t)] model_prev_list = [self.model_fn(x, vec_t)]
t_prev_list = [vec_t] t_prev_list = [vec_t]
# Init the first `order` values by lower order multistep DPM-Solver. with tqdm.tqdm(total=steps) as pbar:
for init_order in range(1, order): # Init the first `order` values by lower order multistep DPM-Solver.
vec_t = timesteps[init_order].expand(x.shape[0]) for init_order in range(1, order):
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True) vec_t = timesteps[init_order].expand(x.shape[0])
if model_x is None: x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
model_x = self.model_fn(x, vec_t)
if self.after_update is not None:
self.after_update(x, model_x)
model_prev_list.append(model_x)
t_prev_list.append(vec_t)
for step in trange(order, steps + 1):
vec_t = timesteps[step].expand(x.shape[0])
if lower_order_final:
step_order = min(order, steps + 1 - step)
else:
step_order = order
#print('this step order:', step_order)
if step == steps:
#print('do not run corrector at the last step')
use_corrector = False
else:
use_corrector = True
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
if self.after_update is not None:
self.after_update(x, model_x)
for i in range(order - 1):
t_prev_list[i] = t_prev_list[i + 1]
model_prev_list[i] = model_prev_list[i + 1]
t_prev_list[-1] = vec_t
# We do not need to evaluate the final model value.
if step < steps:
if model_x is None: if model_x is None:
model_x = self.model_fn(x, vec_t) model_x = self.model_fn(x, vec_t)
model_prev_list[-1] = model_x if self.after_update is not None:
self.after_update(x, model_x)
model_prev_list.append(model_x)
t_prev_list.append(vec_t)
pbar.update()
for step in range(order, steps + 1):
vec_t = timesteps[step].expand(x.shape[0])
if lower_order_final:
step_order = min(order, steps + 1 - step)
else:
step_order = order
#print('this step order:', step_order)
if step == steps:
#print('do not run corrector at the last step')
use_corrector = False
else:
use_corrector = True
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
if self.after_update is not None:
self.after_update(x, model_x)
for i in range(order - 1):
t_prev_list[i] = t_prev_list[i + 1]
model_prev_list[i] = model_prev_list[i + 1]
t_prev_list[-1] = vec_t
# We do not need to evaluate the final model value.
if step < steps:
if model_x is None:
model_x = self.model_fn(x, vec_t)
model_prev_list[-1] = model_x
pbar.update()
else: else:
raise NotImplementedError() raise NotImplementedError()
if denoise_to_zero: if denoise_to_zero: