Skip to content

Commit a591398

Browse files
committed
Support NPUW for text-embedding models
1 parent df1c52d commit a591398

File tree

6 files changed

+336
-181
lines changed

6 files changed

+336
-181
lines changed

src/cpp/src/rag/text_embedding_pipeline.cpp

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,18 @@ std::optional<size_t> read_max_position_embeddings(const std::filesystem::path&
169169
return max_position_embeddings;
170170
}
171171

172+
std::string get_post_type_string(const TextEmbeddingPipeline::Config& config) {
173+
std::string post_type;
174+
if (config.pooling_type == TextEmbeddingPipeline::PoolingType::CLS) {
175+
post_type = "cls";
176+
} else if (config.pooling_type == TextEmbeddingPipeline::PoolingType::MEAN) {
177+
post_type = "mean";
178+
} else {
179+
post_type = "last_token";
180+
}
181+
return post_type;
182+
}
183+
172184
} // namespace
173185

174186
namespace ov {
@@ -211,32 +223,54 @@ class TextEmbeddingPipeline::TextEmbeddingPipelineImpl {
211223

212224
auto model = core.read_model(models_path / "openvino_model.xml", {}, properties);
213225

214-
const bool should_reshape = m_config.batch_size.has_value() || m_config.max_length.has_value();
215-
if (should_reshape) {
216-
reshape_model(model);
217-
}
218-
219-
if (device == "NPU") {
220-
OPENVINO_ASSERT(!model->is_dynamic(),
221-
"NPU device does not support dynamic shapes. In order to fix model shape, set batch_size, "
222-
"max_length and pad_to_max_length in the configuration.");
223-
}
224-
225-
model = apply_postprocessing(model, m_config);
226-
226+
bool is_fixed_size = true;
227227
if (m_config.max_length) {
228228
m_tokenization_params.insert({max_length.name(), *m_config.max_length});
229+
} else {
230+
is_fixed_size = false;
229231
}
230232

231233
if (m_config.pad_to_max_length) {
232234
m_tokenization_params.insert({pad_to_max_length.name(), *m_config.pad_to_max_length});
235+
is_fixed_size &= m_config.pad_to_max_length.value();
236+
} else {
237+
is_fixed_size = false;
233238
}
234239

240+
bool is_padding_on_left = false;
235241
if (m_config.padding_side) {
236242
m_tokenization_params.insert({padding_side.name(), *m_config.padding_side});
243+
if (m_config.padding_side.value() == "left") {
244+
is_padding_on_left = true;
245+
}
246+
}
247+
248+
bool should_reshape_non_npu =
249+
(device != "NPU" && (m_config.batch_size.has_value() || m_config.max_length.has_value()));
250+
bool should_reshape_npu = (device == "NPU" && m_config.batch_size.has_value() && is_fixed_size);
251+
if (should_reshape_non_npu || should_reshape_npu) {
252+
reshape_model(model);
237253
}
238254

239-
ov::CompiledModel compiled_model = core.compile_model(model, device, properties);
255+
ov::CompiledModel compiled_model;
256+
if (device == "NPU" && model->is_dynamic()) {
257+
OPENVINO_ASSERT(!(is_padding_on_left && is_fixed_size) ||
258+
config.pooling_type == TextEmbeddingPipeline::PoolingType::MEAN,
259+
"Padding on left is only supported for the mean post-processing type");
260+
261+
auto kv_pos = ov::genai::utils::get_kv_axes_pos(model);
262+
utils::KVDesc kv_desc;
263+
ov::AnyMap local_config;
264+
local_config["NPUW_TEXT_EMBED_POST_TYPE"] = get_post_type_string(config);
265+
if (m_config.max_length.has_value()) {
266+
local_config["MAX_PROMPT_LEN"] = m_config.max_length.value();
267+
}
268+
std::tie(compiled_model, kv_desc) =
269+
utils::compile_decoder_for_npu_text_embedding(model, properties, kv_pos, local_config);
270+
} else {
271+
model = apply_postprocessing(model, m_config);
272+
compiled_model = core.compile_model(model, device, properties);
273+
}
240274

241275
utils::print_compiled_model_properties(compiled_model, "text embedding model");
242276
m_request = compiled_model.create_infer_request();
@@ -383,7 +417,6 @@ class TextEmbeddingPipeline::TextEmbeddingPipelineImpl {
383417

384418
std::vector<std::vector<float>> result;
385419
const auto shape = last_hidden_state.get_shape();
386-
387420
const size_t batch_size = shape[0];
388421
const size_t hidden_size = shape[1];
389422

0 commit comments

Comments
 (0)