Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: pass task-specific config to backend #922

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions flutter/assets/tasks.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -247,4 +247,12 @@ task {
id: "stable_diffusion"
name: "StableDiffusion"
}
custom_config {
id: "stable_diffusion_seed"
value: "633994880"
}
custom_config {
id: "stable_diffusion_num_steps"
value: "20"
}
Comment on lines +250 to +257
Copy link
Collaborator Author

@anhappdev anhappdev Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This how a custom setting for all backends can be defined in tasks.pbtxt

}
10 changes: 5 additions & 5 deletions flutter/cpp/binary/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ int Main(int argc, char *argv[]) {
command_line += " " + backend_name + " " + benchmark_id;

// Command Line Flags for mlperf.
std::string mode, scenario = "SingleStream", output_dir;
std::string mode, scenario = "SingleStream", output_dir, custom_config;
int min_query_count = 100, min_duration_ms = 100,
max_duration_ms = 10 * 60 * 1000,
single_stream_expected_latency_ns = 1000000;
Expand All @@ -157,8 +157,9 @@ int Main(int argc, char *argv[]) {
"A hint used by the loadgen to pre-generate "
"enough samples to meet the minimum test duration."),
Flag::CreateFlag("output_dir", &output_dir,
"The output directory of mlperf.", Flag::kRequired)});

"The output directory of mlperf.", Flag::kRequired),
Flag::CreateFlag("custom_config", &custom_config,
"Custom config in form key1:val1,key2:val2.")});
// Command Line Flags for backend.
std::unique_ptr<Backend> backend;
std::unique_ptr<Dataset> dataset;
Expand Down Expand Up @@ -207,9 +208,8 @@ int Main(int argc, char *argv[]) {
}
}
}

SettingList setting_list =
createSettingList(backend_setting, benchmark_id);
CreateSettingList(backend_setting, custom_config, benchmark_id);

