diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 4120f0926f0f..740ba073f329 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2718,7 +2718,7 @@ def test_beam_search_early_stop_heuristic(self): question = tokenizer.apply_chat_template( question, tokenize=False, add_generation_prompt=True, return_tensors="pt" ) - inputs = tokenizer(question, return_tensors="pt", padding=True).to("cuda") + inputs = tokenizer(question, return_tensors="pt", padding=True).to(torch_device) outputs = model.generate(**inputs, generation_config=generation_config) responses = tokenizer.batch_decode(outputs, skip_special_tokens=True) self.assertEqual(responses[0], EXPECTED_OUTPUT) @@ -2737,7 +2737,7 @@ def test_beam_search_early_stop_heuristic(self): cot_question = tokenizer.apply_chat_template( cot_question, tokenize=False, add_generation_prompt=True, return_tensors="pt" ) - inputs = tokenizer([question, cot_question], return_tensors="pt", padding=True).to("cuda") + inputs = tokenizer([question, cot_question], return_tensors="pt", padding=True).to(torch_device) outputs = model.generate(**inputs, generation_config=generation_config) responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)