Skip to content

Commit 87fd20e

Browse files
committed
Support prefill-chunk for text-embedding model
1 parent 96f4dba commit 87fd20e

File tree

11 files changed

+618
-62
lines changed

11 files changed

+618
-62
lines changed

src/plugins/intel_npu/src/al/include/intel_npu/config/npuw.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ DEFINE_OPT(NPUW_LLM_ENABLE_PREFIX_CACHING, bool, false, npuw::llm::enable_prefix
154154
DEFINE_OPT(NPUW_LLM_PREFIX_CACHING_BLOCK_SIZE, uint64_t, 256, npuw::llm::prefix_caching_block_size, RunTime);
155155
DEFINE_OPT(NPUW_LLM_PREFIX_CACHING_MAX_NUM_BLOCKS, uint64_t, 128, npuw::llm::prefix_caching_max_num_blocks, RunTime);
156156
DEFINE_OPT(NPUW_WHISPER, bool, false, npuw::whisper::enabled, RunTime);
157+
DEFINE_OPT(NPUW_TEXT_EMBED, bool, false, npuw::text_embed::enabled, RunTime);
157158
DEFINE_ANYMAP_OPT(NPUW_LLM_PREFILL_CONFIG, npuw::llm::prefill_config);
158159
DEFINE_ANYMAP_OPT(NPUW_LLM_ADDITIONAL_PREFILL_CONFIG, npuw::llm::additional_prefill_config);
159160
DEFINE_ANYMAP_OPT(NPUW_LLM_GENERATE_CONFIG, npuw::llm::generate_config);

src/plugins/intel_npu/src/al/include/intel_npu/npuw_private_properties.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,18 @@ namespace whisper {
666666
static constexpr ov::Property<bool> enabled{"NPUW_WHISPER"};
667667
} // namespace whisper
668668

669+
namespace text_embed {
670+
/**
671+
* @brief
672+
* Type: bool.
673+
* Tell NPUW that you want to pass text-embedding model.
674+
* Default value: false.
675+
*/
676+
static constexpr ov::Property<bool> enabled{"NPUW_TEXT_EMBED"};
677+
static constexpr ov::Property<std::string> post_type{"NPUW_TEXT_EMBED_POST_TYPE"};
678+
679+
} // namespace text_embed
680+
669681
} // namespace npuw
670682
} // namespace intel_npu
671683
} // namespace ov

src/plugins/intel_npu/src/al/src/config/npuw.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ void intel_npu::registerNPUWLLMOptions(OptionsDesc& desc) {
8181
desc.add<NPUW_LLM_GENERATE_ATTENTION_HINT>();
8282
desc.add<NPUW_LLM_SHARED_HEAD>();
8383
desc.add<NPUW_WHISPER>();
84+
desc.add<NPUW_TEXT_EMBED>();
8485
}
8586

