Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding GQA #1139

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions llama/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,14 @@ def generate(
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")

if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)

prev_pos = 0
eos_reached = torch.tensor([False] * bsz, device="cuda")
input_text_mask = tokens != pad_id

if min_prompt_len == total_len:
logits = self.model.forward(tokens, prev_pos)
token_logprobs = -F.cross_entropy(
Expand All @@ -184,7 +186,7 @@ def generate(
)

for cur_pos in range(min_prompt_len, total_len):
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
logits = self.model.forward(tokens[:, :cur_pos], prev_pos)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
Expand All @@ -197,13 +199,15 @@ def generate(
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
tokens[:, cur_pos] = next_token

if logprobs:
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
input=logits.transpose(1, 2),
target=tokens[:, prev_pos + 1 : cur_pos + 1],
reduction="none",
ignore_index=pad_id,
)

eos_reached |= (~input_text_mask[:, cur_pos]) & (
next_token == self.tokenizer.eos_id
)
Expand All @@ -213,23 +217,27 @@ def generate(

if logprobs:
token_logprobs = token_logprobs.tolist()

out_tokens, out_logprobs = [], []
for i, toks in enumerate(tokens.tolist()):
# cut to max gen len
start = 0 if echo else len(prompt_tokens[i])
toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
probs = None

if logprobs:
probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]
# cut to eos tok if any

if self.tokenizer.eos_id in toks:
eos_idx = toks.index(self.tokenizer.eos_id)
toks = toks[:eos_idx]
probs = probs[:eos_idx] if logprobs else None

out_tokens.append(toks)
out_logprobs.append(probs)

return (out_tokens, out_logprobs if logprobs else None)


def text_completion(
self,
prompts: List[str],
Expand Down
118 changes: 77 additions & 41 deletions llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import fairscale.nn.model_parallel.initialize as fs_init
import torch
import torch.nn.functional as F

from fairscale.nn.model_parallel.layers import (
ColumnParallelLinear,
ParallelEmbedding,
RowParallelLinear,
)
from torch import nn
import torch.nn as nn


@dataclass
Expand All @@ -24,13 +26,57 @@ class ModelArgs:
n_kv_heads: Optional[int] = None
vocab_size: int = -1 # defined later by tokenizer
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None
ffn_dim_multiplier: float
norm_eps: float = 1e-5

max_batch_size: int = 32
max_seq_len: int = 2048



query_groups: int = 32 # New parameter for GQA




class GroupedQueryAttention(nn.Module):
def __init__(self, embed_dim, num_heads, query_groups):
super(GroupedQueryAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.query_groups = query_groups
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"

self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)
self.o_proj = nn.Linear(embed_dim, embed_dim)
self.scale = self.head_dim ** -0.5

def forward(self, x):
B, T, C = x.shape
qkv = self.qkv_proj(x)
qkv = qkv.view(B, T, self.num_heads, 3 * self.head_dim)
q, k, v = qkv.chunk(3, dim=-1)

q_groups = q.split(self.query_groups, dim=1)
k_groups = k.split(self.query_groups, dim=1)
v_groups = v.split(self.query_groups, dim=1)

attn_outputs = []
for q_group, k_group, v_group in zip(q_groups, k_groups, v_groups):
scores = torch.einsum('bthd,bThd->bhtT', q_group, k_group) * self.scale
attn_weights = torch.nn.functional.softmax(scores, dim=-1)
attn_output = torch.einsum('bhtT,bThd->bthd', attn_weights, v_group)
attn_outputs.append(attn_output)

attn_output = torch.cat(attn_outputs, dim=1).contiguous()
attn_output = attn_output.view(B, T, C)
output = self.o_proj(attn_output)
return output




class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
"""
Expand Down Expand Up @@ -173,14 +219,6 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
)


class Attention(nn.Module):
"""Multi-head attention module."""
def __init__(self, args: ModelArgs):
"""
Initialize the Attention module.

Args:
args (ModelArgs): Model configuration parameters.

Attributes:
n_kv_heads (int): Number of key and value heads.
Expand All @@ -195,14 +233,17 @@ def __init__(self, args: ModelArgs):
cache_k (torch.Tensor): Cached keys for attention.
cache_v (torch.Tensor): Cached values for attention.

"""
class Attention(nn.Module):
"""Multi-head attention module with Grouped Query Attention."""
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.query_groups = args.query_groups # Add query_groups parameter in ModelArgs

self.wq = ColumnParallelLinear(
args.dim,
Expand Down Expand Up @@ -257,19 +298,6 @@ def forward(
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
"""
Forward pass of the attention module.

Args:
x (torch.Tensor): Input tensor.
start_pos (int): Starting position for caching.
freqs_cis (torch.Tensor): Precomputed frequency tensor.
mask (torch.Tensor, optional): Attention mask tensor.

Returns:
torch.Tensor: Output tensor after attention.

"""
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

Expand All @@ -295,13 +323,26 @@ def forward(
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)

# Split queries, keys, values into groups for GQA
q_groups = xq.split(self.query_groups, dim=1)
k_groups = keys.split(self.query_groups, dim=1)
v_groups = values.split(self.query_groups, dim=1)

attn_outputs = []
for q_group, k_group, v_group in zip(q_groups, k_groups, v_groups):
scores = torch.matmul(q_group, k_group.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(q_group)
attn_output = torch.matmul(scores, v_group) # (bs, n_local_heads, seqlen, head_dim)
attn_outputs.append(attn_output)

# Concatenate attention outputs from all groups
attn_output = torch.cat(attn_outputs, dim=1).contiguous()
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(attn_output)



class FeedForward(nn.Module):
Expand Down Expand Up @@ -348,6 +389,7 @@ def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))



class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs):
"""
Expand All @@ -372,7 +414,7 @@ def __init__(self, layer_id: int, args: ModelArgs):
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = Attention(args)
self.attention = Attention(args) # Use the updated Attention class
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=4 * args.dim,
Expand Down Expand Up @@ -410,6 +452,10 @@ def forward(
return out






class Transformer(nn.Module):
def __init__(self, params: ModelArgs):
"""
Expand All @@ -427,7 +473,6 @@ def __init__(self, params: ModelArgs):
norm (RMSNorm): Layer normalization for the model output.
output (ColumnParallelLinear): Linear layer for final output.
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.

"""
super().__init__()
self.params = params
Expand All @@ -448,8 +493,6 @@ def __init__(self, params: ModelArgs):
)

self.freqs_cis = precompute_freqs_cis(
# Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096.
# Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning.
self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
)

Expand All @@ -464,7 +507,6 @@ def forward(self, tokens: torch.Tensor, start_pos: int):

Returns:
torch.Tensor: Output logits after applying the Transformer model.

"""
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
Expand All @@ -476,13 +518,7 @@ def forward(self, tokens: torch.Tensor, start_pos: int):
mask = torch.full(
(seqlen, seqlen), float("-inf"), device=tokens.device
)

mask = torch.triu(mask, diagonal=1)

# When performing key-value caching, we compute the attention scores
# only for the new sequence. Thus, the matrix of scores is of size
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
# j > cache_len + i, since row i corresponds to token cache_len + i.
mask = torch.hstack([
torch.zeros((seqlen, start_pos), device=tokens.device),
mask
Expand Down
43 changes: 35 additions & 8 deletions llama/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@
logger = getLogger()




class Tokenizer:
"""tokenizing and encoding/decoding text using SentencePiece."""
"""Tokenizing and encoding/decoding text using SentencePiece."""

def __init__(self, model_path: str):
"""
Initializes the Tokenizer with a SentencePiece model.

Args:
model_path (str): The path to the SentencePiece model file.
"""
# reload tokenizer
assert os.path.isfile(model_path), model_path
self.sp_model = SentencePieceProcessor(model_file=model_path)
logger.info(f"Reloaded SentencePiece model from {model_path}")
Expand All @@ -35,7 +37,7 @@ def __init__(self, model_path: str):
)
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()

def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
def encode(self, s: str, bos: bool = True, eos: bool = True) -> List[int]:
"""
Encodes a string into a list of token IDs.

Expand All @@ -47,13 +49,13 @@ def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
Returns:
List[int]: A list of token IDs.
"""
assert type(s) is str
t = self.sp_model.encode(s)
assert isinstance(s, str)
tokens = self.sp_model.encode(s)
if bos:
t = [self.bos_id] + t
tokens = [self.bos_id] + tokens
if eos:
t = t + [self.eos_id]
return t
tokens = tokens + [self.eos_id]
return tokens

def decode(self, t: List[int]) -> str:
"""
Expand All @@ -66,3 +68,28 @@ def decode(self, t: List[int]) -> str:
str: The decoded string.
"""
return self.sp_model.decode(t)

def tokenize(self, s: str) -> List[str]:
"""
Tokenizes a string into subword tokens.

Args:
s (str): The input string to be tokenized.

Returns:
List[str]: A list of subword tokens.
"""
return self.sp_model.encode_as_pieces(s)

def detokenize(self, tokens: List[str]) -> str:
"""
Detokenizes a list of subword tokens into a string.

Args:
tokens (List[str]): The list of subword tokens to be detokenized.

Returns:
str: The detokenized string.
"""
return self.sp_model.decode_pieces(tokens)