-
Notifications
You must be signed in to change notification settings - Fork 44
/
main_submission.py
121 lines (97 loc) · 3.57 KB
/
main_submission.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from fastapi import FastAPI
import logging
# Lit-GPT imports
import sys
import time
from pathlib import Path
import json
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
import lightning as L
import torch
torch.set_float32_matmul_precision("high")
from lit_gpt import GPT, Tokenizer, Config
from lit_gpt.utils import lazy_load, quantization
# Toy submission imports
from helper import toysubmission_generate
from api import (
ProcessRequest,
ProcessResponse,
TokenizeRequest,
TokenizeResponse,
Token,
)
app = FastAPI()
logger = logging.getLogger(__name__)
# Configure the logging module
logging.basicConfig(level=logging.INFO)
quantize = "bnb.nf4-dq" # 4-bit NormalFloat with Double-Quantization (see QLoRA paper)
checkpoint_dir = Path("checkpoints/openlm-research/open_llama_3b")
precision = "bf16-true" # weights and data in bfloat16 precision
fabric = L.Fabric(devices=1, accelerator="cuda", precision=precision)
with open(checkpoint_dir / "lit_config.json") as fp:
config = Config(**json.load(fp))
checkpoint_path = checkpoint_dir / "lit_model.pth"
logger.info(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}")
with fabric.init_module(empty_init=True), quantization(quantize):
model = GPT(config)
with lazy_load(checkpoint_path) as checkpoint:
model.load_state_dict(checkpoint, strict=quantize is None)
model.eval()
model = fabric.setup(model)
tokenizer = Tokenizer(checkpoint_dir)
@app.post("/process")
async def process_request(input_data: ProcessRequest) -> ProcessResponse:
if input_data.seed is not None:
L.seed_everything(input_data.seed)
logger.info("Using device: {}".format(fabric.device))
encoded = tokenizer.encode(
input_data.prompt, bos=True, eos=False, device=fabric.device
)
prompt_length = encoded.size(0)
max_returned_tokens = prompt_length + input_data.max_new_tokens
assert max_returned_tokens <= model.config.block_size, (
max_returned_tokens,
model.config.block_size,
) # maximum rope cache length
t0 = time.perf_counter()
tokens, logprobs, top_logprobs = toysubmission_generate(
model,
encoded,
max_returned_tokens,
max_seq_length=max_returned_tokens,
temperature=input_data.temperature,
top_k=input_data.top_k,
)
t = time.perf_counter() - t0
model.reset_cache()
output = tokenizer.decode(tokens)
tokens_generated = tokens.size(0) - prompt_length
logger.info(
f"Time for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec"
)
logger.info(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
generated_tokens = []
for t, lp, tlp in zip(tokens, logprobs, top_logprobs):
idx, val = tlp
tok_str = tokenizer.processor.decode([idx])
token_tlp = {tok_str: val}
generated_tokens.append(
Token(text=tokenizer.decode(t), logprob=lp, top_logprob=token_tlp)
)
logprobs_sum = sum(logprobs)
# Process the input data here
return ProcessResponse(
text=output, tokens=generated_tokens, logprob=logprobs_sum, request_time=t
)
@app.post("/tokenize")
async def tokenize(input_data: TokenizeRequest) -> TokenizeResponse:
logger.info("Using device: {}".format(fabric.device))
t0 = time.perf_counter()
encoded = tokenizer.encode(
input_data.text, bos=True, eos=False, device=fabric.device
)
t = time.perf_counter() - t0
tokens = encoded.tolist()
return TokenizeResponse(tokens=tokens, request_time=t)