Skip to content

Commit

Permalink
add stable diffusion
Browse files Browse the repository at this point in the history
  • Loading branch information
Cifko committed Mar 27, 2024
1 parent 48e7845 commit 41a3da9
Show file tree
Hide file tree
Showing 10 changed files with 830 additions and 27 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
Cargo.lock
target/
.vscode/
16 changes: 15 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,17 +1,31 @@
[workspace]
resolver = "2"
edition = "2021"

members = ["atoma-event-subscribe", "atoma-inference", "atoma-networking", "atoma-json-rpc", "atoma-storage"]
members = [
"atoma-event-subscribe",
"atoma-inference",
"atoma-networking",
"atoma-json-rpc",
"atoma-storage",
]

[workspace.package]
version = "0.1.0"

[workspace.dependencies]
anyhow = "1.0.81"
async-trait = "0.1.78"
candle = { git = "https://github.com/huggingface/candle", package = "candle-core", version = "0.4.2" }
candle-nn = { git = "https://github.com/huggingface/candle", package = "candle-nn", version = "0.4.2" }
candle-transformers = { git = "https://github.com/huggingface/candle", package = "candle-transformers", version = "0.4.2" }
clap = "4.5.3"
ed25519-consensus = "2.1.0"
hf-hub = "0.3.0"
image = { version = "0.25.0", default-features = false, features = [
"jpeg",
"png",
] }
serde = "1.0.197"
thiserror = "1.0.58"
tokenizers = "0.15.2"
Expand Down
18 changes: 13 additions & 5 deletions atoma-inference/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
[package]
name = "inference"
version = "0.1.0"
version.workspace = true
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
anyhow.workspace = true
async-trait.workspace = true
candle.workspace = true
candle-nn.workspace = true
candle-transformers.workspace = true
candle.workspace = true
clap.workspace = true
ed25519-consensus.workspace = true
hf-hub = { workspace = true, features = ["tokio"] }
image = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = "1.0.114"
thiserror.workspace = true
tokenizers.workspace = true
tokenizers = { workspace = true, features = ["onig"] }
tokio = { workspace = true, features = ["full", "tracing"] }
tracing.workspace = true
llama_cpp = "0.3.1"

[features]
cuda = ["candle/cuda", "candle-nn/cuda"]
metal = ["candle/metal", "candle-nn/metal"]
75 changes: 75 additions & 0 deletions atoma-inference/src/candle/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
pub mod stable_diffusion;
pub mod token_output_stream;

use std::path::PathBuf;

use candle::{
utils::{cuda_is_available, metal_is_available},
Device, Tensor,
};
use tracing::info;

use crate::models::ModelError;

pub trait CandleModel {
type Fetch;
type Input;
fn fetch(fetch: &Self::Fetch) -> Result<(), ModelError>;
fn inference(input: Self::Input) -> Result<Vec<Tensor>, ModelError>;
}

pub fn device() -> Result<Device, candle::Error> {
if cuda_is_available() {
info!("Using CUDA");
Device::new_cuda(0)
} else if metal_is_available() {
info!("Using Metal");
Device::new_metal(0)
} else {
info!("Using Cpu");
Ok(Device::Cpu)
}
}

pub fn hub_load_safetensors(
repo: &hf_hub::api::sync::ApiRepo,
json_file: &str,
) -> candle::Result<Vec<std::path::PathBuf>> {
let json_file = repo.get(json_file).map_err(candle::Error::wrap)?;
let json_file = std::fs::File::open(json_file)?;
let json: serde_json::Value =
serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?;
let weight_map = match json.get("weight_map") {
None => candle::bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,
Some(_) => candle::bail!("weight map in {json_file:?} is not a map"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
if let Some(file) = value.as_str() {
safetensors_files.insert(file.to_string());
}
}
let safetensors_files = safetensors_files
.iter()
.map(|v| repo.get(v).map_err(candle::Error::wrap))
.collect::<candle::Result<Vec<_>>>()?;
Ok(safetensors_files)
}

pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> candle::Result<()> {
let p = p.as_ref();
let (channel, height, width) = img.dims3()?;
if channel != 3 {
candle::bail!("save_image expects an input of shape (3, height, width)")
}
let img = img.permute((1, 2, 0))?.flatten_all()?;
let pixels = img.to_vec1::<u8>()?;
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
Some(image) => image,
None => candle::bail!("error saving image {p:?}"),
};
image.save(p).map_err(candle::Error::wrap)?;
Ok(())
}
Loading

0 comments on commit 41a3da9

Please sign in to comment.