Skip to content

Commit

Permalink
Do not have hardcoded num_blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
nitinkedia7 committed May 13, 2024
1 parent b096c8c commit aa1e9cc
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions simulator/profiling/attention/attention_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from math import ceil
from math import ceil, floor
from typing import List

import numpy as np
Expand Down Expand Up @@ -63,14 +63,18 @@ def __init__(
self._device,
)
self._max_blocks_per_sequence = ceil(max_model_len / self._block_size)
# We create (big) KV tensors and reuse them
element_size = torch.randn(1, dtype=self._dtype).element_size()
block_memory_size = 2 * self._block_size * self._n_worker_kv_heads * self._head_dim * element_size
self.total_num_blocks = floor((torch.cuda.mem_get_info()[1] * 0.9) / (block_memory_size * model_config.num_layers))
self.kv_cache = get_attention_wrapper().get_cache_block(
self.total_num_blocks, dtype=self._dtype, device=self._device
)

def _get_input_tensors(
self,
attention_input: AttentionInput,
):
total_num_blocks = max(
10000, 1 + attention_input.batch_size * self._max_blocks_per_sequence
)
num_tokens_per_seq = (
attention_input.prefill_chunk_size if attention_input.is_prefill else 1
)
Expand All @@ -93,11 +97,6 @@ def _get_input_tensors(
dtype=self._dtype,
device=self._device,
)
# We create (big) KV tensors every time.
# A better solution would be to create them once and reuse them.
kv_cache = get_attention_wrapper().get_cache_block(
total_num_blocks, dtype=self._dtype, device=self._device
)
# Create SequenceMetadataProxy objects corresponding to AttentionInput
seq_metadata_list: List[SequenceMetadataProxy] = []
for _ in range(attention_input.batch_size):
Expand All @@ -108,11 +107,11 @@ def _get_input_tensors(
total_len=num_tokens_per_seq + attention_input.kv_cache_size,
processed_len=attention_input.kv_cache_size,
block_table=np.random.default_rng()
.integers(low=0, high=total_num_blocks - 1, size=num_blocks)
.integers(low=0, high=self.total_num_blocks-1, size=num_blocks)
.tolist(),
)
seq_metadata_list.append(seq_metadata)
return seq_metadata_list, query, key, value, kv_cache
return seq_metadata_list, query, key, value, self.kv_cache

@torch.inference_mode()
def profile(
Expand Down

0 comments on commit aa1e9cc

Please sign in to comment.