Skip to content

Commit

Permalink
support batch size in SD example
Browse files Browse the repository at this point in the history
  • Loading branch information
carlushuang committed Jul 29, 2023
1 parent 258a5d4 commit 8ff009c
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions examples/05_stable_diffusion/src/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def benchmark_unet(

latent_model_input_pt = torch.randn(batch_size, 4, height, width).cuda().half()
text_embeddings_pt = torch.randn(batch_size, 64, hidden_dim).cuda().half()
timesteps_pt = torch.Tensor([1, 1]).cuda().half()
timesteps_pt = torch.Tensor([1, 1]*(batch_size // 2)).cuda().half()

with autocast("cuda"):
pt_ys = pt_mod(
Expand Down Expand Up @@ -148,7 +148,7 @@ def benchmark_clip(
if tokenizer is None:
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_input = tokenizer(
["a photo of an astronaut riding a horse on mars"],
["a photo of an astronaut riding a horse on mars"] * batch_size,
padding="max_length",
max_length=seqlen,
truncation=True,
Expand All @@ -161,8 +161,9 @@ def benchmark_clip(
attention_mask = None

position_ids = torch.arange(seqlen).expand((batch_size, -1)).cuda()
pt_ys = pt_mod(input_ids, attention_mask, position_ids)
print("pt output:", pt_ys[0].shape)
if batch_size == 1:
pt_ys = pt_mod(input_ids, attention_mask, position_ids)
print("pt output:", pt_ys[0].shape)

# PT benchmark
if benchmark_pt:
Expand Down Expand Up @@ -292,7 +293,7 @@ def benchmark_vae(
@click.option("--benchmark-pt", type=bool, default=False, help="run pt benchmark")
@click.option("--profile-op", type=bool, default=False, help="profile model on op level")
def benchmark_diffusers(local_dir, batch_size, verify, benchmark_pt, profile_op):
assert batch_size == 1, "batch size must be 1 for submodule verification"
#assert batch_size == 1, "batch size must be 1 for submodule verification"
logging.getLogger().setLevel(logging.INFO)
np.random.seed(0)
torch.manual_seed(4896)
Expand Down

0 comments on commit 8ff009c

Please sign in to comment.