@@ -275,6 +275,10 @@ std::unordered_map<std::string, std::unordered_map<std::string, std::string>> su
275275 {" to_v" , " v" },
276276 {" to_out_0" , " proj_out" },
277277 {" group_norm" , " norm" },
278+ {" key" , " k" },
279+ {" query" , " q" },
280+ {" value" , " v" },
281+ {" proj_attn" , " proj_out" },
278282 },
279283 },
280284 {
@@ -299,6 +303,10 @@ std::unordered_map<std::string, std::unordered_map<std::string, std::string>> su
299303 {" to_v" , " v" },
300304 {" to_out.0" , " proj_out" },
301305 {" group_norm" , " norm" },
306+ {" key" , " k" },
307+ {" query" , " q" },
308+ {" value" , " v" },
309+ {" proj_attn" , " proj_out" },
302310 },
303311 },
304312 {
@@ -370,6 +378,10 @@ std::string convert_diffusers_name_to_compvis(std::string key, char seq) {
370378 return format (" model%cdiffusion_model%ctime_embed%c" , seq, seq, seq) + std::to_string (std::stoi (m[0 ]) * 2 - 2 ) + m[1 ];
371379 }
372380
381+ if (match (m, std::regex (format (" unet%cadd_embedding%clinear_(\\ d+)(.*)" , seq, seq)), key)) {
382+ return format (" model%cdiffusion_model%clabel_emb%c0%c" , seq, seq, seq, seq) + std::to_string (std::stoi (m[0 ]) * 2 - 2 ) + m[1 ];
383+ }
384+
373385 if (match (m, std::regex (format (" unet%cdown_blocks%c(\\ d+)%c(attentions|resnets)%c(\\ d+)%c(.+)" , seq, seq, seq, seq, seq)), key)) {
374386 std::string suffix = get_converted_suffix (m[1 ], m[3 ]);
375387 // LOG_DEBUG("%s %s %s %s", m[0].c_str(), m[1].c_str(), m[2].c_str(), m[3].c_str());
@@ -407,6 +419,19 @@ std::string convert_diffusers_name_to_compvis(std::string key, char seq) {
407419 return format (" cond_stage_model%ctransformer%ctext_model" , seq, seq) + m[0 ];
408420 }
409421
422+ // clip-g
423+ if (match (m, std::regex (format (" te%c1%ctext_model%cencoder%clayers%c(\\ d+)%c(.+)" , seq, seq, seq, seq, seq, seq)), key)) {
424+ return format (" cond_stage_model%c1%ctransformer%ctext_model%cencoder%clayers%c" , seq, seq, seq, seq, seq, seq) + m[0 ] + seq + m[1 ];
425+ }
426+
427+ if (match (m, std::regex (format (" te%c1%ctext_model(.*)" , seq, seq)), key)) {
428+ return format (" cond_stage_model%c1%ctransformer%ctext_model" , seq, seq, seq) + m[0 ];
429+ }
430+
431+ if (match (m, std::regex (format (" te%c1%ctext_projection" , seq, seq)), key)) {
432+ return format (" cond_stage_model%c1%ctransformer%ctext_model%ctext_projection" , seq, seq, seq, seq);
433+ }
434+
410435 // vae
411436 if (match (m, std::regex (format (" vae%c(.*)%cconv_norm_out(.*)" , seq, seq)), key)) {
412437 return format (" first_stage_model%c%s%cnorm_out%s" , seq, m[0 ].c_str (), seq, m[1 ].c_str ());
@@ -543,6 +568,8 @@ std::string convert_tensor_name(std::string name) {
543568 std::string new_key = convert_diffusers_name_to_compvis (name_without_network_parts, ' .' );
544569 if (new_key.empty ()) {
545570 new_name = name;
571+ } else if (new_key == " cond_stage_model.1.transformer.text_model.text_projection" ) {
572+ new_name = new_key;
546573 } else {
547574 new_name = new_key + " ." + network_part;
548575 }
@@ -966,10 +993,14 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
966993 ttype = GGML_TYPE_F32;
967994 } else if (dtype == " F32" ) {
968995 ttype = GGML_TYPE_F32;
996+ } else if (dtype == " F64" ) {
997+ ttype = GGML_TYPE_F64;
969998 } else if (dtype == " F8_E4M3" ) {
970999 ttype = GGML_TYPE_F16;
9711000 } else if (dtype == " F8_E5M2" ) {
9721001 ttype = GGML_TYPE_F16;
1002+ } else if (dtype == " I64" ) {
1003+ ttype = GGML_TYPE_I64;
9731004 }
9741005 return ttype;
9751006}
@@ -982,6 +1013,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
9821013 std::ifstream file (file_path, std::ios::binary);
9831014 if (!file.is_open ()) {
9841015 LOG_ERROR (" failed to open '%s'" , file_path.c_str ());
1016+ file_paths_.pop_back ();
9851017 return false ;
9861018 }
9871019
@@ -993,6 +1025,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
9931025 // read header size
9941026 if (file_size_ <= ST_HEADER_SIZE_LEN) {
9951027 LOG_ERROR (" invalid safetensor file '%s'" , file_path.c_str ());
1028+ file_paths_.pop_back ();
9961029 return false ;
9971030 }
9981031
@@ -1006,6 +1039,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
10061039 size_t header_size_ = read_u64 (header_size_buf);
10071040 if (header_size_ >= file_size_) {
10081041 LOG_ERROR (" invalid safetensor file '%s'" , file_path.c_str ());
1042+ file_paths_.pop_back ();
10091043 return false ;
10101044 }
10111045
@@ -1016,6 +1050,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
10161050 file.read (header_buf.data (), header_size_);
10171051 if (!file) {
10181052 LOG_ERROR (" read safetensors header failed: '%s'" , file_path.c_str ());
1053+ file_paths_.pop_back ();
10191054 return false ;
10201055 }
10211056
@@ -1071,7 +1106,10 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
10711106 n_dims = 1 ;
10721107 }
10731108
1074- TensorStorage tensor_storage (prefix + name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
1109+ std::string new_name = prefix + name;
1110+ new_name = convert_tensor_name (new_name);
1111+
1112+ TensorStorage tensor_storage (new_name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
10751113 tensor_storage.reverse_ne ();
10761114
10771115 size_t tensor_data_size = end - begin;
@@ -1103,18 +1141,40 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
11031141/* ================================================= DiffusersModelLoader ==================================================*/
11041142
11051143bool ModelLoader::init_from_diffusers_file (const std::string& file_path, const std::string& prefix) {
1106- std::string unet_path = path_join (file_path, " unet/diffusion_pytorch_model.safetensors" );
1107- std::string vae_path = path_join (file_path, " vae/diffusion_pytorch_model.safetensors" );
1108- std::string clip_path = path_join (file_path, " text_encoder/model.safetensors" );
1144+ std::string unet_path = path_join (file_path, " unet/diffusion_pytorch_model.safetensors" );
1145+ std::string vae_path = path_join (file_path, " vae/diffusion_pytorch_model.safetensors" );
1146+ std::string clip_path = path_join (file_path, " text_encoder/model.safetensors" );
1147+ std::string clip_g_path = path_join (file_path, " text_encoder_2/model.safetensors" );
11091148
11101149 if (!init_from_safetensors_file (unet_path, " unet." )) {
11111150 return false ;
11121151 }
1152+ for (auto ts : tensor_storages) {
1153+ if (ts.name .find (" label_emb" ) != std::string::npos) {
1154+ // probably SDXL
1155+ LOG_DEBUG (" Fixing name for SDXL output blocks.2.2" );
1156+ for (auto & tensor_storage : tensor_storages) {
1157+ auto pos = tensor_storage.name .find (" model.diffusion_model.output_blocks.2.1.conv" );
1158+ if (pos != std::string::npos) {
1159+ tensor_storage.name = " model.diffusion_model.output_blocks.2.2.conv" + tensor_storage.name .substr (44 );
1160+ LOG_DEBUG (" NEW NAME: %s" , tensor_storage.name .c_str ());
1161+ add_preprocess_tensor_storage_types (tensor_storages_types, tensor_storage.name , tensor_storage.type );
1162+ }
1163+ }
1164+ break ;
1165+ }
1166+ }
1167+
11131168 if (!init_from_safetensors_file (vae_path, " vae." )) {
1114- return false ;
1169+ LOG_WARN (" Couldn't find working VAE in %s" , file_path.c_str ());
1170+ // return false;
11151171 }
11161172 if (!init_from_safetensors_file (clip_path, " te." )) {
1117- return false ;
1173+ LOG_WARN (" Couldn't find working text encoder in %s" , file_path.c_str ());
1174+ // return false;
1175+ }
1176+ if (!init_from_safetensors_file (clip_g_path, " te.1." )) {
1177+ LOG_DEBUG (" Couldn't find working second text encoder in %s" , file_path.c_str ());
11181178 }
11191179 return true ;
11201180}
0 commit comments