Skip to content

Commit b8a8ecf

Browse files
Dave Berenbaumshcheklein
andauthored
Adds lightning fabric integration (#749)
* wip lightning fabric logger * add fabric tests * lightning: fix test_lightning_val_updates_to_studio * lightning: move next_step logic to fabric * add sync method and use in lightning/fabric callbacks * fabric: auto-increment step * add fabric example notebook * update fabric notebook * fix mypy errors * skip fabric tests if not installed * fix(project): proper lighting extra name reference * fix(mypy): remove options that is default now * add back nox * fix linting issues * make torch import optional --------- Co-authored-by: Ivan Shcheklein <[email protected]>
1 parent a7b7f55 commit b8a8ecf

File tree

8 files changed

+595
-94
lines changed

8 files changed

+595
-94
lines changed

examples/DVCLive-Fabric.ipynb

Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {
6+
"id": "QKSE19fW_Dnj"
7+
},
8+
"source": [
9+
"# DVCLive and Lightning Fabric"
10+
]
11+
},
12+
{
13+
"cell_type": "markdown",
14+
"metadata": {
15+
"id": "q-C_4R_o_QGG"
16+
},
17+
"source": [
18+
"## Install dvclive"
19+
]
20+
},
21+
{
22+
"cell_type": "code",
23+
"execution_count": null,
24+
"metadata": {
25+
"colab": {
26+
"base_uri": "https://localhost:8080/"
27+
},
28+
"id": "-XFbvwq7TSwN",
29+
"outputId": "15d0e3b5-bb4a-4b3e-d37f-21608d1822ed"
30+
},
31+
"outputs": [],
32+
"source": [
33+
"!pip install \"dvclive[lightning]\""
34+
]
35+
},
36+
{
37+
"cell_type": "markdown",
38+
"metadata": {
39+
"id": "I6S6Uru1_Y0x"
40+
},
41+
"source": [
42+
"## Initialize DVC Repository"
43+
]
44+
},
45+
{
46+
"cell_type": "code",
47+
"execution_count": null,
48+
"metadata": {
49+
"colab": {
50+
"base_uri": "https://localhost:8080/"
51+
},
52+
"id": "WcbvUl2uTV0y",
53+
"outputId": "aff9740c-26db-483d-ce30-cfef395f3cbb"
54+
},
55+
"outputs": [],
56+
"source": [
57+
"!git init -q\n",
58+
"!git config --local user.email \"[email protected]\"\n",
59+
"!git config --local user.name \"Your Name\"\n",
60+
"!dvc init -q\n",
61+
"!git commit -m \"DVC init\""
62+
]
63+
},
64+
{
65+
"cell_type": "markdown",
66+
"metadata": {
67+
"id": "LmY4PLMh_cUk"
68+
},
69+
"source": [
70+
"## Imports"
71+
]
72+
},
73+
{
74+
"cell_type": "code",
75+
"execution_count": null,
76+
"metadata": {
77+
"id": "85qErT5yTEbN"
78+
},
79+
"outputs": [],
80+
"source": [
81+
"import argparse\n",
82+
"from os import path\n",
83+
"from types import SimpleNamespace\n",
84+
"\n",
85+
"import torch\n",
86+
"import torch.nn as nn\n",
87+
"import torch.nn.functional as F\n",
88+
"import torch.optim as optim\n",
89+
"import torchvision.transforms as T\n",
90+
"from lightning.fabric import Fabric, seed_everything\n",
91+
"from lightning.fabric.utilities.rank_zero import rank_zero_only\n",
92+
"from torch.optim.lr_scheduler import StepLR\n",
93+
"from torchmetrics.classification import Accuracy\n",
94+
"from torchvision.datasets import MNIST\n",
95+
"\n",
96+
"from dvclive.fabric import DVCLiveLogger\n",
97+
"\n",
98+
"DATASETS_PATH = (\"Datasets\")"
99+
]
100+
},
101+
{
102+
"cell_type": "markdown",
103+
"metadata": {
104+
"id": "UrmAHbhr_lgs"
105+
},
106+
"source": [
107+
"## Setup model code\n",
108+
"\n",
109+
"Adapted from https://github.com/Lightning-AI/pytorch-lightning/blob/master/examples/fabric/image_classifier/train_fabric.py.\n",
110+
"\n",
111+
"Look for the `logger` statements where DVCLiveLogger calls were added."
112+
]
113+
},
114+
{
115+
"cell_type": "code",
116+
"execution_count": null,
117+
"metadata": {
118+
"id": "UCzTygUnTHM8"
119+
},
120+
"outputs": [],
121+
"source": [
122+
"class Net(nn.Module):\n",
123+
" def __init__(self) -> None:\n",
124+
" super().__init__()\n",
125+
" self.conv1 = nn.Conv2d(1, 32, 3, 1)\n",
126+
" self.conv2 = nn.Conv2d(32, 64, 3, 1)\n",
127+
" self.dropout1 = nn.Dropout(0.25)\n",
128+
" self.dropout2 = nn.Dropout(0.5)\n",
129+
" self.fc1 = nn.Linear(9216, 128)\n",
130+
" self.fc2 = nn.Linear(128, 10)\n",
131+
"\n",
132+
" def forward(self, x):\n",
133+
" x = self.conv1(x)\n",
134+
" x = F.relu(x)\n",
135+
" x = self.conv2(x)\n",
136+
" x = F.relu(x)\n",
137+
" x = F.max_pool2d(x, 2)\n",
138+
" x = self.dropout1(x)\n",
139+
" x = torch.flatten(x, 1)\n",
140+
" x = self.fc1(x)\n",
141+
" x = F.relu(x)\n",
142+
" x = self.dropout2(x)\n",
143+
" x = self.fc2(x)\n",
144+
" return F.log_softmax(x, dim=1)\n",
145+
"\n",
146+
"\n",
147+
"def run(hparams):\n",
148+
" # Create the DVCLive Logger\n",
149+
" logger = DVCLiveLogger(report=\"notebook\")\n",
150+
"\n",
151+
" # Log dict of hyperparameters\n",
152+
" logger.log_hyperparams(hparams.__dict__)\n",
153+
"\n",
154+
" # Create the Lightning Fabric object. The parameters like accelerator, strategy, devices etc. will be proided\n",
155+
" # by the command line. See all options: `lightning run model --help`\n",
156+
" fabric = Fabric()\n",
157+
"\n",
158+
" seed_everything(hparams.seed) # instead of torch.manual_seed(...)\n",
159+
"\n",
160+
" transform = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])\n",
161+
"\n",
162+
" # Let rank 0 download the data first, then everyone will load MNIST\n",
163+
" with fabric.rank_zero_first(local=False): # set `local=True` if your filesystem is not shared between machines\n",
164+
" train_dataset = MNIST(DATASETS_PATH, download=fabric.is_global_zero, train=True, transform=transform)\n",
165+
" test_dataset = MNIST(DATASETS_PATH, download=fabric.is_global_zero, train=False, transform=transform)\n",
166+
"\n",
167+
" train_loader = torch.utils.data.DataLoader(\n",
168+
" train_dataset,\n",
169+
" batch_size=hparams.batch_size,\n",
170+
" )\n",
171+
" test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=hparams.batch_size)\n",
172+
"\n",
173+
" # don't forget to call `setup_dataloaders` to prepare for dataloaders for distributed training.\n",
174+
" train_loader, test_loader = fabric.setup_dataloaders(train_loader, test_loader)\n",
175+
"\n",
176+
" model = Net() # remove call to .to(device)\n",
177+
" optimizer = optim.Adadelta(model.parameters(), lr=hparams.lr)\n",
178+
"\n",
179+
" # don't forget to call `setup` to prepare for model / optimizer for distributed training.\n",
180+
" # the model is moved automatically to the right device.\n",
181+
" model, optimizer = fabric.setup(model, optimizer)\n",
182+
"\n",
183+
" scheduler = StepLR(optimizer, step_size=1, gamma=hparams.gamma)\n",
184+
"\n",
185+
" # use torchmetrics instead of manually computing the accuracy\n",
186+
" test_acc = Accuracy(task=\"multiclass\", num_classes=10).to(fabric.device)\n",
187+
"\n",
188+
" # EPOCH LOOP\n",
189+
" for epoch in range(1, hparams.epochs + 1):\n",
190+
" # TRAINING LOOP\n",
191+
" model.train()\n",
192+
" for batch_idx, (data, target) in enumerate(train_loader):\n",
193+
" # NOTE: no need to call `.to(device)` on the data, target\n",
194+
" optimizer.zero_grad()\n",
195+
" output = model(data)\n",
196+
" loss = F.nll_loss(output, target)\n",
197+
" fabric.backward(loss) # instead of loss.backward()\n",
198+
"\n",
199+
" optimizer.step()\n",
200+
" if (batch_idx == 0) or ((batch_idx + 1) % hparams.log_interval == 0):\n",
201+
" print(\n",
202+
" \"Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}\".format(\n",
203+
" epoch,\n",
204+
" batch_idx * len(data),\n",
205+
" len(train_loader.dataset),\n",
206+
" 100.0 * batch_idx / len(train_loader),\n",
207+
" loss.item(),\n",
208+
" )\n",
209+
" )\n",
210+
"\n",
211+
" # Log dict of metrics\n",
212+
" logger.log_metrics({\"loss\": loss.item()})\n",
213+
"\n",
214+
" if hparams.dry_run:\n",
215+
" break\n",
216+
"\n",
217+
" scheduler.step()\n",
218+
"\n",
219+
" # TESTING LOOP\n",
220+
" model.eval()\n",
221+
" test_loss = 0\n",
222+
" with torch.no_grad():\n",
223+
" for data, target in test_loader:\n",
224+
" # NOTE: no need to call `.to(device)` on the data, target\n",
225+
" output = model(data)\n",
226+
" test_loss += F.nll_loss(output, target, reduction=\"sum\").item()\n",
227+
"\n",
228+
" # WITHOUT TorchMetrics\n",
229+
" # pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability\n",
230+
" # correct += pred.eq(target.view_as(pred)).sum().item()\n",
231+
"\n",
232+
" # WITH TorchMetrics\n",
233+
" test_acc(output, target)\n",
234+
"\n",
235+
" if hparams.dry_run:\n",
236+
" break\n",
237+
"\n",
238+
" # all_gather is used to aggregated the value across processes\n",
239+
" test_loss = fabric.all_gather(test_loss).sum() / len(test_loader.dataset)\n",
240+
"\n",
241+
" print(f\"\\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({100 * test_acc.compute():.0f}%)\\n\")\n",
242+
"\n",
243+
" # log additional metrics\n",
244+
" logger.log_metrics({\"test_loss\": test_loss, \"test_acc\": 100 * test_acc.compute()})\n",
245+
"\n",
246+
" test_acc.reset()\n",
247+
"\n",
248+
" if hparams.dry_run:\n",
249+
" break\n",
250+
"\n",
251+
" # When using distributed training, use `fabric.save`\n",
252+
" # to ensure the current process is allowed to save a checkpoint\n",
253+
" if hparams.save_model:\n",
254+
" fabric.save(\"mnist_cnn.pt\", model.state_dict())\n",
255+
"\n",
256+
" # `logger.experiment` provides access to the `dvclive.Live` instance where you can use additional logging methods.\n",
257+
" # Check that `rank_zero_only.rank == 0` to avoid logging in other processes.\n",
258+
" if rank_zero_only.rank == 0:\n",
259+
" logger.experiment.log_artifact(\"mnist_cnn.pt\")\n",
260+
"\n",
261+
" # Call finalize to save final results as a DVC experiment\n",
262+
" logger.finalize(\"success\")"
263+
]
264+
},
265+
{
266+
"cell_type": "markdown",
267+
"metadata": {
268+
"id": "o5_v9lRDAM7l"
269+
},
270+
"source": [
271+
"## Train the model"
272+
]
273+
},
274+
{
275+
"cell_type": "code",
276+
"execution_count": null,
277+
"metadata": {
278+
"colab": {
279+
"base_uri": "https://localhost:8080/",
280+
"height": 1000
281+
},
282+
"id": "BbCXen1PTM4V",
283+
"outputId": "b79c90eb-74cc-474d-c0dd-21245064bca8"
284+
},
285+
"outputs": [],
286+
"source": [
287+
"hparams = SimpleNamespace(batch_size=64, epochs=5, lr=1.0, gamma=0.7, dry_run=False, seed=1, log_interval=10, save_model=True)\n",
288+
"run(hparams)"
289+
]
290+
},
291+
{
292+
"cell_type": "code",
293+
"execution_count": null,
294+
"metadata": {
295+
"id": "DnqCrlbLAopV"
296+
},
297+
"outputs": [],
298+
"source": []
299+
}
300+
],
301+
"metadata": {
302+
"colab": {
303+
"provenance": []
304+
},
305+
"kernelspec": {
306+
"display_name": "Python 3",
307+
"name": "python3"
308+
},
309+
"language_info": {
310+
"name": "python"
311+
}
312+
},
313+
"nbformat": 4,
314+
"nbformat_minor": 0
315+
}

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ fastai = ["fastai"]
6868
lightning = ["lightning>=2.0", "torch"]
6969
optuna = ["optuna"]
7070
all = [
71-
"dvclive[image,mmcv,tf,xgb,lgbm,huggingface,catalyst,fastai,pytorch-lightning,optuna,plots,markdown]"
71+
"dvclive[image,mmcv,tf,xgb,lgbm,huggingface,catalyst,fastai,lightning,optuna,plots,markdown]"
7272
]
7373

7474
[project.urls]

0 commit comments

Comments
 (0)