|
11 | 11 | "example_1": "[INST] Who made Berlin [/INST] dunno", |
12 | 12 | "example_2": "[INST] Quiero preparar una pizza de pepperoni, puedes darme los pasos para hacerla? [/INST] Claro!", |
13 | 13 | }, |
14 | | - "meta-llama/Meta-Llama-3.1-8B":{ |
| 14 | + "meta-llama/Meta-Llama-3.1-8B-Instruct":{ |
15 | 15 | "example_1": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nWho made Berlin<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\ndunno<|eot_id|><|end_of_text|>", |
16 | 16 | "example_2": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow to start learning guitar and become a master at it?", |
17 | 17 | }, |
@@ -114,3 +114,30 @@ def test_unknown_dataset_error(step_lr, optimizer, tokenizer, get_model, train, |
114 | 114 | } |
115 | 115 | with pytest.raises(AttributeError): |
116 | 116 | main(**kwargs) |
| 117 | + |
| 118 | +@pytest.mark.skip_missing_tokenizer |
| 119 | +@patch('llama_recipes.finetuning.AutoTokenizer') |
| 120 | +def test_tokenize_dialog(tokenizer, monkeypatch, setup_tokenizer, llama_version): |
| 121 | + monkeypatch.syspath_prepend("recipes/quickstart/finetuning/datasets/") |
| 122 | + from custom_dataset import tokenize_dialog |
| 123 | + |
| 124 | + setup_tokenizer(tokenizer) |
| 125 | + tokenizer = tokenizer.from_pretrained() |
| 126 | + |
| 127 | + dialog = [ |
| 128 | + {"role":"user", "content":"Who made Berlin?"}, |
| 129 | + {"role":"assistant", "content":"dunno"}, |
| 130 | + {"role":"user", "content":"And Rome?"}, |
| 131 | + {"role":"assistant", "content":"Romans"}, |
| 132 | + ] |
| 133 | + |
| 134 | + result = tokenize_dialog(dialog, tokenizer) |
| 135 | + |
| 136 | + if "Llama-2" in llama_version: |
| 137 | + assert result["labels"][:12] == [-100] * 12 |
| 138 | + assert result["labels"][17:28] == [-100] * 11 |
| 139 | + assert result["labels"].count(-100) == 11 + 12 |
| 140 | + else: |
| 141 | + assert result["labels"][:38] == [-100] * 38 |
| 142 | + assert result["labels"][43:54] == [-100] * 11 |
| 143 | + assert result["labels"].count(-100) == 38 + 11 |
0 commit comments