@@ -1194,6 +1194,10 @@ ov::AnyMap get_default_lm_head_config(const std::optional<NPUDesc>& npudesc) {
11941194 return config;
11951195}
11961196
1197+ ov::AnyMap get_default_text_embedding_post_config (const std::optional<NPUDesc>& npudesc) {
1198+ return get_default_lm_head_config (npudesc);
1199+ }
1200+
11971201void merge_config_with (ov::AnyMap& lhs, const ov::AnyMap& rhs) {
11981202 for (const auto & [key, value] : rhs) {
11991203 // NB: Overwrite the value if key already exists
@@ -1237,6 +1241,10 @@ void update_config_for_whisper(ov::AnyMap& config) {
12371241 config.erase (" NPUW_SLICE_OUT" );
12381242}
12391243
1244+ void update_config_for_text_embed (ov::AnyMap& config) {
1245+ config.erase (" NPUW_SLICE_OUT" );
1246+ }
1247+
12401248std::map<std::string, std::string> any_copy (const ov::AnyMap& params) {
12411249 std::map<std::string, std::string> result;
12421250 for (auto && value : params) {
@@ -1463,48 +1471,12 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
14631471 m_cfg.update ({{" NPUW_LLM_OPTIMIZE_V_TENSORS" , " NO" }});
14641472 }
14651473
1466- LOG_DEBUG (" Creating kvcache model as clone of passed one." );
1467- auto kvcache_model = model->clone ();
1468- LOG_DEBUG (" Transform kvcache model from stateful to stateless." );
1469- ov::pass::StatefulToStateless ().run_on_model (kvcache_model);
1470- convert_stateful_lora_to_stateless (kvcache_model);
1471- LOG_DEBUG (" ...also convert BF16 to FP16" );
1472- // Note: we need to identify original bf16 constants for potential weightless deserialization later
1473- // And only then do bf16 to f16 transformation
1474- m_bf16_consts = ov::npuw::s11n::get_bf16_consts (model);
1475- ov::pass::ConvertPrecision (ov::element::bf16 , ov::element::f16 ).run_on_model (kvcache_model);
1476-
1477- bool shared_head_enabled = m_cfg.get <::intel_npu::NPUW_LLM_SHARED_HEAD>();
1478- std::shared_ptr<ov::Model> lm_head_model = nullptr ;
1479- if (shared_head_enabled) {
1480- LOG_DEBUG (" Trying to separate Vocabulary matrix multiplication op into additional model..." );
1481- lm_head_model = cut_lm_head (kvcache_model);
1482- if (lm_head_model) {
1483- LOG_INFO (" Three-model pipeline will be created: LM head will be shared between prefill and generate." );
1484- } else {
1485- LOG_WARN (" Three-model pipeline is requested, but LM head cutting is failed,"
1486- " two-model pipeline will be created!" );
1487- }
1488- } else {
1489- LOG_INFO (" Two-model pipeline will be created." );
1490- }
1491-
1492- LOG_DEBUG (" Try patch Phi-3 sliding window mask, if it exists." );
1493- patch_phi3_sliding_mask (kvcache_model);
1494-
1495- LOG_DEBUG (" Creating prefill model as clone of transformed kvcache one." );
1496- auto prefill_model = kvcache_model->clone ();
1497- prefill_model->set_friendly_name (kvcache_model->get_friendly_name () + " _prefill" );
1498-
14991474 // NB: PREFILL_HINT is now compatible with the PREFILL_CONFIG section, unlike for
15001475 // the generate model they're not mutually exclusive
15011476 const ::intel_npu::npuw::llm::PrefillHint prefill_hint = m_cfg.get <::intel_npu::NPUW_LLM_PREFILL_HINT>();
15021477 m_prefill_chunk_size = m_cfg.get <::intel_npu::NPUW_LLM_PREFILL_CHUNK_SIZE>();
15031478 m_use_chunk_prefill = (prefill_hint == ::intel_npu::npuw::llm::PrefillHint::DYNAMIC && m_prefill_chunk_size > 0 );
15041479
1505- const uint32_t batch_dim = m_cfg.get <::intel_npu::NPUW_LLM_BATCH_DIM>();
1506- const uint32_t seq_len_dim = m_cfg.get <::intel_npu::NPUW_LLM_SEQ_LEN_DIM>();
1507- KVAxesPosition axes{batch_dim, seq_len_dim};
15081480 uint32_t max_prompt_len = align_to (m_cfg.get <::intel_npu::NPUW_LLM_MAX_PROMPT_LEN>(), 64u );
15091481 const uint32_t min_response_len = align_to (m_cfg.get <::intel_npu::NPUW_LLM_MIN_RESPONSE_LEN>(), 64u );
15101482 uint32_t max_generation_token_len = m_cfg.get <::intel_npu::NPUW_LLM_MAX_GENERATION_TOKEN_LEN>();
@@ -1548,6 +1520,59 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
15481520 LOG_VERB (" Prefill chunk size: " << m_prefill_chunk_size);
15491521 LOG_VERB (" Maximum prompt length: " << max_prompt_len);
15501522
1523+ const uint32_t batch_dim = m_cfg.get <::intel_npu::NPUW_LLM_BATCH_DIM>();
1524+ const uint32_t seq_len_dim = m_cfg.get <::intel_npu::NPUW_LLM_SEQ_LEN_DIM>();
1525+ KVAxesPosition axes{batch_dim, seq_len_dim};
1526+
1527+ LOG_DEBUG (" Creating kvcache model as clone of passed one." );
1528+ auto kvcache_model = model->clone ();
1529+
1530+ auto use_text_embed_key = pop_option (other_props, std::string (" NPUW_TEXT_EMBED" ));
1531+ m_is_text_embed = use_text_embed_key.value_or (false ).as <bool >() == true ;
1532+
1533+ std::shared_ptr<ov::Model> text_embedding_post_model = nullptr ;
1534+ if (m_is_text_embed) {
1535+ if (m_use_chunk_prefill) {
1536+ LOG_DEBUG (" Text-Embedding Chunk rebuild" );
1537+ ov::npuw::util::prepare_text_embedding_model (kvcache_model, seq_len_dim);
1538+ }
1539+
1540+ auto post_type = pop_option (other_props, std::string (" NPUW_TEXT_EMBED_POST_TYPE" ));
1541+ ov::npuw::util::create_text_embedding_post_model (kvcache_model, text_embedding_post_model, post_type);
1542+ } else {
1543+ LOG_DEBUG (" Transform kvcache model from stateful to stateless." );
1544+ ov::pass::StatefulToStateless ().run_on_model (kvcache_model);
1545+ convert_stateful_lora_to_stateless (kvcache_model);
1546+ }
1547+
1548+ LOG_DEBUG (" ...also convert BF16 to FP16" );
1549+ // Note: we need to identify original bf16 constants for potential weightless deserialization later
1550+ // And only then do bf16 to f16 transformation
1551+ m_bf16_consts = ov::npuw::s11n::get_bf16_consts (model);
1552+ ov::pass::ConvertPrecision (ov::element::bf16 , ov::element::f16 ).run_on_model (kvcache_model);
1553+
1554+ bool shared_head_enabled = m_cfg.get <::intel_npu::NPUW_LLM_SHARED_HEAD>();
1555+ std::shared_ptr<ov::Model> lm_head_model = nullptr ;
1556+ if (shared_head_enabled) {
1557+ LOG_DEBUG (" Trying to separate Vocabulary matrix multiplication op into additional model..." );
1558+ lm_head_model = cut_lm_head (kvcache_model);
1559+ if (lm_head_model) {
1560+ LOG_INFO (" Three-model pipeline will be created: LM head will be shared between prefill and generate." );
1561+ } else {
1562+ LOG_WARN (" Three-model pipeline is requested, but LM head cutting is failed,"
1563+ " two-model pipeline will be created!" );
1564+ }
1565+ } else {
1566+ LOG_INFO (" Two-model pipeline will be created." );
1567+ }
1568+
1569+ LOG_DEBUG (" Try patch Phi-3 sliding window mask, if it exists." );
1570+ patch_phi3_sliding_mask (kvcache_model);
1571+
1572+ LOG_DEBUG (" Creating prefill model as clone of transformed kvcache one." );
1573+ auto prefill_model = kvcache_model->clone ();
1574+ prefill_model->set_friendly_name (kvcache_model->get_friendly_name () + " _prefill" );
1575+
15511576 m_kvcache_desc =
15521577 KVCacheDesc{max_prompt_len, max_prompt_len + min_response_len, 0u , seq_len_dim, max_generation_token_len};
15531578
@@ -1564,6 +1589,14 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
15641589 ov::npuw::util::prepare_whisper_kvcache_model (kvcache_model); // Whisper decoder_with_past model
15651590 }
15661591
1592+ if (m_is_text_embed) {
1593+ m_kvcache_desc = KVCacheDesc{max_prompt_len,
1594+ max_prompt_len + min_response_len,
1595+ 0u ,
1596+ seq_len_dim,
1597+ max_prompt_len + min_response_len};
1598+ }
1599+
15671600 LOG_DEBUG (" Make prefill model with static shapes" );
15681601 m_max_lora_rank = m_cfg.get <::intel_npu::NPUW_LLM_MAX_LORA_RANK>();
15691602 if (m_use_chunk_prefill) {
@@ -1591,6 +1624,14 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
15911624 gemma_transformations (model_variant);
15921625 }
15931626
1627+ if (text_embedding_post_model) {
1628+ auto input_node = text_embedding_post_model->inputs ()[0 ];
1629+ NPUW_ASSERT (input_node.get_partial_shape ().size () == 3u );
1630+ auto new_shape = ov::PartialShape ({1 , m_kvcache_desc.max_prompt_size , input_node.get_partial_shape ()[2 ]});
1631+ auto mask_shape = ov::PartialShape ({1 , m_kvcache_desc.max_prompt_size });
1632+ text_embedding_post_model->reshape ({{" input_ids" , new_shape}, {" attention_mask" , mask_shape}});
1633+ }
1634+
15941635 if (lm_head_model) {
15951636 LOG_DEBUG (" Shared LM head: slice the prefill output" );
15961637 // KVCache model is already reshaped to [1, max_generation_token_len, embed size],
@@ -1653,18 +1694,23 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
16531694 LOG_DEBUG (" Check and apply opt layout --- SKIPPED" );
16541695 }
16551696
1656- if (!m_use_chunk_prefill) {
1697+ if (!m_use_chunk_prefill && !m_is_text_embed ) {
16571698 NPUW_ASSERT (remove_empty_kv_inputs (prefill_model));
16581699 } else {
16591700 LOG_DEBUG (" Don't remove input key/values from prefill model." );
16601701 LOG_DEBUG (" Ask prefill model to output key/values for prefill chunk size tokens." );
1661- prefill_model = redirect_new_kv_to_output (prefill_model);
1702+ if (!m_is_text_embed) {
1703+ prefill_model = redirect_new_kv_to_output (prefill_model);
1704+ }
16621705 }
16631706
1664- LOG_DEBUG (" Optimize generate model to output key/values for new token." );
1665- for (size_t i = 0 ; i < generate_model_variants.size (); ++i) {
1666- generate_model_variants[i] = redirect_new_kv_to_output (generate_model_variants[i]);
1707+ if (!m_is_text_embed) {
1708+ LOG_DEBUG (" Optimize generate model to output key/values for new token." );
1709+ for (size_t i = 0 ; i < generate_model_variants.size (); ++i) {
1710+ generate_model_variants[i] = redirect_new_kv_to_output (generate_model_variants[i]);
1711+ }
16671712 }
1713+
16681714 LOG_DEBUG (" Converting KV-cache in generate model to FP16." );
16691715 for (size_t i = 0 ; i < generate_model_variants.size (); ++i) {
16701716 generate_model_variants[i] = cvt_kvcache_to_fp16 (generate_model_variants[i]);
@@ -1730,6 +1776,10 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
17301776 update_config_for_whisper (prefill_config);
17311777 }
17321778
1779+ if (m_is_text_embed) {
1780+ update_config_for_text_embed (prefill_config);
1781+ }
1782+
17331783 if (m_cfg.get <::intel_npu::NPUW_LLM_CACHE_ROPE>()) {
17341784 LOG_DEBUG (" Caching preROPE " );
17351785 const uint32_t CACHE_ROPE_START = 2048 ;
@@ -1798,6 +1848,13 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
17981848 NPUW_ASSERT (m_lm_head_compiled);
17991849 }
18001850
1851+ if (text_embedding_post_model) {
1852+ auto post_config = get_default_text_embedding_post_config (npudesc);
1853+ merge_config_with (post_config, other_props);
1854+ m_text_embedding_post_compiled = std::dynamic_pointer_cast<ov::npuw::CompiledModel>(
1855+ ov::npuw::ICompiledModel::create (text_embedding_post_model, plugin, post_config));
1856+ }
1857+
18011858 implement_properties ();
18021859 LOG_DEBUG (" Done" );
18031860}
@@ -1908,6 +1965,7 @@ void ov::npuw::LLMCompiledModel::serialize(std::ostream& stream, const ov::npuw:
19081965 write (model_stream, m_prefix_caching_max_num_blocks);
19091966 write (model_stream, m_gemma_sliding_window_size);
19101967 write (model_stream, m_is_whisper);
1968+ write (model_stream, m_is_text_embed);
19111969
19121970 // Write config
19131971 write (model_stream, m_cfg);
@@ -1931,6 +1989,12 @@ void ov::npuw::LLMCompiledModel::serialize(std::ostream& stream, const ov::npuw:
19311989 if (is_shared_lm_head) {
19321990 m_lm_head_compiled->serialize (model_stream, enc_ctx);
19331991 }
1992+
1993+ const bool is_text_embed_post = m_text_embedding_post_compiled != nullptr ;
1994+ write (model_stream, is_text_embed_post);
1995+ if (is_text_embed_post) {
1996+ m_text_embedding_post_compiled->serialize (model_stream, enc_ctx);
1997+ }
19341998 };
19351999
19362000 std::stringstream non_encrypted_stream;
@@ -2033,6 +2097,11 @@ std::shared_ptr<ov::npuw::LLMCompiledModel> ov::npuw::LLMCompiledModel::import_m
20332097 compiled->m_lm_head_compiled ->m_weights_bank = bank;
20342098 compiled->m_lm_head_compiled ->finalize_weights_bank ();
20352099 }
2100+
2101+ if (compiled->m_text_embedding_post_compiled ) {
2102+ compiled->m_text_embedding_post_compiled ->m_weights_bank = bank;
2103+ compiled->m_text_embedding_post_compiled ->finalize_weights_bank ();
2104+ }
20362105 } else {
20372106 auto bank =
20382107 ov::npuw::weights::Bank::deserialize (model_stream, compiled->get_plugin ()->get_core (), bank_name);
@@ -2050,6 +2119,11 @@ std::shared_ptr<ov::npuw::LLMCompiledModel> ov::npuw::LLMCompiledModel::import_m
20502119 compiled->m_lm_head_compiled ->m_weights_bank = bank;
20512120 compiled->m_lm_head_compiled ->reconstruct_closure ();
20522121 }
2122+
2123+ if (compiled->m_text_embedding_post_compiled ) {
2124+ compiled->m_text_embedding_post_compiled ->m_weights_bank = bank;
2125+ compiled->m_text_embedding_post_compiled ->reconstruct_closure ();
2126+ }
20532127 }
20542128 };
20552129
@@ -2132,6 +2206,7 @@ std::shared_ptr<ov::npuw::LLMCompiledModel> ov::npuw::LLMCompiledModel::deserial
21322206 read (model_stream, compiled->m_prefix_caching_max_num_blocks );
21332207 read (model_stream, compiled->m_gemma_sliding_window_size );
21342208 read (model_stream, compiled->m_is_whisper );
2209+ read (model_stream, compiled->m_is_text_embed );
21352210
21362211 // Deserialize config
21372212 read (model_stream, compiled->m_cfg );
@@ -2166,6 +2241,14 @@ std::shared_ptr<ov::npuw::LLMCompiledModel> ov::npuw::LLMCompiledModel::deserial
21662241 compiled->m_lm_head_compiled =
21672242 ov::npuw::CompiledModel::deserialize (model_stream, plugin, properties, enc_ctx);
21682243 }
2244+
2245+ bool is_text_embed_post = false ;
2246+ read (model_stream, is_text_embed_post);
2247+ if (is_text_embed_post) {
2248+ compiled->m_text_embedding_post_compiled =
2249+ ov::npuw::CompiledModel::deserialize (model_stream, plugin, properties, enc_ctx);
2250+ }
2251+
21692252 return compiled;
21702253 };
21712254
@@ -2247,6 +2330,7 @@ void ov::npuw::LLMCompiledModel::implement_properties() {
22472330 BIND (npuw::llm::prefill_attn_hint, NPUW_LLM_PREFILL_ATTENTION_HINT, getString),
22482331 BIND (npuw::llm::generate_attn_hint, NPUW_LLM_GENERATE_ATTENTION_HINT, getString),
22492332 BIND (npuw::llm::shared_lm_head, NPUW_LLM_SHARED_HEAD, get),
2250- BIND (npuw::whisper::enabled, NPUW_WHISPER, get)});
2333+ BIND (npuw::whisper::enabled, NPUW_WHISPER, get),
2334+ BIND (npuw::text_embed::enabled, NPUW_TEXT_EMBED, get)});
22512335#undef BIND
22522336}
0 commit comments