Skip to content

Commit 60db19a

Browse files
committed
Fix loading diffusers model
1 parent 10c6501 commit 60db19a

1 file changed

Lines changed: 66 additions & 6 deletions

File tree

model.cpp

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,10 @@ std::unordered_map<std::string, std::unordered_map<std::string, std::string>> su
275275
{"to_v", "v"},
276276
{"to_out_0", "proj_out"},
277277
{"group_norm", "norm"},
278+
{"key", "k"},
279+
{"query", "q"},
280+
{"value", "v"},
281+
{"proj_attn", "proj_out"},
278282
},
279283
},
280284
{
@@ -299,6 +303,10 @@ std::unordered_map<std::string, std::unordered_map<std::string, std::string>> su
299303
{"to_v", "v"},
300304
{"to_out.0", "proj_out"},
301305
{"group_norm", "norm"},
306+
{"key", "k"},
307+
{"query", "q"},
308+
{"value", "v"},
309+
{"proj_attn", "proj_out"},
302310
},
303311
},
304312
{
@@ -370,6 +378,10 @@ std::string convert_diffusers_name_to_compvis(std::string key, char seq) {
370378
return format("model%cdiffusion_model%ctime_embed%c", seq, seq, seq) + std::to_string(std::stoi(m[0]) * 2 - 2) + m[1];
371379
}
372380

381+
if (match(m, std::regex(format("unet%cadd_embedding%clinear_(\\d+)(.*)", seq, seq)), key)) {
382+
return format("model%cdiffusion_model%clabel_emb%c0%c", seq, seq, seq, seq) + std::to_string(std::stoi(m[0]) * 2 - 2) + m[1];
383+
}
384+
373385
if (match(m, std::regex(format("unet%cdown_blocks%c(\\d+)%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) {
374386
std::string suffix = get_converted_suffix(m[1], m[3]);
375387
// LOG_DEBUG("%s %s %s %s", m[0].c_str(), m[1].c_str(), m[2].c_str(), m[3].c_str());
@@ -407,6 +419,19 @@ std::string convert_diffusers_name_to_compvis(std::string key, char seq) {
407419
return format("cond_stage_model%ctransformer%ctext_model", seq, seq) + m[0];
408420
}
409421

422+
// clip-g
423+
if (match(m, std::regex(format("te%c1%ctext_model%cencoder%clayers%c(\\d+)%c(.+)", seq, seq, seq, seq, seq, seq)), key)) {
424+
return format("cond_stage_model%c1%ctransformer%ctext_model%cencoder%clayers%c", seq, seq, seq, seq, seq, seq) + m[0] + seq + m[1];
425+
}
426+
427+
if (match(m, std::regex(format("te%c1%ctext_model(.*)", seq, seq)), key)) {
428+
return format("cond_stage_model%c1%ctransformer%ctext_model", seq, seq, seq) + m[0];
429+
}
430+
431+
if (match(m, std::regex(format("te%c1%ctext_projection", seq, seq)), key)) {
432+
return format("cond_stage_model%c1%ctransformer%ctext_model%ctext_projection", seq, seq, seq, seq);
433+
}
434+
410435
// vae
411436
if (match(m, std::regex(format("vae%c(.*)%cconv_norm_out(.*)", seq, seq)), key)) {
412437
return format("first_stage_model%c%s%cnorm_out%s", seq, m[0].c_str(), seq, m[1].c_str());
@@ -543,6 +568,8 @@ std::string convert_tensor_name(std::string name) {
543568
std::string new_key = convert_diffusers_name_to_compvis(name_without_network_parts, '.');
544569
if (new_key.empty()) {
545570
new_name = name;
571+
} else if (new_key == "cond_stage_model.1.transformer.text_model.text_projection") {
572+
new_name = new_key;
546573
} else {
547574
new_name = new_key + "." + network_part;
548575
}
@@ -966,10 +993,14 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
966993
ttype = GGML_TYPE_F32;
967994
} else if (dtype == "F32") {
968995
ttype = GGML_TYPE_F32;
996+
} else if (dtype == "F64") {
997+
ttype = GGML_TYPE_F64;
969998
} else if (dtype == "F8_E4M3") {
970999
ttype = GGML_TYPE_F16;
9711000
} else if (dtype == "F8_E5M2") {
9721001
ttype = GGML_TYPE_F16;
1002+
} else if (dtype == "I64") {
1003+
ttype = GGML_TYPE_I64;
9731004
}
9741005
return ttype;
9751006
}
@@ -982,6 +1013,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
9821013
std::ifstream file(file_path, std::ios::binary);
9831014
if (!file.is_open()) {
9841015
LOG_ERROR("failed to open '%s'", file_path.c_str());
1016+
file_paths_.pop_back();
9851017
return false;
9861018
}
9871019

@@ -993,6 +1025,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
9931025
// read header size
9941026
if (file_size_ <= ST_HEADER_SIZE_LEN) {
9951027
LOG_ERROR("invalid safetensor file '%s'", file_path.c_str());
1028+
file_paths_.pop_back();
9961029
return false;
9971030
}
9981031

@@ -1006,6 +1039,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
10061039
size_t header_size_ = read_u64(header_size_buf);
10071040
if (header_size_ >= file_size_) {
10081041
LOG_ERROR("invalid safetensor file '%s'", file_path.c_str());
1042+
file_paths_.pop_back();
10091043
return false;
10101044
}
10111045

@@ -1016,6 +1050,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
10161050
file.read(header_buf.data(), header_size_);
10171051
if (!file) {
10181052
LOG_ERROR("read safetensors header failed: '%s'", file_path.c_str());
1053+
file_paths_.pop_back();
10191054
return false;
10201055
}
10211056

@@ -1071,7 +1106,10 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
10711106
n_dims = 1;
10721107
}
10731108

1074-
TensorStorage tensor_storage(prefix + name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
1109+
std::string new_name = prefix + name;
1110+
new_name = convert_tensor_name(new_name);
1111+
1112+
TensorStorage tensor_storage(new_name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
10751113
tensor_storage.reverse_ne();
10761114

10771115
size_t tensor_data_size = end - begin;
@@ -1103,18 +1141,40 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
11031141
/*================================================= DiffusersModelLoader ==================================================*/
11041142

11051143
bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const std::string& prefix) {
1106-
std::string unet_path = path_join(file_path, "unet/diffusion_pytorch_model.safetensors");
1107-
std::string vae_path = path_join(file_path, "vae/diffusion_pytorch_model.safetensors");
1108-
std::string clip_path = path_join(file_path, "text_encoder/model.safetensors");
1144+
std::string unet_path = path_join(file_path, "unet/diffusion_pytorch_model.safetensors");
1145+
std::string vae_path = path_join(file_path, "vae/diffusion_pytorch_model.safetensors");
1146+
std::string clip_path = path_join(file_path, "text_encoder/model.safetensors");
1147+
std::string clip_g_path = path_join(file_path, "text_encoder_2/model.safetensors");
11091148

11101149
if (!init_from_safetensors_file(unet_path, "unet.")) {
11111150
return false;
11121151
}
1152+
for (auto ts : tensor_storages) {
1153+
if (ts.name.find("label_emb") != std::string::npos) {
1154+
// probably SDXL
1155+
LOG_DEBUG("Fixing name for SDXL output blocks.2.2");
1156+
for (auto& tensor_storage : tensor_storages) {
1157+
auto pos = tensor_storage.name.find("model.diffusion_model.output_blocks.2.1.conv");
1158+
if (pos != std::string::npos) {
1159+
tensor_storage.name = "model.diffusion_model.output_blocks.2.2.conv" + tensor_storage.name.substr(44);
1160+
LOG_DEBUG("NEW NAME: %s", tensor_storage.name.c_str());
1161+
add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type);
1162+
}
1163+
}
1164+
break;
1165+
}
1166+
}
1167+
11131168
if (!init_from_safetensors_file(vae_path, "vae.")) {
1114-
return false;
1169+
LOG_WARN("Couldn't find working VAE in %s", file_path.c_str());
1170+
// return false;
11151171
}
11161172
if (!init_from_safetensors_file(clip_path, "te.")) {
1117-
return false;
1173+
LOG_WARN("Couldn't find working text encoder in %s", file_path.c_str());
1174+
// return false;
1175+
}
1176+
if (!init_from_safetensors_file(clip_g_path, "te.1.")) {
1177+
LOG_DEBUG("Couldn't find working second text encoder in %s", file_path.c_str());
11181178
}
11191179
return true;
11201180
}

0 commit comments

Comments
 (0)