diff --git a/advanced/main_t2i_highres.py b/advanced/main_t2i_highres.py index 5ce5460..918fd71 100644 --- a/advanced/main_t2i_highres.py +++ b/advanced/main_t2i_highres.py @@ -130,6 +130,7 @@ def __init__(self, model, ln=True): self.model = model self.ln = ln self.stratified = False + self.t_transform = lambda t : (math.sqrt(3) * t / (1 + (math.sqrt(3) - 1) *t)) def forward(self, x, cond, randomly_augment_x_latent=False): @@ -151,7 +152,7 @@ def forward(self, x, cond, randomly_augment_x_latent=False): # and also, they are correct. This comes with the tradeoff that in worst case we drop 66% of the batches. new_b = int(0.18 * b * h * w / (new_h * new_w)) * 2 - new_b = min(new_b, 1) + new_b = max(new_b, 1) x = x[ :new_b, :, @@ -159,10 +160,11 @@ def forward(self, x, cond, randomly_augment_x_latent=False): w // 2 - new_w // 2 : w // 2 + new_w // 2, ] else: - new_b = min(int(b * 0.18) * 2, 1) + new_b = max(int(b * 0.18) * 2, 1) x = x[:new_b] b = x.size(0) + #print(x.size()) if self.ln: if self.stratified: # stratified sampling of normals @@ -178,6 +180,7 @@ def forward(self, x, cond, randomly_augment_x_latent=False): else: t = torch.rand((b,)).to(x.device) texp = t.view([b, *([1] * len(x.shape[1:]))]) + #texp = self.t_transform(texp) z1 = torch.randn_like(x) zt = (1 - texp) * x + texp * z1 @@ -443,17 +446,18 @@ def main( ), True, ).cuda() - statedict = torch.load( - "/home/ubuntu/ckpts_36L_2/model_102401/ema1.pt", - map_location="cpu", - ) - # remove model.layers.23.modC.1.weight - # statedict.pop("model.layers.31.modC.1.weight") + if True: + statedict = torch.load( + "/home/ubuntu/ckpts_36L_2/model_102401/ema1.pt", + map_location="cpu", + ) + # remove model.layers.23.modC.1.weight + # statedict.pop("model.layers.31.modC.1.weight") - rf.load_state_dict( - statedict, - strict=False, - ) + rf.load_state_dict( + statedict, + strict=False, + ) if resize_pe_at_initialization: rf.model.extend_pe((16, 16), (vaeres // 2, vaeres // 2)) @@ -476,7 +480,7 @@ def main( # barrier torch.distributed.barrier() - os.environ["LOCAL_WORLD_SIZE"] = str(8) + os.environ["LOCAL_WORLD_SIZE"] = str(min(8, int(os.environ.get("WORLD_SIZE")))) # WORLD_SIZE: Total number of processes to launch across all nodes. # LOCAL_WORLD_SIZE: Total number of processes to launch for each node. # RANK: Rank of the current process, which is the range between 0 to WORLD_SIZE - 1. @@ -585,6 +589,16 @@ def dequantize_t5(tensor): optimizer_grouped_parameters = [] final_optimizer_settings = {} + # requires grad for first 2 and last 2 layer + for n, p in rf.named_parameters(): + if "layers" in n: + if any(layername in n for layername in ["layers.0.", "layers.1.", "layers.34.", "layers.35."]): + p.requires_grad = True + else: + p.requires_grad = False + else: + p.requires_grad = True + for n, p in rf.named_parameters(): group_parameters = {} if p.requires_grad: @@ -630,7 +644,7 @@ def dequantize_t5(tensor): AdamOptimizer = torch.optim.AdamW optimizer = AdamOptimizer( - rf.parameters(), lr=learning_rate * (32 / hidden_dim), betas=(0.9, 0.95) + optimizer_grouped_parameters, betas=(0.9, 0.95) ) lr_scheduler = get_scheduler( @@ -772,7 +786,7 @@ def dequantize_t5(tensor): f"norm: {norm}, loss: {loss.item()}, global_step: {global_step}" ) - if global_step % 4096 == 1: + if global_step % 4096 == 0: os.makedirs(f"{save_dir}/model_{global_step}", exist_ok=True) save_zero_three_model( diff --git a/advanced/mmdit.py b/advanced/mmdit.py index 99ce416..25278fa 100644 --- a/advanced/mmdit.py +++ b/advanced/mmdit.py @@ -373,8 +373,8 @@ def forward(self, x, t, conds, **kwargs): x = x + self.positional_encoding[:, pe_indexes] # process conditions for MMDiT Blocks - c_seq = conds["c_seq"][:b] # B, T_c, D_c - t = t[:b] + c_seq = conds["c_seq"][0:b] # B, T_c, D_c + t = t[0:b] # c_vec = conds["c_vec"] # B, D_gc c = self.cond_seq_linear(c_seq) # B, T_c, D c = torch.cat([self.register_tokens.repeat(c.size(0), 1, 1), c], dim=1) @@ -461,7 +461,7 @@ def forward(self, x, t, conds, **kwargs): if __name__ == "__main__": model = MMDiT(max_seq=32 * 32) model.extend_pe((32, 32), (64, 64)) - x = torch.randn(2, 4, 20, 48) + x = torch.randn(1, 4, 20, 48) t = torch.randn(8) conds = {"c_seq": torch.randn(8, 32, 2048)} out = model(x, t, conds) diff --git a/advanced/run_multi_node_resize.sh b/advanced/run_multi_node_resize.sh index 07b1cdd..9f2c35a 100644 --- a/advanced/run_multi_node_resize.sh +++ b/advanced/run_multi_node_resize.sh @@ -24,13 +24,13 @@ done deepspeed --hostfile=./hostfiles \ main_t2i_highres.py \ - --learning_rate 0.0366 \ + --learning_rate 0.006 \ --hidden_dim 2560 \ --n_layers 36 \ --run_name node-2-highres \ - --save_dir "/home/ubuntu/ckpts_36L_2_highres" \ + --save_dir "/home/ubuntu/ckpts_36L_2_highres_freezemost" \ --num_train_epochs 200 \ - --train_batch_size 384 \ + --train_batch_size 256 \ --per_device_train_batch_size 4 \ --train_dir "/home/ubuntu/laionpop" \ --seed 3 \