-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbasic_steps_two.py
More file actions
122 lines (93 loc) · 4.28 KB
/
basic_steps_two.py
File metadata and controls
122 lines (93 loc) · 4.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import argparse
import random
import numpy as np
import logging
import datetime
import time
import os
import torch
import torch.optim as optim
import my_utils
import application.model as model
device_type = 'cpu' # alternatively this can be 'cuda' or 'mps' in case of a mac
device = torch.device(device_type)
torch.manual_seed(42) # in case you want to "seed" you experiments and ensure that they are the same in every run
# --- START initialize the dataset -- for this we can use the Flower framework
# the following dictionary defines the characteristics of the data partition, for this toy example we only have one client
# and therefore it does not really matter.
partition_args = {
'dataset_name': 'ylecun/mnist', ## Change this datasets for any dataset name available in https://huggingface.co/datasets.
'num_clients': 1,
'partition_method': 'IID', # TODO
'partition_target': 'label',
'alpha': None, # for dirichlet
'shard_size' : 10,
'num_shards_per_partition' : 10
}
trainloaders, valloaders, testloader, df_list = my_utils.execute_partition_and_plot(partition_args) ## note there is a single testloader
# --- END initialize the dataset -- for this we can use the Flower framework
# --- START initializing the model
input_dim=(16 * 4 * 4)
hidden_dims=[120, 84]
cut_layer_first = 2
cut_layer_second = 3
client_model_a = model.SimpleCNNMNIST(input_dim=input_dim, hidden_dims=hidden_dims, output_dim=10, first_cut=-1, last_cut=cut_layer_first) # this will return only the model part that starts from the first layer and ends at the third
server_model = model.SimpleCNNMNIST(input_dim=input_dim, hidden_dims=hidden_dims, output_dim=10, first_cut=cut_layer_first, last_cut=cut_layer_second) # return the second model part
client_model_c = model.SimpleCNNMNIST(input_dim=input_dim, hidden_dims=hidden_dims, output_dim=10, first_cut=cut_layer_second, last_cut=-1) # return the third model part
# in the application folder we will show a faster way to define the models for all entinies --> using the get_model functions
# ---- END initializing the model
# initialize optimizers
client_optimizer_a = optim.SGD(client_model_a.parameters(), lr=0.05, weight_decay=0)
server_optimizer = optim.SGD(server_model.parameters(), lr=0.05, weight_decay=0)
client_optimizer_c = optim.SGD(client_model_c.parameters(), lr=0.05, weight_decay=0)
myiter = iter(trainloaders[0])
total_batch = len(trainloaders)
batch_iter = 0
# let' start training
client_model_a.train()
server_model.train()
client_model_c.train()
client_model_a.to(device)
server_model.to(device)
client_model_c.to(device)
criterion = torch.nn.CrossEntropyLoss()
epoch_loss = 0
while batch_iter < total_batch:
batch_iter += 1
# client side -- forward propagation model part-a
batch = next(myiter)
key_ = 'image'
label_ = 'label'
inputs, labels = batch[key_], batch[label_]
inputs = inputs.to(device)
labels = labels.to(device)
client_optimizer_a.zero_grad()
client_optimizer_c.zero_grad()
my_outa = client_model_a(inputs)
my_outa.requires_grad_(True)
det_out_a = my_outa.clone().detach().requires_grad_(True) # this are the intermediate activations client to server
det_out_a.to(device)
# server side forward propagation part-b
server_optimizer.zero_grad()
my_outb = server_model(det_out_a)
my_outb.requires_grad_(True)
det_out_b = my_outb.clone().detach().requires_grad_(True) # this are the intermediate activations server to client
det_out_b.to(device)
# client side -- forward propagation model part c
out = client_model_c(det_out_b)
# client backward propagation part-c
loss = criterion(out, labels)
epoch_loss += loss.item()
loss.backward()
client_optimizer_c.step()
grad_b = det_out_b.grad.clone().detach() # compute the intermediate gradients for the server
grad_b.to(device)
# server side -- backward propagation part-b
my_outb.backward(grad_b) # take gradients from helper
server_optimizer.step()
grad_a = det_out_a.grad.clone().detach() # compute the intermediate gradients for the client
grad_a.to(device)
# client side -- backward propagation part-a
my_outa.backward(grad_a) # take gradients from helper
client_optimizer_a.step()
print(f'The total loss of the epoch is {epoch_loss/total_batch}')