|
3 | 3 | import pytest |
4 | 4 | from fastai.tabular.all import ( |
5 | 5 | Categorify, |
6 | | - FillMissing, |
7 | 6 | Normalize, |
8 | 7 | TabularDataLoaders, |
9 | | - URLs, |
10 | 8 | accuracy, |
11 | 9 | tabular_learner, |
12 | | - untar_data, |
13 | 10 | ) |
14 | 11 |
|
15 | 12 | from dvclive.data.scalar import Scalar |
|
20 | 17 |
|
21 | 18 | @pytest.fixture |
22 | 19 | def data_loader(): |
23 | | - path = untar_data(URLs.ADULT_SAMPLE) |
24 | | - |
25 | | - dls = TabularDataLoaders.from_csv( |
26 | | - path / "adult.csv", |
27 | | - path=path, |
28 | | - y_names="salary", |
29 | | - cat_names=[ |
30 | | - "workclass", |
31 | | - "education", |
32 | | - "marital-status", |
33 | | - "occupation", |
34 | | - "relationship", |
35 | | - "race", |
36 | | - ], |
37 | | - cont_names=["age", "fnlwgt", "education-num"], |
38 | | - procs=[Categorify, FillMissing, Normalize], |
| 20 | + from pandas import DataFrame |
| 21 | + |
| 22 | + d = { |
| 23 | + "x1": [1, 1, 0, 0, 1, 1, 0, 0], |
| 24 | + "x2": [1, 0, 1, 0, 1, 0, 1, 0], |
| 25 | + "y": [1, 0, 0, 1, 1, 0, 0, 1], |
| 26 | + } |
| 27 | + df = DataFrame(d) |
| 28 | + xor_loader = TabularDataLoaders.from_df( |
| 29 | + df, |
| 30 | + valid_idx=[4, 5, 6, 7], |
| 31 | + batch_size=2, |
| 32 | + cont_names=["x1", "x2"], |
| 33 | + procs=[Categorify, Normalize], |
| 34 | + y_names="y", |
39 | 35 | ) |
40 | | - return dls |
| 36 | + return xor_loader |
41 | 37 |
|
42 | 38 |
|
43 | 39 | def test_fastai_callback(tmp_dir, data_loader): |
|
0 commit comments