Skip to content

Commit

Permalink
freeze res module
Browse files Browse the repository at this point in the history
  • Loading branch information
cloneofsimo committed Jun 6, 2024
1 parent 85c854d commit 8913d38
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 21 deletions.
44 changes: 29 additions & 15 deletions advanced/main_t2i_highres.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __init__(self, model, ln=True):
self.model = model
self.ln = ln
self.stratified = False
self.t_transform = lambda t : (math.sqrt(3) * t / (1 + (math.sqrt(3) - 1) *t))

def forward(self, x, cond, randomly_augment_x_latent=False):

Expand All @@ -151,18 +152,19 @@ def forward(self, x, cond, randomly_augment_x_latent=False):
# and also, they are correct. This comes with the tradeoff that in worst case we drop 66% of the batches.

new_b = int(0.18 * b * h * w / (new_h * new_w)) * 2
new_b = min(new_b, 1)
new_b = max(new_b, 1)
x = x[
:new_b,
:,
h // 2 - new_h // 2 : h // 2 + new_h // 2,
w // 2 - new_w // 2 : w // 2 + new_w // 2,
]
else:
new_b = min(int(b * 0.18) * 2, 1)
new_b = max(int(b * 0.18) * 2, 1)
x = x[:new_b]

b = x.size(0)
#print(x.size())
if self.ln:
if self.stratified:
# stratified sampling of normals
Expand All @@ -178,6 +180,7 @@ def forward(self, x, cond, randomly_augment_x_latent=False):
else:
t = torch.rand((b,)).to(x.device)
texp = t.view([b, *([1] * len(x.shape[1:]))])
#texp = self.t_transform(texp)
z1 = torch.randn_like(x)
zt = (1 - texp) * x + texp * z1

Expand Down Expand Up @@ -443,17 +446,18 @@ def main(
),
True,
).cuda()
statedict = torch.load(
"/home/ubuntu/ckpts_36L_2/model_102401/ema1.pt",
map_location="cpu",
)
# remove model.layers.23.modC.1.weight
# statedict.pop("model.layers.31.modC.1.weight")
if True:
statedict = torch.load(
"/home/ubuntu/ckpts_36L_2/model_102401/ema1.pt",
map_location="cpu",
)
# remove model.layers.23.modC.1.weight
# statedict.pop("model.layers.31.modC.1.weight")

rf.load_state_dict(
statedict,
strict=False,
)
rf.load_state_dict(
statedict,
strict=False,
)
if resize_pe_at_initialization:
rf.model.extend_pe((16, 16), (vaeres // 2, vaeres // 2))

Expand All @@ -476,7 +480,7 @@ def main(
# barrier
torch.distributed.barrier()

os.environ["LOCAL_WORLD_SIZE"] = str(8)
os.environ["LOCAL_WORLD_SIZE"] = str(min(8, int(os.environ.get("WORLD_SIZE"))))
# WORLD_SIZE: Total number of processes to launch across all nodes.
# LOCAL_WORLD_SIZE: Total number of processes to launch for each node.
# RANK: Rank of the current process, which is the range between 0 to WORLD_SIZE - 1.
Expand Down Expand Up @@ -585,6 +589,16 @@ def dequantize_t5(tensor):
optimizer_grouped_parameters = []
final_optimizer_settings = {}

# requires grad for first 2 and last 2 layer
for n, p in rf.named_parameters():
if "layers" in n:
if any(layername in n for layername in ["layers.0.", "layers.1.", "layers.34.", "layers.35."]):
p.requires_grad = True
else:
p.requires_grad = False
else:
p.requires_grad = True

for n, p in rf.named_parameters():
group_parameters = {}
if p.requires_grad:
Expand Down Expand Up @@ -630,7 +644,7 @@ def dequantize_t5(tensor):
AdamOptimizer = torch.optim.AdamW

optimizer = AdamOptimizer(
rf.parameters(), lr=learning_rate * (32 / hidden_dim), betas=(0.9, 0.95)
optimizer_grouped_parameters, betas=(0.9, 0.95)
)

lr_scheduler = get_scheduler(
Expand Down Expand Up @@ -772,7 +786,7 @@ def dequantize_t5(tensor):
f"norm: {norm}, loss: {loss.item()}, global_step: {global_step}"
)

if global_step % 4096 == 1:
if global_step % 4096 == 0:

os.makedirs(f"{save_dir}/model_{global_step}", exist_ok=True)
save_zero_three_model(
Expand Down
6 changes: 3 additions & 3 deletions advanced/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,8 @@ def forward(self, x, t, conds, **kwargs):
x = x + self.positional_encoding[:, pe_indexes]

# process conditions for MMDiT Blocks
c_seq = conds["c_seq"][:b] # B, T_c, D_c
t = t[:b]
c_seq = conds["c_seq"][0:b] # B, T_c, D_c
t = t[0:b]
# c_vec = conds["c_vec"] # B, D_gc
c = self.cond_seq_linear(c_seq) # B, T_c, D
c = torch.cat([self.register_tokens.repeat(c.size(0), 1, 1), c], dim=1)
Expand Down Expand Up @@ -461,7 +461,7 @@ def forward(self, x, t, conds, **kwargs):
if __name__ == "__main__":
model = MMDiT(max_seq=32 * 32)
model.extend_pe((32, 32), (64, 64))
x = torch.randn(2, 4, 20, 48)
x = torch.randn(1, 4, 20, 48)
t = torch.randn(8)
conds = {"c_seq": torch.randn(8, 32, 2048)}
out = model(x, t, conds)
Expand Down
6 changes: 3 additions & 3 deletions advanced/run_multi_node_resize.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ done

deepspeed --hostfile=./hostfiles \
main_t2i_highres.py \
--learning_rate 0.0366 \
--learning_rate 0.006 \
--hidden_dim 2560 \
--n_layers 36 \
--run_name node-2-highres \
--save_dir "/home/ubuntu/ckpts_36L_2_highres" \
--save_dir "/home/ubuntu/ckpts_36L_2_highres_freezemost" \
--num_train_epochs 200 \
--train_batch_size 384 \
--train_batch_size 256 \
--per_device_train_batch_size 4 \
--train_dir "/home/ubuntu/laionpop" \
--seed 3 \
Expand Down

0 comments on commit 8913d38

Please sign in to comment.