From 3728ef6d02bb6c02eca34bc03264f40cef7854eb Mon Sep 17 00:00:00 2001 From: George Hotz Date: Mon, 5 Sep 2022 16:48:26 -0700 Subject: [PATCH] better alphas --- examples/stable_diffusion.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index 73619550c9..66f74ec315 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -515,6 +515,7 @@ class CLIPTextTransformer: class StableDiffusion: def __init__(self): + self.alphas_cumprod = Tensor.empty(1000) self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel()) self.first_stage_model = AutoencoderKL() self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = CLIPTextTransformer())) @@ -591,14 +592,20 @@ if __name__ == "__main__": #alphas = [0.9983, 0.6722, 0.2750, 0.0557] #alphas_prev = [0.9991499781608582, 0.9982960224151611, 0.6721514463424683, 0.27499905228614807] - alphas = [0.9983, 0.8930, 0.7521, 0.5888, 0.4229, 0.2750, 0.1598, 0.0819, 0.0365, 0.0140] - alphas_prev = [1.0, 0.9983, 0.8930, 0.7521, 0.5888, 0.4229, 0.2750, 0.1598, 0.0819, 0.0365] + #alphas = [0.9983, 0.8930, 0.7521, 0.5888, 0.4229, 0.2750, 0.1598, 0.0819, 0.0365, 0.0140] + #alphas_prev = [1.0, 0.9983, 0.8930, 0.7521, 0.5888, 0.4229, 0.2750, 0.1598, 0.0819, 0.0365] + #timesteps = [1, 101, 201, 301, 401, 501, 601, 701, 801, 901] + timesteps = list(np.arange(1, 1000, 1000//20)) + print(timesteps) + alphas = [model.alphas_cumprod.numpy()[t] for t in timesteps] + alphas_prev = [1.0] + alphas[:-1] def get_x_prev_and_pred_x0(x, e_t, index): temperature = 1 a_t, a_prev = alphas[index], alphas_prev[index] sigma_t = 0 sqrt_one_minus_at = math.sqrt(1-a_t) + print(a_t, a_prev, sigma_t, sqrt_one_minus_at) pred_x0 = (x - sqrt_one_minus_at * e_t) / math.sqrt(a_t) @@ -614,7 +621,7 @@ if __name__ == "__main__": # is this the diffusion? #for index, timestep in tqdm(list(enumerate([1, 251, 501, 751]))[::-1]): - for index, timestep in tqdm(list(enumerate([1, 101, 201, 301, 401, 501, 601, 701, 801, 901]))[::-1]): + for index, timestep in tqdm(list(enumerate(timesteps))[::-1]): print(index, timestep) e_t = get_model_output(latent, timestep) #print(e_t.numpy())