8687
std::string ov::npuw::s11n::anyToString(const ov::Any& var) {

src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.cpp

Lines changed: 126 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,6 +1194,10 @@ ov::AnyMap get_default_lm_head_config(const std::optional<NPUDesc>& npudesc) {
11941194
return config;
11951195
}
11961196

1197+
ov::AnyMap get_default_text_embedding_post_config(const std::optional<NPUDesc>& npudesc) {
1198+
return get_default_lm_head_config(npudesc);
1199+
}
1200+
11971201
void merge_config_with(ov::AnyMap& lhs, const ov::AnyMap& rhs) {
11981202
for (const auto& [key, value] : rhs) {
11991203
// NB: Overwrite the value if key already exists
@@ -1237,6 +1241,10 @@ void update_config_for_whisper(ov::AnyMap& config) {
12371241
config.erase("NPUW_SLICE_OUT");
12381242
}
12391243

1244+
void update_config_for_text_embed(ov::AnyMap& config) {
1245+
config.erase("NPUW_SLICE_OUT");
1246+
}
1247+
12401248
std::map<std::string, std::string> any_copy(const ov::AnyMap& params) {
12411249
std::map<std::string, std::string> result;
12421250
for (auto&& value : params) {
@@ -1463,48 +1471,12 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
14631471
m_cfg.update({{"NPUW_LLM_OPTIMIZE_V_TENSORS", "NO"}});
14641472
}
14651473

1466-
LOG_DEBUG("Creating kvcache model as clone of passed one.");
1467-
auto kvcache_model = model->clone();
1468-
LOG_DEBUG("Transform kvcache model from stateful to stateless.");
1469-
ov::pass::StatefulToStateless().run_on_model(kvcache_model);
1470-
convert_stateful_lora_to_stateless(kvcache_model);
1471-
LOG_DEBUG(" ...also convert BF16 to FP16");
1472-
// Note: we need to identify original bf16 constants for potential weightless deserialization later
1473-
// And only then do bf16 to f16 transformation
1474-
m_bf16_consts = ov::npuw::s11n::get_bf16_consts(model);
1475-
ov::pass::ConvertPrecision(ov::element::bf16, ov::element::f16).run_on_model(kvcache_model);
1476-
1477-
bool shared_head_enabled = m_cfg.get<::intel_npu::NPUW_LLM_SHARED_HEAD>();
1478-
std::shared_ptr<ov::Model> lm_head_model = nullptr;
1479-
if (shared_head_enabled) {
1480-
LOG_DEBUG("Trying to separate Vocabulary matrix multiplication op into additional model...");
1481-
lm_head_model = cut_lm_head(kvcache_model);
1482-
if (lm_head_model) {
1483-
LOG_INFO("Three-model pipeline will be created: LM head will be shared between prefill and generate.");
1484-
} else {
1485-
LOG_WARN("Three-model pipeline is requested, but LM head cutting is failed,"
1486-
" two-model pipeline will be created!");
1487-
}
1488-
} else {
1489-
LOG_INFO("Two-model pipeline will be created.");
1490-
}
1491-
1492-
LOG_DEBUG("Try patch Phi-3 sliding window mask, if it exists.");
1493-
patch_phi3_sliding_mask(kvcache_model);
1494-
1495-
LOG_DEBUG("Creating prefill model as clone of transformed kvcache one.");
1496-
auto prefill_model = kvcache_model->clone();
1497-
prefill_model->set_friendly_name(kvcache_model->get_friendly_name() + "_prefill");
1498-
14991474
// NB: PREFILL_HINT is now compatible with the PREFILL_CONFIG section, unlike for
15001475
// the generate model they're not mutually exclusive
15011476
const ::intel_npu::npuw::llm::PrefillHint prefill_hint = m_cfg.get<::intel_npu::NPUW_LLM_PREFILL_HINT>();
15021477
m_prefill_chunk_size = m_cfg.get<::intel_npu::NPUW_LLM_PREFILL_CHUNK_SIZE>();
15031478
m_use_chunk_prefill = (prefill_hint == ::intel_npu::npuw::llm::PrefillHint::DYNAMIC && m_prefill_chunk_size > 0);
15041479

1505-
const uint32_t batch_dim = m_cfg.get<::intel_npu::NPUW_LLM_BATCH_DIM>();
1506-
const uint32_t seq_len_dim = m_cfg.get<::intel_npu::NPUW_LLM_SEQ_LEN_DIM>();
1507-
KVAxesPosition axes{batch_dim, seq_len_dim};
15081480
uint32_t max_prompt_len = align_to(m_cfg.get<::intel_npu::NPUW_LLM_MAX_PROMPT_LEN>(), 64u);
15091481
const uint32_t min_response_len = align_to(m_cfg.get<::intel_npu::NPUW_LLM_MIN_RESPONSE_LEN>(), 64u);
15101482
uint32_t max_generation_token_len = m_cfg.get<::intel_npu::NPUW_LLM_MAX_GENERATION_TOKEN_LEN>();
@@ -1548,6 +1520,59 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
15481520
LOG_VERB("Prefill chunk size: " << m_prefill_chunk_size);
15491521
LOG_VERB("Maximum prompt length: " << max_prompt_len);
15501522

1523+
const uint32_t batch_dim = m_cfg.get<::intel_npu::NPUW_LLM_BATCH_DIM>();
1524+
const uint32_t seq_len_dim = m_cfg.get<::intel_npu::NPUW_LLM_SEQ_LEN_DIM>();
1525+
KVAxesPosition axes{batch_dim, seq_len_dim};
1526+
1527+
LOG_DEBUG("Creating kvcache model as clone of passed one.");
1528+
auto kvcache_model = model->clone();
1529+
1530+
auto use_text_embed_key = pop_option(other_props, std::string("NPUW_TEXT_EMBED"));
1531+
m_is_text_embed = use_text_embed_key.value_or(false).as<bool>() == true;
1532+
1533+
std::shared_ptr<ov::Model> text_embedding_post_model = nullptr;
1534+
if (m_is_text_embed) {
1535+
if (m_use_chunk_prefill) {
1536+
LOG_DEBUG("Text-Embedding Chunk rebuild");
1537+
ov::npuw::util::prepare_text_embedding_model(kvcache_model, seq_len_dim);
1538+
}
1539+
1540+
auto post_type = pop_option(other_props, std::string("NPUW_TEXT_EMBED_POST_TYPE"));
1541+
ov::npuw::util::create_text_embedding_post_model(kvcache_model, text_embedding_post_model, post_type);
1542+
} else {
1543+
LOG_DEBUG("Transform kvcache model from stateful to stateless.");
1544+
ov::pass::StatefulToStateless().run_on_model(kvcache_model);
1545+
convert_stateful_lora_to_stateless(kvcache_model);
1546+
}
1547+
1548+
LOG_DEBUG(" ...also convert BF16 to FP16");
1549+
// Note: we need to identify original bf16 constants for potential weightless deserialization later
1550+
// And only then do bf16 to f16 transformation
1551+
m_bf16_consts = ov::npuw::s11n::get_bf16_consts(model);
1552+
ov::pass::ConvertPrecision(ov::element::bf16, ov::element::f16).run_on_model(kvcache_model);
1553+
1554+
bool shared_head_enabled = m_cfg.get<::intel_npu::NPUW_LLM_SHARED_HEAD>();
1555+
std::shared_ptr<ov::Model> lm_head_model = nullptr;
1556+
if (shared_head_enabled) {
1557+
LOG_DEBUG("Trying to separate Vocabulary matrix multiplication op into additional model...");
1558+
lm_head_model = cut_lm_head(kvcache_model);
1559+
if (lm_head_model) {
1560+
LOG_INFO("Three-model pipeline will be created: LM head will be shared between prefill and generate.");
1561+
} else {
1562+
LOG_WARN("Three-model pipeline is requested, but LM head cutting is failed,"
1563+
" two-model pipeline will be created!");
1564+
}
1565+
} else {
1566+
LOG_INFO("Two-model pipeline will be created.");
1567+
}
1568+
1569+
LOG_DEBUG("Try patch Phi-3 sliding window mask, if it exists.");
1570+
patch_phi3_sliding_mask(kvcache_model);
1571+
1572+
LOG_DEBUG("Creating prefill model as clone of transformed kvcache one.");
1573+
auto prefill_model = kvcache_model->clone();
1574+
prefill_model->set_friendly_name(kvcache_model->get_friendly_name() + "_prefill");
1575+
15511576
m_kvcache_desc =
15521577
KVCacheDesc{max_prompt_len, max_prompt_len + min_response_len, 0u, seq_len_dim, max_generation_token_len};
15531578

@@ -1564,6 +1589,14 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
15641589
ov::npuw::util::prepare_whisper_kvcache_model(kvcache_model); // Whisper decoder_with_past model
15651590
}
15661591

1592+
if (m_is_text_embed) {
1593+
m_kvcache_desc = KVCacheDesc{max_prompt_len,
1594+
max_prompt_len + min_response_len,
1595+
0u,
1596+
seq_len_dim,
1597+
max_prompt_len + min_response_len};
1598+
}
1599+
15671600
LOG_DEBUG("Make prefill model with static shapes");
15681601
m_max_lora_rank = m_cfg.get<::intel_npu::NPUW_LLM_MAX_LORA_RANK>();
15691602
if (m_use_chunk_prefill) {
@@ -1591,6 +1624,14 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
15911624
gemma_transformations(model_variant);
15921625
}
15931626

1627+
if (text_embedding_post_model) {
1628+
auto input_node = text_embedding_post_model->inputs()[0];
1629+
NPUW_ASSERT(input_node.get_partial_shape().size() == 3u);
1630+
auto new_shape = ov::PartialShape({1, m_kvcache_desc.max_prompt_size, input_node.get_partial_shape()[2]});
1631+
auto mask_shape = ov::PartialShape({1, m_kvcache_desc.max_prompt_size});
1632+
text_embedding_post_model->reshape({{"input_ids", new_shape}, {"attention_mask", mask_shape}});
1633+
}
1634+
15941635
if (lm_head_model) {
15951636
LOG_DEBUG("Shared LM head: slice the prefill output");
15961637
// KVCache model is already reshaped to [1, max_generation_token_len, embed size],
@@ -1653,18 +1694,23 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
16531694
LOG_DEBUG("Check and apply opt layout --- SKIPPED");
16541695
}
16551696

1656-
if (!m_use_chunk_prefill) {
1697+
if (!m_use_chunk_prefill && !m_is_text_embed) {
16571698
NPUW_ASSERT(remove_empty_kv_inputs(prefill_model));
16581699
} else {
16591700
LOG_DEBUG("Don't remove input key/values from prefill model.");
16601701
LOG_DEBUG("Ask prefill model to output key/values for prefill chunk size tokens.");
1661-
prefill_model = redirect_new_kv_to_output(prefill_model);
1702+
if (!m_is_text_embed) {
1703+
prefill_model = redirect_new_kv_to_output(prefill_model);
1704+
}
16621705
}
16631706

1664-
LOG_DEBUG("Optimize generate model to output key/values for new token.");
1665-
for (size_t i = 0; i < generate_model_variants.size(); ++i) {
1666-
generate_model_variants[i] = redirect_new_kv_to_output(generate_model_variants[i]);
1707+
if (!m_is_text_embed) {
1708+
LOG_DEBUG("Optimize generate model to output key/values for new token.");
1709+
for (size_t i = 0; i < generate_model_variants.size(); ++i) {
1710+
generate_model_variants[i] = redirect_new_kv_to_output(generate_model_variants[i]);
1711+
}
16671712
}
1713+
16681714
LOG_DEBUG("Converting KV-cache in generate model to FP16.");
16691715
for (size_t i = 0; i < generate_model_variants.size(); ++i) {
16701716
generate_model_variants[i] = cvt_kvcache_to_fp16(generate_model_variants[i]);
@@ -1730,6 +1776,10 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
17301776
update_config_for_whisper(prefill_config);
17311777
}
17321778

1779+
if (m_is_text_embed) {
1780+
update_config_for_text_embed(prefill_config);
1781+
}
1782+
17331783
if (m_cfg.get<::intel_npu::NPUW_LLM_CACHE_ROPE>()) {
17341784
LOG_DEBUG("Caching preROPE ");
17351785
const uint32_t CACHE_ROPE_START = 2048;
@@ -1798,6 +1848,13 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
17981848
NPUW_ASSERT(m_lm_head_compiled);
17991849
}
18001850

1851+
if (text_embedding_post_model) {
1852+
auto post_config = get_default_text_embedding_post_config(npudesc);
1853+
merge_config_with(post_config, other_props);
1854+
m_text_embedding_post_compiled = std::dynamic_pointer_cast<ov::npuw::CompiledModel>(
1855+
ov::npuw::ICompiledModel::create(text_embedding_post_model, plugin, post_config));
1856+
}
1857+
18011858
implement_properties();
18021859
LOG_DEBUG("Done");
18031860
}
@@ -1908,6 +1965,7 @@ void ov::npuw::LLMCompiledModel::serialize(std::ostream& stream, const ov::npuw:
19081965
write(model_stream, m_prefix_caching_max_num_blocks);
19091966
write(model_stream, m_gemma_sliding_window_size);
19101967
write(model_stream, m_is_whisper);
1968+
write(model_stream, m_is_text_embed);
19111969

19121970
// Write config
19131971
write(model_stream, m_cfg);
@@ -1931,6 +1989,12 @@ void ov::npuw::LLMCompiledModel::serialize(std::ostream& stream, const ov::npuw:
19311989
if (is_shared_lm_head) {
19321990
m_lm_head_compiled->serialize(model_stream, enc_ctx);
19331991
}
1992+
1993+
const bool is_text_embed_post = m_text_embedding_post_compiled != nullptr;
1994+
write(model_stream, is_text_embed_post);
1995+
if (is_text_embed_post) {
1996+
m_text_embedding_post_compiled->serialize(model_stream, enc_ctx);
1997+
}
19341998
};
19351999

19362000
std::stringstream non_encrypted_stream;
@@ -2033,6 +2097,11 @@ std::shared_ptr<ov::npuw::LLMCompiledModel> ov::npuw::LLMCompiledModel::import_m
20332097
compiled->m_lm_head_compiled->m_weights_bank = bank;
20342098
compiled->m_lm_head_compiled->finalize_weights_bank();
20352099
}
2100+
2101+
if (compiled->m_text_embedding_post_compiled) {
2102+
compiled->m_text_embedding_post_compiled->m_weights_bank = bank;
2103+
compiled->m_text_embedding_post_compiled->finalize_weights_bank();
2104+
}
20362105
} else {
20372106
auto bank =
20382107
ov::npuw::weights::Bank::deserialize(model_stream, compiled->get_plugin()->get_core(), bank_name);
@@ -2050,6 +2119,11 @@ std::shared_ptr<ov::npuw::LLMCompiledModel> ov::npuw::LLMCompiledModel::import_m
20502119
compiled->m_lm_head_compiled->m_weights_bank = bank;
20512120
compiled->m_lm_head_compiled->reconstruct_closure();
20522121
}
2122+
2123+
if (compiled->m_text_embedding_post_compiled) {
2124+
compiled->m_text_embedding_post_compiled->m_weights_bank = bank;
2125+
compiled->m_text_embedding_post_compiled->reconstruct_closure();
2126+
}
20532127
}
20542128
};
20552129

@@ -2132,6 +2206,7 @@ std::shared_ptr<ov::npuw::LLMCompiledModel> ov::npuw::LLMCompiledModel::deserial
21322206
read(model_stream, compiled->m_prefix_caching_max_num_blocks);
21332207
read(model_stream, compiled->m_gemma_sliding_window_size);
21342208
read(model_stream, compiled->m_is_whisper);
2209+
read(model_stream, compiled->m_is_text_embed);
21352210

21362211
// Deserialize config
21372212
read(model_stream, compiled->m_cfg);
@@ -2166,6 +2241,14 @@ std::shared_ptr<ov::npuw::LLMCompiledModel> ov::npuw::LLMCompiledModel::deserial
21662241
compiled->m_lm_head_compiled =
21672242
ov::npuw::CompiledModel::deserialize(model_stream, plugin, properties, enc_ctx);
21682243
}
2244+
2245+
bool is_text_embed_post = false;
2246+
read(model_stream, is_text_embed_post);
2247+
if (is_text_embed_post) {
2248+
compiled->m_text_embedding_post_compiled =
2249+
ov::npuw::CompiledModel::deserialize(model_stream, plugin, properties, enc_ctx);
2250+
}
2251+
21692252
return compiled;
21702253
};
21712254

