-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathnull_task_distribution.py
More file actions
97 lines (80 loc) · 2.52 KB
/
null_task_distribution.py
File metadata and controls
97 lines (80 loc) · 2.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import torch.nn.functional as F
import numpy as np
import os
class Net(torch.nn.Module):
def __init__(self,S=7):
super(Net, self).__init__()
self.fc1 = torch.nn.Linear(S*S, S*S)
self.fc2 = torch.nn.Linear(S*S, S*S)
self.fc3 = torch.nn.Linear(S*S, S*S)
def forward(self,x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return F.sigmoid(x)
network=Net()
if os.path.isfile("task_generator.pt"):
network.load_state_dict(torch.load("task_generator.pt"))
network.eval()
#Sample single board from null task distribution.
def gibbs_sample(S=7,numSweeps=20,network=network):
M=np.random.choice([0,1],size=49,replace=True)
idxs=np.arange(S*S)
for sweep in range(numSweeps):
np.random.shuffle(idxs)
for i in idxs:
M_eval=M.copy()
M_eval[i]=-1
masked=torch.from_numpy(M_eval).float()
preds=network(masked).detach().numpy()
if np.random.rand()<preds[i]:
M[i]=1
else:
M[i]=0
return M
#Vectorized form of above function to produce many boards at the same time.
def batch_gibbs(S=7,numSweeps=20,batch_size=500,network=network):
M=torch.from_numpy(np.random.choice([0,1],size=S*S*batch_size,replace=True).reshape((batch_size,S*S))).float()
idxs=np.arange(S*S)
for sweep in range(numSweeps):
np.random.shuffle(idxs)
for i in idxs:
M_eval=M.clone().detach()
M_eval[:,i]=-1
preds=network(M_eval)[:,i]
r=torch.rand(batch_size)
M[r<preds,i]=1
M[r>=preds,i]=0
return M.detach().numpy()
#If called as main, trains the network.
if __name__=='__main__':
from grid_env import *
S=7
net=Net(S=S)
lr=0.0002
configurations=np.asarray(list(set([tuple(generate_grid('all')[0].flatten()) for _ in range(100000)])))
optimizer=torch.optim.Adam(net.parameters())
criterion = torch.nn.BCELoss()
true=configurations.copy()
for epoch in range(10000):
np.random.shuffle(true)
num_changes=1
change_idxs=np.zeros(true.shape).astype('bool')
for i in range(true.shape[0]):
idxs=np.random.choice(np.arange(true.shape[1]),size=num_changes,replace=False)
change_idxs[i,idxs]=True
data=true.copy()
data[change_idxs]=-1
masked=torch.from_numpy(data).float()
labels=torch.from_numpy(true).float()
tensor_idxs=torch.from_numpy(change_idxs).bool()
preds=net(masked)
loss=criterion(preds[tensor_idxs],labels[tensor_idxs])
y_hat=(preds[tensor_idxs]>0.5).int().numpy()
y=labels[tensor_idxs].int().numpy()
print(epoch,loss,np.sum(y_hat==y)/y_hat.shape[0])
net.zero_grad()
loss.backward()
optimizer.step()
torch.save(net.state_dict(),"task_generator.pt")