diff --git a/mseb/encoders/text_encoder_with_prompt.py b/mseb/encoders/text_encoder_with_prompt.py index 4e1c4889..21298803 100644 --- a/mseb/encoders/text_encoder_with_prompt.py +++ b/mseb/encoders/text_encoder_with_prompt.py @@ -234,11 +234,12 @@ def _encode( embeddings_batch.append(prompt.ProcessResponse(response)) outputs = [] - for embeddings, example, response in zip( - embeddings_batch, batch, response_batch + for embeddings, example, response, example_prompt in zip( + embeddings_batch, batch, response_batch, prompt_batch ): if isinstance(response, str): - debug_text = json.dumps({'model_response': response}) + debug_text = json.dumps({'prompt_text': example_prompt[0], + 'model_response': response}) else: debug_text = None assert isinstance(example, types.Text) or isinstance(example, types.Sound) diff --git a/mseb/encoders/text_encoder_with_prompt_test.py b/mseb/encoders/text_encoder_with_prompt_test.py index 7e2b32a0..03ca0a6e 100644 --- a/mseb/encoders/text_encoder_with_prompt_test.py +++ b/mseb/encoders/text_encoder_with_prompt_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import logging from typing import Callable from unittest import mock @@ -22,6 +23,7 @@ import numpy as np import pytest + text_encoder_with_prompt = pytest.importorskip( "mseb.encoders.text_encoder_with_prompt" ) @@ -169,6 +171,73 @@ def test_get_normalized_text_prompt_with_task_prompts(self): "default: hello", ) + def test_encode_with_task_prompts(self): + mock_encoder = MockTextEncoderWithPrompt( + prompt_template="default: {text}", + prompt_encode_fn=lambda prompts: [ + prompt[0].split(":")[0].upper() for prompt in prompts + ], + ) + mock_encoder.task_prompts = { + "Task1": prompt_lib.DefaultPrompt("task1 prompt: {text}"), + } + + mock_task1 = mock.MagicMock() + mock_task1.metadata.name = "Task1" + + mock_task2 = mock.MagicMock() + mock_task2.metadata.name = "Task2" + + mock_encoder.set_task(mock_task1) + embedding = mock_encoder.encode( + [ + types.Text( + text="hello", + context=types.TextContextParams(id="id1"), + ) + ] + )[0] + self.assertEqual( + embedding.embedding, + np.array("TASK1 PROMPT"), + ) + self.assertEqual( + embedding.context, + types.TextContextParams( + id="id1", + text="hello", + debug_text=json.dumps( + {"prompt_text": "task1 prompt: hello", + "model_response": "TASK1 PROMPT"} + ), + ), + ) + + mock_encoder.set_task(mock_task2) + embedding = mock_encoder.encode( + [ + types.Text( + text="hello", + context=types.TextContextParams(id="id1"), + ) + ] + )[0] + self.assertEqual( + embedding.embedding, + np.array("DEFAULT"), + ) + self.assertEqual( + embedding.context, + types.TextContextParams( + id="id1", + text="hello", + debug_text=json.dumps( + {"prompt_text": "default: hello", + "model_response": "DEFAULT"} + ), + ), + ) + if __name__ == "__main__": absltest.main()