@@ -2247,6 +2330,7 @@ void ov::npuw::LLMCompiledModel::implement_properties() {
22472330
BIND(npuw::llm::prefill_attn_hint, NPUW_LLM_PREFILL_ATTENTION_HINT, getString),
22482331
BIND(npuw::llm::generate_attn_hint, NPUW_LLM_GENERATE_ATTENTION_HINT, getString),
22492332
BIND(npuw::llm::shared_lm_head, NPUW_LLM_SHARED_HEAD, get),
2250-
BIND(npuw::whisper::enabled, NPUW_WHISPER, get)});
2333+
BIND(npuw::whisper::enabled, NPUW_WHISPER, get),
2334+
BIND(npuw::text_embed::enabled, NPUW_TEXT_EMBED, get)});
22512335
#undef BIND
22522336
}

src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ class LLMCompiledModel : public ov::npuw::ICompiledModel {
9292
std::shared_ptr<ov::npuw::CompiledModel> m_prefill_compiled;
9393
// This model is optional, so can be null.
9494
std::shared_ptr<ov::npuw::CompiledModel> m_lm_head_compiled;
95+
std::shared_ptr<ov::npuw::CompiledModel> m_text_embedding_post_compiled;
9596

9697
// Multiple generate models with different static KV cache shapes (1K, 2K, 4K, 8K stepping)
9798
std::vector<std::shared_ptr<ov::npuw::CompiledModel>> m_generate_compiled_variants;
@@ -113,6 +114,7 @@ class LLMCompiledModel : public ov::npuw::ICompiledModel {
113114
int32_t m_gemma_sliding_window_size = 0;
114115

115116
bool m_is_whisper = false;
117+
bool m_is_text_embed = false;
116118

117119
// Create generate model variants with different sizes
118120
std::vector<std::shared_ptr<ov::Model>> create_generate_model_variants(

0 commit comments

Comments
 (0)