Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: vllm llama integration #129

Merged
merged 105 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
105 commits
Select commit Hold shift + click to select a range
c125132
first commit
jorgeantonio21 Jun 22, 2024
2161306
logic for prepare model inputs
jorgeantonio21 Jun 23, 2024
efa422a
add changes
jorgeantonio21 Jun 24, 2024
408c4e9
add atoma paged attention dependency and refactor code to integrate w…
jorgeantonio21 Jul 26, 2024
576a61a
work on sampling logic
jorgeantonio21 Jul 31, 2024
403252d
refactor selected token indices computation
jorgeantonio21 Jul 31, 2024
45b821f
minor changes
jorgeantonio21 Aug 1, 2024
ee6f278
add minor changes
jorgeantonio21 Aug 1, 2024
b7a4a47
compiler issues and simplify code
jorgeantonio21 Aug 2, 2024
9a55dbd
add changes
jorgeantonio21 Aug 2, 2024
d0a9ad4
refactors
jorgeantonio21 Aug 2, 2024
9efcc44
resolve compilation issues
jorgeantonio21 Aug 2, 2024
ece840c
minor mods
jorgeantonio21 Aug 5, 2024
c349810
resolve compilation issues
jorgeantonio21 Aug 12, 2024
5b982ae
resolve compilation issues
jorgeantonio21 Aug 12, 2024
b1ec4fa
resolve compilation issues
jorgeantonio21 Aug 12, 2024
1731ae8
resolve compilation issues
jorgeantonio21 Aug 12, 2024
a32cfbb
resolve compilation issues
jorgeantonio21 Aug 12, 2024
f753d8a
clippy
jorgeantonio21 Aug 12, 2024
3c28a8a
testing
jorgeantonio21 Aug 12, 2024
15b4e36
testing
jorgeantonio21 Aug 12, 2024
dc3a56f
testing
jorgeantonio21 Aug 12, 2024
d2227c5
remove sampling from testing
jorgeantonio21 Aug 12, 2024
4fdf80c
testing
jorgeantonio21 Aug 12, 2024
054c28b
testing
jorgeantonio21 Aug 12, 2024
1bace2b
remove sampling from testing
jorgeantonio21 Aug 12, 2024
4f4b46b
fmt
jorgeantonio21 Aug 12, 2024
9cc4673
add further changes
jorgeantonio21 Aug 12, 2024
2244266
refactor parts of the code, and add more logs
jorgeantonio21 Aug 12, 2024
1ba71b4
add changes
jorgeantonio21 Aug 13, 2024
4749979
add changes
jorgeantonio21 Aug 13, 2024
a5d7973
add changes
jorgeantonio21 Aug 13, 2024
64f0e5d
add changes
jorgeantonio21 Aug 13, 2024
39f9869
add changes
jorgeantonio21 Aug 13, 2024
8b88da8
add changes
jorgeantonio21 Aug 13, 2024
13be558
add changes
jorgeantonio21 Aug 13, 2024
9ddf868
resolve tests
jorgeantonio21 Aug 13, 2024
0fd1932
fmt
jorgeantonio21 Aug 13, 2024
63c4851
remove logging requests
jorgeantonio21 Aug 14, 2024
d82ca19
remove unnecessary panic!
jorgeantonio21 Aug 14, 2024
8fc880c
review the `LlmService` impl
jorgeantonio21 Aug 14, 2024
10b43cc
refactor tokenizer service spawn and further refactor the llm service
jorgeantonio21 Aug 16, 2024
b5a5055
address PR changes
jorgeantonio21 Aug 19, 2024
9247e37
resolve compilation issues
jorgeantonio21 Aug 19, 2024
dc89688
resolve compilation issues
jorgeantonio21 Aug 19, 2024
19e85d7
resolve compilation issues
jorgeantonio21 Aug 19, 2024
5317627
resolve compilation issues
jorgeantonio21 Aug 19, 2024
42e3ac2
testing
jorgeantonio21 Aug 19, 2024
dd1e305
add llama tests
jorgeantonio21 Aug 19, 2024
149982c
Merge branch 'vllm-model-executor' into llama-integration
jorgeantonio21 Aug 20, 2024
322a22e
address compilation issues
jorgeantonio21 Aug 20, 2024
427dd2e
resolve further compilation errors
jorgeantonio21 Aug 20, 2024
7b189ed
resolve further compilation errors
jorgeantonio21 Aug 20, 2024
fa89ed2
small changes
jorgeantonio21 Aug 21, 2024
167cdba
merge main and resolve conflicts
jorgeantonio21 Aug 21, 2024
9e60969
resolve compilation issues
jorgeantonio21 Aug 21, 2024
63bda32
resolve compilation issues and tests
jorgeantonio21 Aug 21, 2024
4d3cda2
improve logging
jorgeantonio21 Aug 21, 2024
baaf28d
improve logging
jorgeantonio21 Aug 21, 2024
39e716c
validation tokenizer communication issues
jorgeantonio21 Aug 21, 2024
24fcc1c
address llama test
jorgeantonio21 Aug 22, 2024
e788e77
changes
jorgeantonio21 Aug 22, 2024
6324795
add logits processor per sequence group for sampling
jorgeantonio21 Aug 23, 2024
4423d38
changes
jorgeantonio21 Aug 24, 2024
96e72e0
change dependency for atoma paged attention
jorgeantonio21 Aug 25, 2024
33cf1b3
add changes
jorgeantonio21 Aug 25, 2024
81ea4a2
refactor to atoma paged attention
jorgeantonio21 Aug 25, 2024
12ac17c
refactor to atoma paged attention
jorgeantonio21 Aug 25, 2024
2ed5374
add changes
jorgeantonio21 Aug 25, 2024
63f08e0
remove token output streamer for now
jorgeantonio21 Aug 25, 2024
9508659
resolve issue
jorgeantonio21 Aug 25, 2024
6fe7578
add multiple eos token ids
jorgeantonio21 Sep 17, 2024
3ba993c
import LlamaEosToks
jorgeantonio21 Sep 17, 2024
06db112
add changes
jorgeantonio21 Sep 17, 2024
6e3cf9b
further refactor
jorgeantonio21 Sep 17, 2024
24df292
further refactor
jorgeantonio21 Sep 17, 2024
4623324
further refactor
jorgeantonio21 Sep 17, 2024
4d4776d
add changes
jorgeantonio21 Sep 17, 2024
c929172
add import
jorgeantonio21 Sep 17, 2024
b54956e
squeeze logits
jorgeantonio21 Sep 17, 2024
cb1e4a4
clippy
jorgeantonio21 Sep 18, 2024
7ccec11
address clippy and compilation errors
jorgeantonio21 Sep 18, 2024
2ea5eaa
address clippy and compilation errors
jorgeantonio21 Sep 18, 2024
dbbb1da
address clippy and compilation errors
jorgeantonio21 Sep 18, 2024
ff7d91d
address clippy and compilation errors
jorgeantonio21 Sep 18, 2024
cfefe66
clippy
jorgeantonio21 Sep 18, 2024
3af9475
clippy
jorgeantonio21 Sep 18, 2024
2ffcde9
address tests clippy
jorgeantonio21 Sep 18, 2024
c844b05
address tests clippy
jorgeantonio21 Sep 18, 2024
b1d19f3
address tests clippy
jorgeantonio21 Sep 18, 2024
cd94f97
address tests clippy
jorgeantonio21 Sep 18, 2024
a2eaa3d
add changes
jorgeantonio21 Sep 18, 2024
1724cdf
add changes
jorgeantonio21 Sep 18, 2024
d365983
add changes
jorgeantonio21 Sep 18, 2024
69f5fb0
add changes
jorgeantonio21 Sep 18, 2024
c68614a
refactor check for max token length
jorgeantonio21 Sep 18, 2024
bae0a9c
remove unnecessary files
jorgeantonio21 Sep 18, 2024
f7aabca
merge main and resolve conflicts
jorgeantonio21 Sep 18, 2024
92b52f1
resolve bug with decoding full stream of tokens
jorgeantonio21 Sep 18, 2024
40ed214
resolve bug with decoding full stream of tokens
jorgeantonio21 Sep 18, 2024
d7c8a7d
resolve bug with decoding full stream of tokens
jorgeantonio21 Sep 18, 2024
d08df67
resolve bug with decoding full stream of tokens
jorgeantonio21 Sep 18, 2024
af3deb1
add changes
jorgeantonio21 Sep 18, 2024
1952a6e
add changes
jorgeantonio21 Sep 18, 2024
54ef3cf
add changes
jorgeantonio21 Sep 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ members = [
"atoma-storage",
"atoma-streamer",
"atoma-types",
"atoma-vllm",
]

