From 8ff009cfd72b4415a1f1b1b73329df70cef606d7 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sat, 29 Jul 2023 03:20:13 -0400 Subject: [PATCH] support batch size in SD example --- examples/05_stable_diffusion/src/benchmark.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/05_stable_diffusion/src/benchmark.py b/examples/05_stable_diffusion/src/benchmark.py index d8d88afae..8e4f47877 100644 --- a/examples/05_stable_diffusion/src/benchmark.py +++ b/examples/05_stable_diffusion/src/benchmark.py @@ -67,7 +67,7 @@ def benchmark_unet( latent_model_input_pt = torch.randn(batch_size, 4, height, width).cuda().half() text_embeddings_pt = torch.randn(batch_size, 64, hidden_dim).cuda().half() - timesteps_pt = torch.Tensor([1, 1]).cuda().half() + timesteps_pt = torch.Tensor([1, 1]*(batch_size // 2)).cuda().half() with autocast("cuda"): pt_ys = pt_mod( @@ -148,7 +148,7 @@ def benchmark_clip( if tokenizer is None: tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") text_input = tokenizer( - ["a photo of an astronaut riding a horse on mars"], + ["a photo of an astronaut riding a horse on mars"] * batch_size, padding="max_length", max_length=seqlen, truncation=True, @@ -161,8 +161,9 @@ def benchmark_clip( attention_mask = None position_ids = torch.arange(seqlen).expand((batch_size, -1)).cuda() - pt_ys = pt_mod(input_ids, attention_mask, position_ids) - print("pt output:", pt_ys[0].shape) + if batch_size == 1: + pt_ys = pt_mod(input_ids, attention_mask, position_ids) + print("pt output:", pt_ys[0].shape) # PT benchmark if benchmark_pt: @@ -292,7 +293,7 @@ def benchmark_vae( @click.option("--benchmark-pt", type=bool, default=False, help="run pt benchmark") @click.option("--profile-op", type=bool, default=False, help="profile model on op level") def benchmark_diffusers(local_dir, batch_size, verify, benchmark_pt, profile_op): - assert batch_size == 1, "batch size must be 1 for submodule verification" + #assert batch_size == 1, "batch size must be 1 for submodule verification" logging.getLogger().setLevel(logging.INFO) np.random.seed(0) torch.manual_seed(4896)