@@ -54,7 +54,7 @@ def forward(self, x):
5454
5555@pytest .fixture
5656def model ():
57- model = DummyModel ()
57+ model = DummyModel (out_channels = 10 )
5858 yield model
5959
6060
@@ -305,7 +305,7 @@ def test_detach_terminates(lr_finder, to_save, dummy_engine, dataloader, recwarn
305305
306306 dummy_engine .run (dataloader , max_epochs = 3 )
307307 assert dummy_engine .state .epoch == 3
308- assert len (recwarn ) == 1
308+ assert len (recwarn ) == 0
309309
310310
311311def test_lr_suggestion_unexpected_curve (lr_finder , to_save , dummy_engine , dataloader ):
@@ -320,10 +320,11 @@ def test_lr_suggestion_unexpected_curve(lr_finder, to_save, dummy_engine, datalo
320320
321321
322322def test_lr_suggestion_single_param_group (lr_finder ): # , to_save, dummy_engine, dataloader):
323+ import numpy as np
323324
324325 noise = 0.05
325- lr_finder ._history ["loss" ] = torch .linspace (- 5.0 , 5.0 , steps = 100 ) ** 2 + noise
326- lr_finder ._history ["lr" ] = torch .linspace (0.01 , 10 , steps = 100 )
326+ lr_finder ._history ["loss" ] = np .linspace (- 5.0 , 5.0 , num = 100 ) ** 2 + noise
327+ lr_finder ._history ["lr" ] = np .linspace (0.01 , 10 , num = 100 )
327328
328329 # lr_finder.lr_suggestion() is supposed to return a value, but as
329330 # we assign loss and lr to tensors, instead of lists, it will return tensors
@@ -336,9 +337,9 @@ def test_lr_suggestion_multiple_param_groups(lr_finder):
336337 import numpy as np
337338
338339 noise = 0.06
339- lr_finder ._history ["loss" ] = torch . tensor ( np .linspace (- 5.0 , 5 , num = 50 ) ** 2 + noise )
340+ lr_finder ._history ["loss" ] = np .linspace (- 5.0 , 5 , num = 50 ) ** 2 + noise
340341 # 2 param_groups
341- lr_finder ._history ["lr" ] = torch . tensor ( np .linspace (0.01 , 10 , num = 100 ) ).reshape (50 , 2 )
342+ lr_finder ._history ["lr" ] = np .linspace (0.01 , 10 , num = 100 ).reshape (50 , 2 )
342343
343344 # lr_finder.lr_suggestion() is supposed to return a list of values,
344345 # but as we assign loss and lr to tensors, instead of lists, it will return tensors
@@ -352,7 +353,7 @@ def test_lr_suggestion_mnist(lr_finder, mnist_to_save, dummy_engine_mnist, mnist
352353
353354 max_iters = 50
354355
355- with lr_finder .attach (dummy_engine_mnist , mnist_to_save ) as trainer_with_finder :
356+ with lr_finder .attach (dummy_engine_mnist , mnist_to_save , diverge_th = 2 , step_mode = "linear" ) as trainer_with_finder :
356357
357358 with trainer_with_finder .add_event_handler (
358359 Events .ITERATION_COMPLETED (once = max_iters ), lambda _ : trainer_with_finder .terminate ()
@@ -393,11 +394,11 @@ def test_apply_suggested_lr_multiple_param_groups(
393394 to_save_mulitple_param_groups ,
394395 dummy_engine_mulitple_param_groups ,
395396 optimizer_multiple_param_groups ,
396- dataloader ,
397+ dataloader_plot ,
397398):
398399
399400 with lr_finder .attach (dummy_engine_mulitple_param_groups , to_save_mulitple_param_groups ) as trainer_with_finder :
400- trainer_with_finder .run (dataloader )
401+ trainer_with_finder .run (dataloader_plot )
401402
402403 sug_lr = lr_finder .lr_suggestion ()
403404 lr_finder .apply_suggested_lr (optimizer_multiple_param_groups )
0 commit comments