ExternalBackend *external_backend = new ExternalBackend(
model_file_path, lib_path, setting_list, native_lib_path);
Expand Down
14 changes: 13 additions & 1 deletion flutter/cpp/proto/mlperf_task.proto
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ message MLPerfConfig {
// Config of the mlperf tasks.
// A task is basically a combination of models and a dataset.
//
// Next ID: 11
// Next ID: 12
message TaskConfig {
// Must be unique in one task file. Ex: image_classification
// used to match backend settings
Expand All @@ -52,6 +52,7 @@ message TaskConfig {
required string scenario = 7;
required DatasetConfig datasets = 8;
required ModelConfig model = 9;
repeated CustomConfig custom_config = 11;
}

// Datasets for a task
Expand Down Expand Up @@ -107,3 +108,14 @@ message ModelConfig {
// Number of detection classes if applicable
optional int32 num_classes = 6;
}

// CustomConfig are task specific configuration.
// The TaskConfig.CustomConfig will be converted to
// BenchmarkSetting.CustomSetting and passed to the backend.
// To avoid name collision, the id should be prefixed with TaskConfig.id.
message CustomConfig {
// Id of this config.
required string id = 1;
// Value of this config.
required string value = 2;
}
4 changes: 3 additions & 1 deletion flutter/cpp/proto/test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,11 @@ int test_proto() {
std::list<std::string> benchmarks;
benchmarks.push_back("image_classification");
benchmarks.push_back("image_classification_offline");
std::string custom_config = "key1:val1,key2:val2";
for (auto benchmark_id : benchmarks) {
// Convert to SettingList
SettingList setting_list = createSettingList(backend_setting, benchmark_id);
SettingList setting_list =
CreateSettingList(backend_setting, custom_config, benchmark_id);

std::cout << "SettingList for " << benchmark_id << ":\n";
dumpSettingList(setting_list);
Expand Down
51 changes: 45 additions & 6 deletions flutter/cpp/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,22 +125,61 @@ mlperf_backend_configuration_t CppToCSettings(const SettingList &settings) {
return c_settings;
}

SettingList createSettingList(const BackendSetting &backend_setting,
std::string benchmark_id) {
// Split the string by a given delimiter
std::vector<std::string> _splitString(const std::string &str, char delimiter) {
std::vector<std::string> tokens;
std::stringstream ss(str);
std::string token;
while (std::getline(ss, token, delimiter)) {
tokens.push_back(token);
}
return tokens;
}

// Parse the key:value string list
std::unordered_map<std::string, std::string> _parseKeyValueList(
const std::string &input) {
std::unordered_map<std::string, std::string> keyValueMap;
std::vector<std::string> pairs = _splitString(input, ','); // Split by comma

for (const std::string &pair : pairs) {
std::vector<std::string> keyValue =
_splitString(pair, ':'); // Split by colon
if (keyValue.size() == 2) {
keyValueMap[keyValue[0]] = keyValue[1];
} else {
LOG(ERROR) << "Invalid key:value pair: " << pair;
}
}
return keyValueMap;
}

// Create the setting list for backend
SettingList CreateSettingList(const BackendSetting &backend_setting,
const std::string &custom_config,
const std::string &benchmark_id) {
SettingList setting_list;
int setting_index = 0;

for (auto setting : backend_setting.common_setting()) {
for (const auto &setting : backend_setting.common_setting()) {
setting_list.add_setting();
(*setting_list.mutable_setting(setting_index)) = setting;
setting_index++;
}

// Copy the benchmark specific settings
setting_index = 0;
for (auto bm_setting : backend_setting.benchmark_setting()) {
for (const auto &bm_setting : backend_setting.benchmark_setting()) {
if (bm_setting.benchmark_id() == benchmark_id) {
setting_list.mutable_benchmark_setting()->CopyFrom(bm_setting);

auto parsed = _parseKeyValueList(custom_config);
for (const auto &kv : parsed) {
CustomSetting custom_setting = CustomSetting();
custom_setting.set_id(kv.first);
custom_setting.set_value(kv.second);
setting_list.mutable_benchmark_setting()->mutable_custom_setting()->Add(
std::move(custom_setting));
}
break;
}
}
LOG(INFO) << "setting_list:" << std::endl << setting_list.DebugString();
Expand Down
5 changes: 3 additions & 2 deletions flutter/cpp/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ void DeleteBackendConfiguration(mlperf_backend_configuration_t *configs);

mlperf_backend_configuration_t CppToCSettings(const SettingList &settings);

SettingList createSettingList(const BackendSetting &backend_setting,
std::string benchmark_id);
SettingList CreateSettingList(const BackendSetting &backend_setting,
const std::string &custom_config,
const std::string &benchmark_id);

} // namespace mobile
} // namespace mlperf
Expand Down
5 changes: 5 additions & 0 deletions flutter/lib/benchmark/benchmark.dart
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ class Benchmark {
setting: commonSettings,
benchmarkSetting: benchmarkSettings,
);
// Convert TaskConfig.CustomConfig to BenchmarkSetting.CustomSetting
final customConfigs = taskConfig.customConfig
.map((e) => pb.CustomSetting(id: e.id, value: e.value))
.toList();
benchmarkSettings.customSetting.addAll(customConfigs);
final uris = selectedDelegate.modelFile.map((e) => e.modelPath).toList();
final modelDirName = selectedDelegate.delegateName.replaceAll(' ', '_');
final backendModelPath =
Expand Down
3 changes: 2 additions & 1 deletion mobile_back_apple/dev-utils/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,9 @@ tflite-run-sd:
bazel-bin/flutter/cpp/binary/main EXTERNAL stable_diffusion \
--mode=PerformanceOnly \
--output_dir="${REPO_ROOT_DIR}/output" \
--model_file="${REPO_ROOT_DIR}/mobile_back_apple/dev-resources/stable_diffusion/sd-models" \
--model_file="${REPO_ROOT_DIR}/mobile_back_apple/dev-resources/stable_diffusion/dynamic-sd-models" \
--lib_path="bazel-bin/mobile_back_tflite/cpp/backend_tflite/libtflitebackend.so" \
--custom_config="stable_diffusion_num_steps:20,stable_diffusion_seed:633994880" \
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the main.cc file for development, one can pass it with the flag custom_config

--input_tfrecord="${REPO_ROOT_DIR}/mobile_back_apple/dev-resources/stable_diffusion/coco_gen_full.tfrecord" \
--input_clip_model="${REPO_ROOT_DIR}/mobile_back_apple/dev-resources/stable_diffusion/clip_model_512x512.tflite" \
--min_query_count=5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ StableDiffusionInvoker::StableDiffusionInvoker(SDBackendData* backend_data)
: backend_data_(backend_data) {}

std::vector<float> StableDiffusionInvoker::invoke() {
std::cout << "Prompt encoding started" << std::endl;
LOG(INFO) << "Prompt encoding started";
auto encoded_text = encode_prompt(backend_data_->input_prompt_tokens);
auto unconditional_encoded_text =
encode_prompt(backend_data_->unconditional_tokens);
std::cout << "Diffusion process started" << std::endl;
LOG(INFO) << "Diffusion process started";
auto latent =
diffusion_process(encoded_text, unconditional_encoded_text,
backend_data_->num_steps, backend_data_->seed);
std::cout << "Image decoding started" << std::endl;
LOG(INFO) << "Image decoding started";
return decode_image(latent);
}

Expand Down Expand Up @@ -108,7 +108,7 @@ std::vector<float> StableDiffusionInvoker::diffusion_process(
auto alphas_prev = std::get<1>(alphas_tuple);

for (int i = timesteps.size() - 1; i >= 0; --i) {
std::cout << "Step " << timesteps.size() - 1 - i << "\n";
LOG(INFO) << "Step " << timesteps.size() - 1 - i;

auto latent_prev = latent;
auto t_emb = get_timestep_embedding(timesteps[i]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ mlperf_backend_ptr_t StableDiffusionPipeline::backend_create(
SDBackendData* backend_data = new SDBackendData();
backendExists = true;

for (int i = 0; i < configs->count; ++i) {
if (strcmp(configs->keys[i], "stable_diffusion_seed") == 0) {
backend_data->seed = atoi(configs->values[i]);
}
if (strcmp(configs->keys[i], "stable_diffusion_num_steps") == 0) {
backend_data->num_steps = atoi(configs->values[i]);
}
}

Comment on lines +67 to +75
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The backend can read them from mlperf_backend_configuration_t in mlperf_backend_create.

// Load models from the provided directory path
std::string text_encoder_path =
std::string(model_path) + "/sd_text_encoder_dynamic.tflite";
Expand Down
Loading