-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathtrain.py
More file actions
78 lines (65 loc) · 2.01 KB
/
train.py
File metadata and controls
78 lines (65 loc) · 2.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from torch_geometric.datasets import ZINC
from tqdm import tqdm
import torch
from torch import nn
import math
import wandb
import os
from models import DiffusionOrderingNetwork, DenoisingNetwork
from utils import NodeMasking
from grapharm import GraphARM
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device {device}")
# instanciate the dataset
dataset = ZINC(root='./data/ZINC', transform=None, pre_transform=None)
diff_ord_net = DiffusionOrderingNetwork(node_feature_dim=1,
num_node_types=dataset.x.unique().shape[0],
num_edge_types=dataset.edge_attr.unique().shape[0],
num_layers=3,
out_channels=1,
device=device)
masker = NodeMasking(dataset)
denoising_net = DenoisingNetwork(
node_feature_dim=dataset.num_features,
edge_feature_dim=dataset.num_edge_features,
num_node_types=dataset.x.unique().shape[0],
num_edge_types=dataset.edge_attr.unique().shape[0],
num_layers=7,
# hidden_dim=32,
device=device
)
wandb.init(
project="GraphARM",
group=f"v2.3.1",
name=f"ZINC_GraphARM",
config={
"policy": "train",
"n_epochs": 10000,
"batch_size": 1,
"lr": 1e-3,
},
# mode='disabled'
)
torch.autograd.set_detect_anomaly(True)
grapharm = GraphARM(
dataset=dataset,
denoising_network=denoising_net,
diffusion_ordering_network=diff_ord_net,
device=device
)
batch_size = 5
dataset = dataset[0:5]
try:
grapharm.load_model()
print("Loaded model")
except:
print ("No model to load")
# train loop
for epoch in range(2000):
print(f"Epoch {epoch}")
grapharm.train_step(
train_batch=dataset[2*epoch*batch_size:(2*epoch + 1)*batch_size],
val_batch=dataset[(2*epoch + 1)*batch_size:batch_size*(2*epoch + 2)],
M=4
)
grapharm.save_model()