[workspace.package]
Expand Down
5 changes: 4 additions & 1 deletion atoma-vllm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,18 @@ edition.workspace = true
async-trait.workspace = true
atoma-paged-attention.workspace = true
candle-core = { version = "0.6.0", features = ["cuda"] }
candle-nn = { version = "0.6.0", features = ["cuda"] }
candle-transformers = { version = "0.6.0", features = ["cuda"] }
futures.workspace = true
hf-hub.workspace = true
indexmap.workspace = true
metrics.workspace = true
metrics-exporter-prometheus.workspace = true
serde.workspace = true
serde_json.workspace = true
thiserror.workspace = true
tokenizers.workspace = true
tokio = { workspace = true, features = ["macros"] }
tokio = { workspace = true, features = ["macros", "fs"] }
tracing.workspace = true

[dev-dependencies]
Expand Down
4 changes: 4 additions & 0 deletions atoma-vllm/src/block_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,8 @@ pub enum BlockSpaceManagerError {
pub(crate) mod tests {
use std::sync::{Arc, RwLock};

use candle_transformers::generation::LogitsProcessor;

use crate::sequence::{tests::create_dummy_prompt, LogProb};

use super::*;
Expand Down Expand Up @@ -828,6 +830,7 @@ pub(crate) mod tests {
Instant::now(),
Default::default(),
Default::default(),
LogitsProcessor::new(0, None, None),
)
.expect("Failed to construct a new `SequenceGroup`");

Expand Down Expand Up @@ -1145,6 +1148,7 @@ pub(crate) mod tests {
Instant::now(),
Default::default(),
Default::default(),
LogitsProcessor::new(0, None, None),
)
.expect("Failed to get `SequenceGroup`");
let parent = seq_group.sequences.values().next().unwrap().clone();
Expand Down
6 changes: 5 additions & 1 deletion atoma-vllm/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ impl CacheConfig {
block_size: usize,
gpu_memory_utilization: f32,
swap_space: usize,
cache_dtype: String,
num_gpu_blocks_override: Option<usize>,
sliding_window: Option<usize>,
num_cpu_blocks: usize,
Expand Down Expand Up @@ -224,18 +223,23 @@ pub enum SchedulerConfigError {
#[derive(Clone, Debug)]
pub struct ModelConfig {
/// HuggingFace model identifier
#[allow(dead_code)]
model_id: String,
/// Dtype
#[allow(dead_code)]
dtype: String,
/// The model revision identifier
#[allow(dead_code)]
revision: String,
/// Maximum length of a sequence (including prompt and
/// output). If None, will be derived from the model.
#[allow(dead_code)]
max_model_len: usize,
/// Whether to disable sliding window. If True,
/// we will disable the sliding window functionality of the model.
/// If the model does not support sliding window, this argument is
/// ignored.
#[allow(dead_code)]
disable_sliding_window: bool,
}

Expand Down
1 change: 1 addition & 0 deletions atoma-vllm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub mod evictor;
pub mod llm_engine;
pub mod llm_service;
pub mod model_executor;
pub mod models;
pub mod policy;
pub mod sampling_params;
pub mod scheduler;
Expand Down
41 changes: 34 additions & 7 deletions atoma-vllm/src/llm_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use tokio::sync::{
mpsc::{error::SendError, UnboundedReceiver, UnboundedSender},
oneshot::error::RecvError,
};
use tracing::{error, info, instrument};
use tracing::{error, info, info_span, instrument, trace, Span};

use crate::{
model_executor::ModelThreadDispatcher,
Expand Down Expand Up @@ -53,6 +53,8 @@ pub struct LlmEngine {
scheduler: Scheduler<FcfsPolicy>,
/// Tokenizer for decoding sequences
tokenizer: Tokenizer,
/// Tracing span
span: Span,
}

impl LlmEngine {
Expand All @@ -72,6 +74,7 @@ impl LlmEngine {
scheduler,
tokenizer,
request_receiver,
span: info_span!("llm-engine"),
}
}

Expand All @@ -85,9 +88,13 @@ impl LlmEngine {
/// service.
#[instrument(skip(self))]
pub async fn run(mut self) -> Result<(), EngineError> {
let span = self.span.clone();
let _enter = span.enter();

loop {
tokio::select! {
Some(sequence_group) = self.request_receiver.recv() => {
trace!("Received new sequence group, with id = {}", sequence_group.request_id);
// 1. Adds the received `SequenceGroup` to the `Scheduler` instance.
self.scheduler.add_sequence_group(sequence_group);

Expand Down Expand Up @@ -116,6 +123,9 @@ impl LlmEngine {
&mut self,
outputs: Result<Vec<SequenceGroupOutput>, EngineError>,
) -> Result<(), EngineError> {
let span = self.span.clone();
let _enter = span.enter();

match outputs {
Ok(outputs) => {
// 1. Processes the newly AI generated outputs
Expand Down Expand Up @@ -152,6 +162,9 @@ impl LlmEngine {
/// 2. It sends a new `ExecuteModelRequest` to the `ModelExecutor`'s thread.
#[instrument(skip_all)]
pub fn step(&mut self) -> Result<(), EngineError> {
let span = self.span.clone();
let _enter = span.enter();

info!("`LlmEngine` new step..");
// 1. Schedule new requests
let (sequence_groups_metadata, scheduler_outputs) = self.scheduler.schedule()?;
Expand Down Expand Up @@ -293,13 +306,22 @@ impl LlmEngine {
.add_token_id(generated_token_id, sequence_output.logprob.clone())?;

// 5. Decode the generated output token id.
let generated_token = self
let token_ids = sequence_guard_lock.sequence_data.get_token_ids();
let generated_text = self
.tokenizer
.decode(&[generated_token_id], true)
.decode(&token_ids, true)
.map_err(|e| EngineError::TokenizerError(e.to_string()))?;

// 6. Update the `output_text` with the newly generated token,
// if in decoding phase.
let generated_token = if sequence_guard_lock.tokens.last().is_some() {
let start = sequence_guard_lock.output_text.chars().count();
generated_text.chars().skip(start).collect::<String>()
} else {
let start = sequence_guard_lock.prompt.chars().count();
generated_text.chars().skip(start).collect()
};

sequence_guard_lock.output_text.push_str(&generated_token);

// 7. Check if the last generated token is a stop token.
Expand Down Expand Up @@ -335,7 +357,7 @@ impl LlmEngine {
// 10. Check if the `Sequence`'s output length exceeds that of
// Request's `max_new_tokens`.
let sequence_output_len = sequence_guard_lock.get_output_len();
if sequence_output_len > stopping_criteria_params.max_new_tokens as usize {
if sequence_output_len >= stopping_criteria_params.max_new_tokens as usize {
sequence_guard_lock.set_sequence_status(SequenceStatus::FinishedLengthCapped)
}

Expand All @@ -344,7 +366,8 @@ impl LlmEngine {
} else {
// NOTE: in this case, we are not sampling newly
// generated tokens. That is, we are in prefill
// phase (possibly while chunking). For this reason,
// phase (possibly while chunking)
// without generating the next token. For this reason,
// we do not have to add tokens to the current
// `Sequence`'s state.

Expand Down Expand Up @@ -413,19 +436,23 @@ impl GenerateRequestOutput {
})
.collect::<Vec<_>>();

let is_finished = sequence_group.is_finished();
if is_finished {
sequence_group.set_finished_time(Instant::now());
}
Self {
request_id: sequence_group.request_id.clone(),
inference_outputs,
prompt: sequence_group.prompt(),
prompt_token_ids: sequence_group.prompt_token_ids(),
is_finished: sequence_group.is_finished(),
is_finished,
metrics: sequence_group.metrics.clone(),
}
}
}

/// `InferenceOutput` - Output of running AI inference on a given sequence group
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct InferenceOutput {
/// The index of the output in the request
pub index: usize,
Expand Down
Loading
Loading