sgd.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. import torch
  2. from torch import Tensor
  3. from .optimizer import Optimizer, required
  4. from typing import List, Optional
  5. class SGD(Optimizer):
  6. r"""Implements stochastic gradient descent (optionally with momentum).
  7. .. math::
  8. \begin{aligned}
  9. &\rule{110mm}{0.4pt} \\
  10. &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta)
  11. \text{ (objective)}, \: \lambda \text{ (weight decay)}, \\
  12. &\hspace{13mm} \:\mu \text{ (momentum)}, \:\tau \text{ (dampening)},
  13. \:\textit{ nesterov,}\:\textit{ maximize} \\[-1.ex]
  14. &\rule{110mm}{0.4pt} \\
  15. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  16. &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
  17. &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\
  18. &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
  19. &\hspace{5mm}\textbf{if} \: \mu \neq 0 \\
  20. &\hspace{10mm}\textbf{if} \: t > 1 \\
  21. &\hspace{15mm} \textbf{b}_t \leftarrow \mu \textbf{b}_{t-1} + (1-\tau) g_t \\
  22. &\hspace{10mm}\textbf{else} \\
  23. &\hspace{15mm} \textbf{b}_t \leftarrow g_t \\
  24. &\hspace{10mm}\textbf{if} \: \textit{nesterov} \\
  25. &\hspace{15mm} g_t \leftarrow g_{t} + \mu \textbf{b}_t \\
  26. &\hspace{10mm}\textbf{else} \\[-1.ex]
  27. &\hspace{15mm} g_t \leftarrow \textbf{b}_t \\
  28. &\hspace{5mm}\textbf{if} \: \textit{maximize} \\
  29. &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} + \gamma g_t \\[-1.ex]
  30. &\hspace{5mm}\textbf{else} \\[-1.ex]
  31. &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma g_t \\[-1.ex]
  32. &\rule{110mm}{0.4pt} \\[-1.ex]
  33. &\bf{return} \: \theta_t \\[-1.ex]
  34. &\rule{110mm}{0.4pt} \\[-1.ex]
  35. \end{aligned}
  36. Nesterov momentum is based on the formula from
  37. `On the importance of initialization and momentum in deep learning`__.
  38. Args:
  39. params (iterable): iterable of parameters to optimize or dicts defining
  40. parameter groups
  41. lr (float): learning rate
  42. momentum (float, optional): momentum factor (default: 0)
  43. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  44. dampening (float, optional): dampening for momentum (default: 0)
  45. nesterov (bool, optional): enables Nesterov momentum (default: False)
  46. maximize (bool, optional): maximize the params based on the objective, instead of
  47. minimizing (default: False)
  48. foreach (bool, optional): whether foreach implementation of optimizer
  49. is used (default: None)
  50. Example:
  51. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  52. >>> optimizer.zero_grad()
  53. >>> loss_fn(model(input), target).backward()
  54. >>> optimizer.step()
  55. __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
  56. .. note::
  57. The implementation of SGD with Momentum/Nesterov subtly differs from
  58. Sutskever et. al. and implementations in some other frameworks.
  59. Considering the specific case of Momentum, the update can be written as
  60. .. math::
  61. \begin{aligned}
  62. v_{t+1} & = \mu * v_{t} + g_{t+1}, \\
  63. p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
  64. \end{aligned}
  65. where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the
  66. parameters, gradient, velocity, and momentum respectively.
  67. This is in contrast to Sutskever et. al. and
  68. other frameworks which employ an update of the form
  69. .. math::
  70. \begin{aligned}
  71. v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\
  72. p_{t+1} & = p_{t} - v_{t+1}.
  73. \end{aligned}
  74. The Nesterov version is analogously modified.
  75. """
  76. def __init__(self, params, lr=required, momentum=0, dampening=0,
  77. weight_decay=0, nesterov=False, *, maximize=False, foreach: Optional[bool] = None):
  78. if lr is not required and lr < 0.0:
  79. raise ValueError("Invalid learning rate: {}".format(lr))
  80. if momentum < 0.0:
  81. raise ValueError("Invalid momentum value: {}".format(momentum))
  82. if weight_decay < 0.0:
  83. raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
  84. defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
  85. weight_decay=weight_decay, nesterov=nesterov,
  86. maximize=maximize, foreach=foreach)
  87. if nesterov and (momentum <= 0 or dampening != 0):
  88. raise ValueError("Nesterov momentum requires a momentum and zero dampening")
  89. super(SGD, self).__init__(params, defaults)
  90. def __setstate__(self, state):
  91. super().__setstate__(state)
  92. for group in self.param_groups:
  93. group.setdefault('nesterov', False)
  94. group.setdefault('maximize', False)
  95. group.setdefault('foreach', None)
  96. @torch.no_grad()
  97. def step(self, closure=None):
  98. """Performs a single optimization step.
  99. Args:
  100. closure (callable, optional): A closure that reevaluates the model
  101. and returns the loss.
  102. """
  103. loss = None
  104. if closure is not None:
  105. with torch.enable_grad():
  106. loss = closure()
  107. for group in self.param_groups:
  108. params_with_grad = []
  109. d_p_list = []
  110. momentum_buffer_list = []
  111. has_sparse_grad = False
  112. for p in group['params']:
  113. if p.grad is not None:
  114. params_with_grad.append(p)
  115. d_p_list.append(p.grad)
  116. if p.grad.is_sparse:
  117. has_sparse_grad = True
  118. state = self.state[p]
  119. if 'momentum_buffer' not in state:
  120. momentum_buffer_list.append(None)
  121. else:
  122. momentum_buffer_list.append(state['momentum_buffer'])
  123. sgd(params_with_grad,
  124. d_p_list,
  125. momentum_buffer_list,
  126. weight_decay=group['weight_decay'],
  127. momentum=group['momentum'],
  128. lr=group['lr'],
  129. dampening=group['dampening'],
  130. nesterov=group['nesterov'],
  131. maximize=group['maximize'],
  132. has_sparse_grad=has_sparse_grad,
  133. foreach=group['foreach'])
  134. # update momentum_buffers in state
  135. for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
  136. state = self.state[p]
  137. state['momentum_buffer'] = momentum_buffer
  138. return loss
  139. def sgd(params: List[Tensor],
  140. d_p_list: List[Tensor],
  141. momentum_buffer_list: List[Optional[Tensor]],
  142. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  143. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  144. has_sparse_grad: bool = None,
  145. foreach: bool = None,
  146. *,
  147. weight_decay: float,
  148. momentum: float,
  149. lr: float,
  150. dampening: float,
  151. nesterov: bool,
  152. maximize: bool):
  153. r"""Functional API that performs SGD algorithm computation.
  154. See :class:`~torch.optim.SGD` for details.
  155. """
  156. if foreach is None:
  157. # Placeholder for more complex foreach logic to be added when value is not set
  158. foreach = False
  159. if foreach and torch.jit.is_scripting():
  160. raise RuntimeError('torch.jit.script not supported with foreach optimizers')
  161. if foreach and not torch.jit.is_scripting():
  162. func = _multi_tensor_sgd
  163. else:
  164. func = _single_tensor_sgd
  165. func(params,
  166. d_p_list,
  167. momentum_buffer_list,
  168. weight_decay=weight_decay,
  169. momentum=momentum,
  170. lr=lr,
  171. dampening=dampening,
  172. nesterov=nesterov,
  173. has_sparse_grad=has_sparse_grad,
  174. maximize=maximize)
  175. def _single_tensor_sgd(params: List[Tensor],
  176. d_p_list: List[Tensor],
  177. momentum_buffer_list: List[Optional[Tensor]],
  178. *,
  179. weight_decay: float,
  180. momentum: float,
  181. lr: float,
  182. dampening: float,
  183. nesterov: bool,
  184. maximize: bool,
  185. has_sparse_grad: bool):
  186. for i, param in enumerate(params):
  187. d_p = d_p_list[i]
  188. if weight_decay != 0:
  189. d_p = d_p.add(param, alpha=weight_decay)
  190. if momentum != 0:
  191. buf = momentum_buffer_list[i]
  192. if buf is None:
  193. buf = torch.clone(d_p).detach()
  194. momentum_buffer_list[i] = buf
  195. else:
  196. buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
  197. if nesterov:
  198. d_p = d_p.add(buf, alpha=momentum)
  199. else:
  200. d_p = buf
  201. alpha = lr if maximize else -lr
  202. param.add_(d_p, alpha=alpha)
  203. def _multi_tensor_sgd(params: List[Tensor],
  204. grads: List[Tensor],
  205. momentum_buffer_list: List[Optional[Tensor]],
  206. *,
  207. weight_decay: float,
  208. momentum: float,
  209. lr: float,
  210. dampening: float,
  211. nesterov: bool,
  212. maximize: bool,
  213. has_sparse_grad: bool):
  214. if len(params) == 0:
  215. return
  216. if has_sparse_grad is None:
  217. has_sparse_grad = any([grad.is_sparse for grad in grads])
  218. if weight_decay != 0:
  219. grads = torch._foreach_add(grads, params, alpha=weight_decay)
  220. if momentum != 0:
  221. bufs = []
  222. all_states_with_momentum_buffer = True
  223. for i in range(len(momentum_buffer_list)):
  224. if momentum_buffer_list[i] is None:
  225. all_states_with_momentum_buffer = False
  226. break
  227. else:
  228. bufs.append(momentum_buffer_list[i])
  229. if all_states_with_momentum_buffer:
  230. torch._foreach_mul_(bufs, momentum)
  231. torch._foreach_add_(bufs, grads, alpha=1 - dampening)
  232. else:
  233. bufs = []
  234. for i in range(len(momentum_buffer_list)):
  235. if momentum_buffer_list[i] is None:
  236. buf = momentum_buffer_list[i] = torch.clone(grads[i]).detach()
  237. else:
  238. buf = momentum_buffer_list[i]
  239. buf.mul_(momentum).add_(grads[i], alpha=1 - dampening)
  240. bufs.append(buf)
  241. if nesterov:
  242. torch._foreach_add_(grads, bufs, alpha=momentum)
  243. else:
  244. grads = bufs
  245. alpha = lr if maximize else -lr
  246. if not has_sparse_grad:
  247. torch._foreach_add_(params, grads, alpha=alpha)
  248. else:
  249. # foreach APIs dont support sparse
  250. for i in range(len(params)):
  251. params[i].add_(grads[i], alpha=alpha)