diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index ccdb15589d..3cdce2628c 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -409,19 +409,30 @@ class CLIPAttention: def _shape(self, tensor, seq_len: int, bsz: int): return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).permute(0,2,1,3) - def __call__(self, hidden_states): + def __call__(self, hidden_states, causal_attention_mask): bsz, tgt_len, embed_dim = hidden_states.shape + query_states = self.q_proj(hidden_states) * self.scale key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + #print("ATTN", query_states.numpy()) + #print(hidden_states.shape, query_states.shape, key_states.shape, value_states.shape) + proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).reshape(*proj_shape) key_states = key_states.reshape(*proj_shape) + src_len = key_states.shape[1] value_states = value_states.reshape(*proj_shape) + #print(query_states.shape, key_states.shape) attn_weights = query_states @ key_states.permute(0,2,1) + + #print(attn_weights.shape, causal_attention_mask.shape) + attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask + attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.softmax() attn_output = attn_weights @ value_states @@ -440,10 +451,10 @@ class CLIPEncoderLayer: self.mlp = CLIPMLP() self.layer_norm2 = Normalize(768, num_groups=None) - def __call__(self, hidden_states): + def __call__(self, hidden_states, causal_attention_mask): residual = hidden_states hidden_states = self.layer_norm1(hidden_states) - hidden_states = self.self_attn(hidden_states) + hidden_states = self.self_attn(hidden_states, causal_attention_mask) hidden_states = residual + hidden_states residual = hidden_states @@ -457,8 +468,13 @@ class CLIPEncoder: def __init__(self): self.layers = [CLIPEncoderLayer() for i in range(12)] - def __call__(self, hidden_states): - return hidden_states.sequential(self.layers) + def __call__(self, hidden_states, causal_attention_mask): + for i,l in enumerate(self.layers): + #if i == 2: + # print(hidden_states.numpy()) + # break + hidden_states = l(hidden_states, causal_attention_mask) + return hidden_states class CLIPTextEmbeddings: def __init__(self): @@ -484,14 +500,16 @@ class CLIPTextTransformer: def __call__(self, input_ids): x = self.embeddings(input_ids, list(range(len(input_ids)))) - x = self.encoder(x) + print(x.numpy()) + causal_attention_mask = np.triu(np.ones((1,1,77,77), dtype=np.float32) * -np.inf, k=1) + x = self.encoder(x, Tensor(causal_attention_mask, device=x.device)) return self.final_layer_norm(x) class StableDiffusion: def __init__(self): - self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel()) + #self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel()) #self.first_stage_model = AutoencoderKL() - #self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = CLIPTextTransformer())) + self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = CLIPTextTransformer())) def __call__(self, x, timesteps, context): return self.model.diffusion_model(x, timesteps, context) @@ -537,12 +555,41 @@ if __name__ == "__main__": assert w.shape == v.shape w.assign(v.astype(np.float32)) + # "a horse sized cat eating a bagel" + phrase = [49406, 320, 4558, 9832, 2368, 4371, 320, 28777, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407] + + """ + # load apple latent space + nz = Tensor(np.load("datasets/stable_diffusion_apple.npy")) + + # run unet (without context) + timesteps = Tensor([32]) + context = Tensor.zeros(1, 77, 768) + nz = model(nz, timesteps, context) + + # upsample latent space to image with autoencoder + x = model.first_stage_model.post_quant_conv(nz) + x = model.first_stage_model.decoder(x) + + # make image correct size + x = x.reshape(3,512,512).permute(1,2,0) + dat = (x.detach().numpy().clip(0, 1)*255).astype(np.uint8) + print(dat.shape) + + # save image + from PIL import Image + im = Image.fromarray(dat) + im.save("/tmp/rendered.png") + exit(0) + """ + """ outs = model.cond_stage_model.transformer.text_model([1,2,3]) print(outs.numpy()) print(outs.numpy().shape) """ + """ from ldm.modules.diffusionmodules.openaimodel import UNetModel tmodel = UNetModel( image_size = 32, @@ -559,10 +606,11 @@ if __name__ == "__main__": use_checkpoint = True, legacy = False) prefix = "model.diffusion_model." + """ - #from ldm.modules.encoders.modules import FrozenCLIPEmbedder - #tmodel = FrozenCLIPEmbedder() - #prefix = "cond_stage_model." + from ldm.modules.encoders.modules import FrozenCLIPEmbedder + tmodel = FrozenCLIPEmbedder() + prefix = "cond_stage_model." #from ldm.models.autoencoder import AutoencoderKL #tmodel = AutoencoderKL( @@ -590,22 +638,31 @@ if __name__ == "__main__": sd[k[len(prefix):]] = dat[k] print("loading", len(sd)) tmodel.load_state_dict(sd, strict=True) + tmodel = tmodel.cuda() - # load apple latent space - nz = Tensor(np.load("datasets/stable_diffusion_apple.npy")) + ret = tmodel("a horse sized cat eating a bagel") + print(ret) + + re = model.cond_stage_model.transformer.text_model(phrase) + print(re.numpy()) + + exit(0) # run one pass of unet tnz = torch.Tensor(nz.numpy()) - ttimesteps = torch.Tensor([0]) - tcontext = torch.zeros(1, 77, 768) + timesteps = Tensor([10]) + context = Tensor.uniform(1, 77, 768) + + ttimesteps = torch.Tensor(timesteps.numpy()) + tcontext = torch.Tensor(context.numpy()) tnz = tmodel(tnz, ttimesteps, tcontext) - timesteps = Tensor([0]) - context = Tensor.zeros(1, 77, 768) nz = model(nz, timesteps, context) print(tnz) print(nz.numpy()) + print("match", np.mean((tnz.detach().numpy() - nz.numpy())**2)) + exit(0)