Skip to content

Commit deb6145

Browse files
authored
Refactor and update file checks logic (#373)
1 parent 423bcac commit deb6145

File tree

3 files changed

+202
-60
lines changed

3 files changed

+202
-60
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"
1212

1313
[tool.poetry]
1414
name = "together"
15-
version = "1.5.29"
15+
version = "1.5.30"
1616
authors = ["Together AI <[email protected]>"]
1717
description = "Python client for Together's Cloud Platform!"
1818
readme = "README.md"

src/together/utils/files.py

Lines changed: 175 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -102,81 +102,163 @@ def check_file(
102102
return report_dict
103103

104104

105-
def validate_messages(messages: List[Dict[str, str | bool]], idx: int) -> None:
106-
"""Validate the messages column."""
105+
def _check_conversation_type(messages: List[Dict[str, str | bool]], idx: int) -> None:
106+
"""Check that the conversation has correct type.
107+
108+
Args:
109+
messages: The messages in the conversation.
110+
Can be any type, this function ensures that the messages are a list of dictionaries.
111+
idx: Line number in the file.
112+
113+
Raises:
114+
InvalidFileFormatError: If the conversation type is invalid.
115+
"""
107116
if not isinstance(messages, list):
108117
raise InvalidFileFormatError(
109118
message=f"Invalid format on line {idx + 1} of the input file. "
110-
f"Expected a list of messages. Found {type(messages)}",
119+
f"The `messages` column must be a list. Found {type(messages)}",
111120
line_number=idx + 1,
112121
error_source="key_value",
113122
)
114-
if not messages:
123+
if len(messages) == 0:
115124
raise InvalidFileFormatError(
116125
message=f"Invalid format on line {idx + 1} of the input file. "
117-
f"Expected a non-empty list of messages. Found empty list",
126+
f"The `messages` column must not be empty.",
118127
line_number=idx + 1,
119128
error_source="key_value",
120129
)
121130

122-
has_weights = any("weight" in message for message in messages)
123-
124-
previous_role = None
125131
for message in messages:
126132
if not isinstance(message, dict):
127133
raise InvalidFileFormatError(
128134
message=f"Invalid format on line {idx + 1} of the input file. "
129-
f"Expected a dictionary in the messages list. Found {type(message)}",
135+
f"The `messages` column must be a list of dicts. Found {type(message)}",
130136
line_number=idx + 1,
131137
error_source="key_value",
132138
)
139+
133140
for column in REQUIRED_COLUMNS_MESSAGE:
134141
if column not in message:
135142
raise InvalidFileFormatError(
136-
message=f"Field `{column}` is missing for a turn `{message}` on line {idx + 1} "
137-
"of the the input file.",
143+
message=f"Missing required column `{column}` in message on line {idx + 1}.",
138144
line_number=idx + 1,
139145
error_source="key_value",
140146
)
141-
else:
142-
if not isinstance(message[column], str):
143-
raise InvalidFileFormatError(
144-
message=f"Invalid format on line {idx + 1} in the column {column} for turn `{message}` "
145-
f"of the input file. Expected string. Found {type(message[column])}",
146-
line_number=idx + 1,
147-
error_source="text_field",
148-
)
149-
150-
if has_weights and "weight" in message:
151-
weight = message["weight"]
152-
if not isinstance(weight, int):
153-
raise InvalidFileFormatError(
154-
message="Weight must be an integer",
155-
line_number=idx + 1,
156-
error_source="key_value",
157-
)
158-
if weight not in {0, 1}:
147+
if not isinstance(message[column], str):
159148
raise InvalidFileFormatError(
160-
message="Weight must be either 0 or 1",
149+
message=f"Column `{column}` is not a string on line {idx + 1}. Found {type(message[column])}",
161150
line_number=idx + 1,
162-
error_source="key_value",
151+
error_source="text_field",
163152
)
164-
if message["role"] not in POSSIBLE_ROLES_CONVERSATION:
153+
154+
155+
def _check_conversation_roles(
156+
require_assistant_role: bool, assistant_role_exists: bool, idx: int
157+
) -> None:
158+
"""Check that the conversation has correct roles.
159+
160+
Args:
161+
require_assistant_role: Whether to require at least one assistant role.
162+
assistant_role_exists: Whether an assistant role exists in the conversation.
163+
idx: Line number in the file.
164+
165+
Raises:
166+
InvalidFileFormatError: If the conversation roles are invalid.
167+
"""
168+
if require_assistant_role and not assistant_role_exists:
169+
raise InvalidFileFormatError(
170+
message=f"Invalid format on line {idx + 1} of the input file. "
171+
"At least one message with the assistant role must be present in the example.",
172+
line_number=idx + 1,
173+
error_source="key_value",
174+
)
175+
176+
177+
def _check_message_weight(message: Dict[str, str | bool], idx: int) -> None:
178+
"""Check that the message has a weight with the correct type and value.
179+
180+
Args:
181+
message: The message to check.
182+
idx: Line number in the file.
183+
184+
Raises:
185+
InvalidFileFormatError: If the message weight is invalid.
186+
"""
187+
if "weight" in message:
188+
weight = message["weight"]
189+
if not isinstance(weight, int):
165190
raise InvalidFileFormatError(
166-
message=f"Found invalid role `{message['role']}` in the messages on the line {idx + 1}. "
167-
f"Possible roles in the conversation are: {POSSIBLE_ROLES_CONVERSATION}",
191+
message=f"Weight must be an integer on line {idx + 1}.",
168192
line_number=idx + 1,
169193
error_source="key_value",
170194
)
171-
172-
if previous_role == message["role"]:
195+
if weight not in {0, 1}:
173196
raise InvalidFileFormatError(
174-
message=f"Invalid role turns on line {idx + 1} of the input file. "
175-
"`user` and `assistant` roles must alternate user/assistant/user/assistant/...",
197+
message=f"Weight must be either 0 or 1 on line {idx + 1}.",
176198
line_number=idx + 1,
177199
error_source="key_value",
178200
)
179-
previous_role = message["role"]
201+
202+
203+
def _check_message_role(
204+
message: Dict[str, str | bool], previous_role: str | None, idx: int
205+
) -> str | bool:
206+
"""Check that the message has correct roles.
207+
208+
Args:
209+
message: The message to check.
210+
previous_role: The role of the previous message.
211+
idx: Line number in the file.
212+
213+
Returns:
214+
str: The role of the current message.
215+
216+
Raises:
217+
InvalidFileFormatError: If the message role is invalid.
218+
"""
219+
if message["role"] not in POSSIBLE_ROLES_CONVERSATION:
220+
raise InvalidFileFormatError(
221+
message=f"Invalid role `{message['role']}` in conversation on line {idx + 1}. "
222+
f"Possible roles: {', '.join(POSSIBLE_ROLES_CONVERSATION)}",
223+
line_number=idx + 1,
224+
error_source="key_value",
225+
)
226+
if previous_role is not None and message["role"] == previous_role:
227+
raise InvalidFileFormatError(
228+
message=f"Invalid role turns on line {idx + 1} of the input file. "
229+
"After the optional system message, conversation roles must alternate between user/assistant/user/assistant.",
230+
line_number=idx + 1,
231+
error_source="key_value",
232+
)
233+
return message["role"]
234+
235+
236+
def validate_messages(
237+
messages: List[Dict[str, str | bool]], idx: int, require_assistant_role: bool = True
238+
) -> None:
239+
"""Validate the messages column.
240+
241+
Args:
242+
messages: List of message dictionaries to validate.
243+
idx: Line number in the file.
244+
require_assistant_role: Whether to require at least one assistant role.
245+
246+
Raises:
247+
InvalidFileFormatError: If the messages are invalid.
248+
"""
249+
_check_conversation_type(messages, idx)
250+
251+
has_weights = any("weight" in message for message in messages)
252+
previous_role = None
253+
assistant_role_exists = False
254+
255+
for message in messages:
256+
if has_weights:
257+
_check_message_weight(message, idx)
258+
previous_role = _check_message_role(message, previous_role, idx)
259+
assistant_role_exists |= previous_role == "assistant"
260+
261+
_check_conversation_roles(require_assistant_role, assistant_role_exists, idx)
180262

181263

182264
def validate_preference_openai(example: Dict[str, Any], idx: int = 0) -> None:
@@ -203,37 +285,73 @@ def validate_preference_openai(example: Dict[str, Any], idx: int = 0) -> None:
203285
error_source="key_value",
204286
)
205287

206-
validate_messages(example["input"]["messages"], idx)
288+
validate_messages(example["input"]["messages"], idx, require_assistant_role=False)
289+
290+
if example["input"]["messages"][-1]["role"] == "assistant":
291+
raise InvalidFileFormatError(
292+
message=f"The last message in the input conversation must not be from the assistant on line {idx + 1}.",
293+
line_number=idx + 1,
294+
error_source="key_value",
295+
)
296+
297+
keys = ["preferred_output", "non_preferred_output"]
298+
299+
for key in keys:
300+
if key not in example:
301+
raise InvalidFileFormatError(
302+
message=f"The dataset is malformed, the `{key}` field must be present in the input dictionary on line {idx + 1}.",
303+
line_number=idx + 1,
304+
error_source="key_value",
305+
)
306+
307+
if not isinstance(example[key], list):
308+
raise InvalidFileFormatError(
309+
message=f"The dataset is malformed, the `{key}` field must be a list on line {idx + 1}.",
310+
line_number=idx + 1,
311+
error_source="key_value",
312+
)
313+
314+
if len(example[key]) != 1:
315+
raise InvalidFileFormatError(
316+
message=f"The dataset is malformed, the `{key}` list must contain exactly one message on line {idx + 1}.",
317+
line_number=idx + 1,
318+
error_source="key_value",
319+
)
207320

208-
for output_field in ["preferred_output", "non_preferred_output"]:
209-
if not isinstance(example[output_field], list):
321+
if not isinstance(example[key][0], dict):
210322
raise InvalidFileFormatError(
211-
message=f"The dataset is malformed, the `{output_field}` field must be a list.",
323+
message=f"The dataset is malformed, the first element of `{key}` must be a dictionary on line {idx + 1}.",
212324
line_number=idx + 1,
213325
error_source="key_value",
214326
)
215327

216-
if len(example[output_field]) != 1:
328+
if "role" not in example[key][0]:
217329
raise InvalidFileFormatError(
218-
message=f"The dataset is malformed, the `{output_field}` list must contain exactly one message.",
330+
message=f"The dataset is malformed, the first element of `{key}` must have a 'role' field on line {idx + 1}.",
219331
line_number=idx + 1,
220332
error_source="key_value",
221333
)
222-
if "role" not in example[output_field][0]:
334+
335+
if example[key][0]["role"] != "assistant":
223336
raise InvalidFileFormatError(
224-
message=f"The dataset is malformed, the `{output_field}` message is missing the `role` field.",
337+
message=f"The dataset is malformed, the first element of `{key}` must have the 'assistant' role on line {idx + 1}.",
225338
line_number=idx + 1,
226339
error_source="key_value",
227340
)
228-
elif example[output_field][0]["role"] != "assistant":
341+
342+
if "content" not in example[key][0]:
229343
raise InvalidFileFormatError(
230-
message=f"The dataset is malformed, the `{output_field}` must contain an assistant message.",
344+
message=f"The dataset is malformed, the first element of `{key}` must have a 'content' field on line {idx + 1}.",
231345
line_number=idx + 1,
232346
error_source="key_value",
233347
)
234348

235-
validate_messages(example["preferred_output"], idx)
236-
validate_messages(example["non_preferred_output"], idx)
349+
if not isinstance(example[key][0]["content"], str):
350+
raise InvalidFileFormatError(
351+
message=f"The dataset is malformed, the 'content' field in `{key}` must be a string on line {idx + 1}.",
352+
line_number=idx + 1,
353+
error_source="key_value",
354+
)
237355

238356

239357
def _check_utf8(file: Path) -> Dict[str, Any]:
@@ -410,7 +528,12 @@ def _check_jsonl(file: Path, purpose: FilePurpose | str) -> Dict[str, Any]:
410528
message_column = JSONL_REQUIRED_COLUMNS_MAP[
411529
DatasetFormat.CONVERSATION
412530
][0]
413-
validate_messages(json_line[message_column], idx)
531+
require_assistant = purpose != FilePurpose.Eval
532+
validate_messages(
533+
json_line[message_column],
534+
idx,
535+
require_assistant_role=require_assistant,
536+
)
414537
else:
415538
for column in JSONL_REQUIRED_COLUMNS_MAP[current_format]:
416539
if not isinstance(json_line[column], str):

tests/unit/test_files_checks.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,12 @@ def test_check_jsonl_inconsistent_dataset_format(tmp_path: Path):
182182
# Create a JSONL file with inconsistent dataset formats
183183
file = tmp_path / "inconsistent_format.jsonl"
184184
content = [
185-
{"messages": [{"role": "user", "content": "Hi"}]},
185+
{
186+
"messages": [
187+
{"role": "user", "content": "Hi"},
188+
{"role": "assistant", "content": "Hi! How can I help you?"},
189+
]
190+
},
186191
{"text": "How are you?"}, # Missing 'messages'
187192
]
188193
with file.open("w") as f:
@@ -207,7 +212,7 @@ def test_check_jsonl_invalid_role(tmp_path: Path):
207212
report = check_file(file)
208213

209214
assert not report["is_check_passed"]
210-
assert "Found invalid role `invalid_role`" in report["message"]
215+
assert "Invalid role `invalid_role` in conversation" in report["message"]
211216

212217

213218
def test_check_jsonl_non_alternating_roles(tmp_path: Path):
@@ -230,6 +235,22 @@ def test_check_jsonl_non_alternating_roles(tmp_path: Path):
230235
assert "Invalid role turns" in report["message"]
231236

232237

238+
def test_check_jsonl_assistant_role_exists(tmp_path: Path):
239+
# Create a JSONL file with no assistant role
240+
file = tmp_path / "assistant_role_exists.jsonl"
241+
content = [{"messages": [{"role": "user", "content": "Hi"}]}]
242+
with file.open("w") as f:
243+
f.write("\n".join(json.dumps(item) for item in content))
244+
245+
report = check_file(file)
246+
247+
assert not report["is_check_passed"]
248+
assert (
249+
"At least one message with the assistant role must be present"
250+
in report["message"]
251+
)
252+
253+
233254
def test_check_jsonl_invalid_value_type(tmp_path: Path):
234255
# Create a JSONL file with an invalid value type
235256
file = tmp_path / "invalid_value_type.jsonl"
@@ -257,7 +278,7 @@ def test_check_jsonl_missing_field_in_conversation(tmp_path: Path):
257278

258279
report = check_file(file)
259280
assert not report["is_check_passed"]
260-
assert "Field `content` is missing for a turn" in report["message"]
281+
assert "Missing required column `content`" in report["message"]
261282

262283

263284
def test_check_jsonl_wrong_turn_type(tmp_path: Path):
@@ -277,7 +298,7 @@ def test_check_jsonl_wrong_turn_type(tmp_path: Path):
277298
report = check_file(file)
278299
assert not report["is_check_passed"]
279300
assert (
280-
"Invalid format on line 1 of the input file. Expected a dictionary"
301+
"Invalid format on line 1 of the input file. The `messages` column must be a list of dicts."
281302
in report["message"]
282303
)
283304

@@ -301,9 +322,7 @@ def test_check_jsonl_empty_messages(tmp_path: Path):
301322

302323
report = check_file(file)
303324
assert not report["is_check_passed"]
304-
assert (
305-
"Expected a non-empty list of messages. Found empty list" in report["message"]
306-
)
325+
assert "The `messages` column must not be empty" in report["message"]
307326

308327

309328
def test_check_jsonl_valid_weights_all_messages(tmp_path: Path):

0 commit comments

Comments
 (0)