-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathrsgd.py
More file actions
137 lines (125 loc) · 5.06 KB
/
rsgd.py
File metadata and controls
137 lines (125 loc) · 5.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch.optim.optimizer
from geoopt import ManifoldParameter, ManifoldTensor
from .mixin import OptimMixin
__all__ = ["RiemannianSGD"]
class RiemannianSGD(OptimMixin, torch.optim.Optimizer):
r"""
Riemannian Stochastic Gradient Descent with the same API as :class:`torch.optim.SGD`.
Parameters
----------
params : iterable
iterable of parameters to optimize or dicts defining
parameter groups
lr : float
learning rate
momentum : float (optional)
momentum factor (default: 0)
weight_decay : float (optional)
weight decay (L2 penalty) (default: 0)
dampening : float (optional)
dampening for momentum (default: 0)
nesterov : bool (optional)
enables Nesterov momentum (default: False)
Other Parameters
----------------
stabilize : int
Stabilize parameters if they are off-manifold due to numerical
reasons every ``stabilize`` steps (default: ``None`` -- no stabilize)
"""
def __init__(
self,
params,
lr,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
stabilize=None,
):
if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(
lr=lr,
momentum=momentum,
dampening=dampening,
weight_decay=weight_decay,
nesterov=nesterov,
)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super().__init__(params, defaults, stabilize=stabilize)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
with torch.no_grad():
for group in self.param_groups:
if "step" not in group:
group["step"] = 0
weight_decay = group["weight_decay"]
momentum = group["momentum"]
dampening = group["dampening"]
nesterov = group["nesterov"]
learning_rate = group["lr"]
group["step"] += 1
for point in group["params"]:
grad = point.grad
if grad is None:
continue
if grad.is_sparse:
raise RuntimeError(
"RiemannianSGD does not support sparse gradients, use SparseRiemannianSGD instead"
)
state = self.state[point]
# State initialization
if len(state) == 0:
if momentum > 0:
state["momentum_buffer"] = grad.clone()
if isinstance(point, (ManifoldParameter, ManifoldTensor)):
manifold = point.manifold
else:
manifold = self._default_manifold
grad.add_(point, alpha=weight_decay)
grad = manifold.egrad2rgrad(point, grad)
if momentum > 0:
momentum_buffer = state["momentum_buffer"]
momentum_buffer.mul_(momentum).add_(grad, alpha=1 - dampening)
if nesterov:
grad = grad.add_(momentum_buffer, alpha=momentum)
else:
grad = momentum_buffer
# we have all the things projected
new_point, new_momentum_buffer = manifold.retr_transp(
point, -learning_rate * grad, momentum_buffer
)
momentum_buffer.copy_(new_momentum_buffer)
# use copy only for user facing point
point.copy_(new_point)
else:
new_point = manifold.retr(point, -learning_rate * grad)
point.copy_(new_point)
if (
group["stabilize"] is not None
and group["step"] % group["stabilize"] == 0
):
self.stabilize_group(group)
return loss
@torch.no_grad()
def stabilize_group(self, group):
for p in group["params"]:
if not isinstance(p, (ManifoldParameter, ManifoldTensor)):
continue
manifold = p.manifold
momentum = group["momentum"]
p.copy_(manifold.projx(p))
if momentum > 0:
param_state = self.state[p]
if not param_state: # due to None grads
continue
if "momentum_buffer" in param_state:
buf = param_state["momentum_buffer"]
buf.copy_(manifold.proju(p, buf))