Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
Cifko committed Mar 28, 2024
1 parent 41a3da9 commit c5f3964
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 7 deletions.
149 changes: 149 additions & 0 deletions atoma-inference/src/candle/llama.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;

#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

use anyhow::{bail, Error as E, Result};

use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};

use candle_transformers::models::llama as model;
use model::{Llama, LlamaConfig};
use tokenizers::Tokenizer;

use crate::candle::{device, hub_load_safetensors, token_output_stream::TokenOutputStream};

const EOS_TOKEN: &str = "</s>";

#[derive(Clone, Debug, Copy, PartialEq, Eq)]
enum Which {
V1,
V2,
Solar10_7B,
TinyLlama1_1BChat,
}

pub struct Config {
temperature: Option<f64>,
top_p: Option<f64>,
seed: u64,
sample_len: usize,
no_kv_cache: bool,
dtype: Option<String>,
model_id: Option<String>,
revision: Option<String>,
which: Which,
use_flash_attn: bool,
repeat_penalty: f32,
repeat_last_n: usize,
}

impl Default for Config {
fn default() -> Self {
Self {
temperature: None,
top_p: None,
seed: 299792458,
sample_len: 10000,
no_kv_cache: false,
dtype: None,
model_id: None,
revision: None,
which: Which::TinyLlama1_1BChat,
use_flash_attn: false,
repeat_penalty: 1.,
repeat_last_n: 64,
}
}
}

pub fn run(prompt: String, cfg: Config) -> Result<String> {
let device = device()?;
let dtype = match cfg.dtype.as_deref() {
Some("f16") => DType::F16,
Some("bf16") => DType::BF16,
Some("f32") => DType::F32,
Some(dtype) => bail!("Unsupported dtype {dtype}"),
None => DType::F16,
};
let (llama, tokenizer_filename, mut cache) = {
let api = Api::new()?;
let model_id = cfg.model_id.unwrap_or_else(|| match cfg.which {
Which::V1 => "Narsil/amall-7b".to_string(),
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
});
println!("loading the model weights from {model_id}");
let revision = cfg.revision.unwrap_or("main".to_string());
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));

let tokenizer_filename = api.get("tokenizer.json")?;
let config_filename = api.get("config.json")?;
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let config = config.into_config(cfg.use_flash_attn);

let filenames = match cfg.which {
Which::V1 | Which::V2 | Which::Solar10_7B => {
hub_load_safetensors(&api, "model.safetensors.index.json")?
}
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
};
let cache = model::Cache::new(!cfg.no_kv_cache, dtype, &config, &device)?;

let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
(Llama::load(vb, &config)?, tokenizer_filename, cache)
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
let mut tokens = tokenizer
.encode(prompt.clone(), true)
.map_err(E::msg)?
.get_ids()
.to_vec();

let mut tokenizer = TokenOutputStream::new(tokenizer);
let mut logits_processor = LogitsProcessor::new(cfg.seed, cfg.temperature, cfg.top_p);
let mut index_pos = 0;
let mut res = String::new();
for index in 0..cfg.sample_len {
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
(1, index_pos)
} else {
(tokens.len(), 0)
};
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = llama.forward(&input, context_index, &mut cache)?;
let logits = logits.squeeze(0)?;
let logits = if cfg.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(cfg.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
cfg.repeat_penalty,
&tokens[start_at..],
)?
};
index_pos += ctxt.len();

let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);

if Some(next_token) == eos_token_id {
break;
}
if let Some(t) = tokenizer.next_token(next_token)? {
res += &t;
}
}
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
res += &rest;
}
Ok(res)
}
1 change: 1 addition & 0 deletions atoma-inference/src/candle/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod llama;
pub mod stable_diffusion;
pub mod token_output_stream;

Expand Down
2 changes: 1 addition & 1 deletion atoma-inference/src/candle/stable_diffusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ impl Input {
guidance_scale: None,
img2img: None,
img2img_strength: 0.8,
seed: None,
seed: Some(0),
}
}
}
Expand Down
18 changes: 12 additions & 6 deletions atoma-inference/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,22 @@ fn stable_diffusion() {
}
}

fn llama1() {
use crate::candle::llama::run;
let x = run(
"The most important thing is ".to_string(),
Default::default(),
)
.unwrap();
println!("{}", x);
}

fn main() {
stable_diffusion();
// stable_diffusion();
llama1();
// let result = llama::run("One day I will").unwrap();
// println!("{}", result);

// let x = run(
// "The most important thing is ".to_string(),
// Default::default(),
// )
// .unwrap();
// println!("{}", x);
// run(
// "Green boat on ocean during storm".to_string(),
Expand Down

0 comments on commit c5f3964

Please sign in to comment.