-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodelhub_backend_adapters.cpp
More file actions
112 lines (96 loc) · 3.87 KB
/
Copy pathmodelhub_backend_adapters.cpp
File metadata and controls
112 lines (96 loc) · 3.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#include "config/model_config.hpp"
#include "models/models_registry.hpp"
#include <stdexcept>
#include <string>
namespace {
struct BackendAdapter {
const char *model_type;
const char *delegate_backend;
const char *family;
};
const BackendAdapter kAdapters[] = {
{"gemma2", "llama", "dense-transformer"},
{"phi3", "llama", "dense-transformer"},
{"yi", "llama", "dense-transformer"},
{"deepseek_v2_lite", "qwen3_moe", "moe-transformer"},
{"olmoe", "qwen3_moe", "moe-transformer"},
};
void add_common_defaults(nlohmann::json &config) {
if (!config.contains("head_dim")) {
config["head_dim"] = config.at("hidden_size").get<size_t>()
/ config.at("num_attention_heads").get<size_t>();
}
if (!config.contains("num_key_value_heads")) {
config["num_key_value_heads"] = config.at("num_attention_heads");
}
if (!config.contains("attention_bias")) {
config["attention_bias"] = false;
}
if (!config.contains("tie_word_embeddings")) {
config["tie_word_embeddings"] = false;
}
}
std::shared_ptr<infinilm::config::ModelConfig> create_adapter_config(
std::shared_ptr<infinilm::config::ModelConfig> model_config,
const BackendAdapter &adapter) {
nlohmann::json &config = model_config->get_config_json();
const auto original_model_type = config.at("model_type").get<std::string>();
if (original_model_type != adapter.model_type) {
throw std::runtime_error(
"InfiniLM-ModelHub backend adapter expected model_type="
+ std::string(adapter.model_type) + ", got " + original_model_type);
}
add_common_defaults(config);
config["_infinilm_out_of_tree_backend_plugin"] = "infinilm_model_hub";
config["_infinilm_backend_delegate"] = adapter.delegate_backend;
config["_infinilm_backend_family"] = adapter.family;
auto config_it = infinilm::models::get_model_config_map().find(adapter.delegate_backend);
if (config_it != infinilm::models::get_model_config_map().end()) {
config["model_type"] = adapter.delegate_backend;
config_it->second(model_config);
config["model_type"] = original_model_type;
}
return model_config;
}
std::shared_ptr<infinilm::InfinilmModel> create_adapter_model(
std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device,
const BackendAdapter &adapter) {
nlohmann::json &config = model_config->get_config_json();
const auto original_model_type = config.at("model_type").get<std::string>();
const auto &model_map = infinilm::models::get_causal_lm_model_map();
auto model_it = model_map.find(adapter.delegate_backend);
if (model_it == model_map.end()) {
throw std::runtime_error(
"InfiniLM-ModelHub backend adapter could not find delegate backend: "
+ std::string(adapter.delegate_backend));
}
config["model_type"] = adapter.delegate_backend;
auto model = model_it->second(model_config, device);
config["model_type"] = original_model_type;
return model;
}
void register_adapter(const BackendAdapter &adapter) {
infinilm::models::register_model_config(
adapter.model_type,
[&adapter](std::shared_ptr<infinilm::config::ModelConfig> model_config) {
return create_adapter_config(std::move(model_config), adapter);
});
infinilm::models::register_causal_lm_model(
adapter.model_type,
[&adapter](std::shared_ptr<infinilm::config::ModelConfig> model_config,
const infinicore::Device &device) {
return create_adapter_model(std::move(model_config), device, adapter);
});
}
} // namespace
extern "C" void infinilm_backend_plugin_init() {
static bool registered = false;
if (registered) {
return;
}
for (const auto &adapter : kAdapters) {
register_adapter(adapter);
}
registered = true;
}