Skip to content

Commit 6ed04bb

Browse files
authored
integrations: fastai speed up tests (#301)
1 parent 56b3096 commit 6ed04bb

File tree

1 file changed

+16
-20
lines changed

1 file changed

+16
-20
lines changed

tests/test_fastai.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,10 @@
33
import pytest
44
from fastai.tabular.all import (
55
Categorify,
6-
FillMissing,
76
Normalize,
87
TabularDataLoaders,
9-
URLs,
108
accuracy,
119
tabular_learner,
12-
untar_data,
1310
)
1411

1512
from dvclive.data.scalar import Scalar
@@ -20,24 +17,23 @@
2017

2118
@pytest.fixture
2219
def data_loader():
23-
path = untar_data(URLs.ADULT_SAMPLE)
24-
25-
dls = TabularDataLoaders.from_csv(
26-
path / "adult.csv",
27-
path=path,
28-
y_names="salary",
29-
cat_names=[
30-
"workclass",
31-
"education",
32-
"marital-status",
33-
"occupation",
34-
"relationship",
35-
"race",
36-
],
37-
cont_names=["age", "fnlwgt", "education-num"],
38-
procs=[Categorify, FillMissing, Normalize],
20+
from pandas import DataFrame
21+
22+
d = {
23+
"x1": [1, 1, 0, 0, 1, 1, 0, 0],
24+
"x2": [1, 0, 1, 0, 1, 0, 1, 0],
25+
"y": [1, 0, 0, 1, 1, 0, 0, 1],
26+
}
27+
df = DataFrame(d)
28+
xor_loader = TabularDataLoaders.from_df(
29+
df,
30+
valid_idx=[4, 5, 6, 7],
31+
batch_size=2,
32+
cont_names=["x1", "x2"],
33+
procs=[Categorify, Normalize],
34+
y_names="y",
3935
)
40-
return dls
36+
return xor_loader
4137

4238

4339
def test_fastai_callback(tmp_dir, data_loader):

0 commit comments

Comments
 (0)