nadam.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. import math
  2. import torch
  3. from torch import Tensor
  4. from .optimizer import Optimizer
  5. from typing import List, Optional
  6. class NAdam(Optimizer):
  7. r"""Implements NAdam algorithm.
  8. .. math::
  9. \begin{aligned}
  10. &\rule{110mm}{0.4pt} \\
  11. &\textbf{input} : \gamma_t \text{ (lr)}, \: \beta_1,\beta_2 \text{ (betas)},
  12. \: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)} \\
  13. &\hspace{13mm} \: \lambda \text{ (weight decay)}, \:\psi \text{ (momentum decay)} \\
  14. &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
  15. v_0 \leftarrow 0 \text{ ( second moment)} \\[-1.ex]
  16. &\rule{110mm}{0.4pt} \\
  17. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  18. &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
  19. &\hspace{5mm}if \: \lambda \neq 0 \\
  20. &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
  21. &\hspace{5mm} \mu_t \leftarrow \beta_1 \big(1 - \frac{1}{2} 0.96^{t \psi} \big) \\
  22. &\hspace{5mm} \mu_{t+1} \leftarrow \beta_1 \big(1 - \frac{1}{2} 0.96^{(t+1)\psi}\big)\\
  23. &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
  24. &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
  25. &\hspace{5mm}\widehat{m_t} \leftarrow \mu_{t+1} m_t/(1-\prod_{i=1}^{t+1}\mu_i)\\[-1.ex]
  26. & \hspace{11mm} + (1-\mu_t) g_t /(1-\prod_{i=1}^{t} \mu_{i}) \\
  27. &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
  28. &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
  29. \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
  30. &\rule{110mm}{0.4pt} \\[-1.ex]
  31. &\bf{return} \: \theta_t \\[-1.ex]
  32. &\rule{110mm}{0.4pt} \\[-1.ex]
  33. \end{aligned}
  34. For further details regarding the algorithm we refer to `Incorporating Nesterov Momentum into Adam`_.
  35. Args:
  36. params (iterable): iterable of parameters to optimize or dicts defining
  37. parameter groups
  38. lr (float, optional): learning rate (default: 2e-3)
  39. betas (Tuple[float, float], optional): coefficients used for computing
  40. running averages of gradient and its square (default: (0.9, 0.999))
  41. eps (float, optional): term added to the denominator to improve
  42. numerical stability (default: 1e-8)
  43. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  44. momentum_decay (float, optional): momentum momentum_decay (default: 4e-3)
  45. foreach (bool, optional): whether foreach implementation of optimizer
  46. is used (default: None)
  47. .. _Incorporating Nesterov Momentum into Adam:
  48. https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ
  49. """
  50. def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8,
  51. weight_decay=0, momentum_decay=4e-3, foreach: Optional[bool] = None):
  52. if not 0.0 <= lr:
  53. raise ValueError("Invalid learning rate: {}".format(lr))
  54. if not 0.0 <= eps:
  55. raise ValueError("Invalid epsilon value: {}".format(eps))
  56. if not 0.0 <= betas[0] < 1.0:
  57. raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
  58. if not 0.0 <= betas[1] < 1.0:
  59. raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
  60. if not 0.0 <= weight_decay:
  61. raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
  62. if not 0.0 <= momentum_decay:
  63. raise ValueError("Invalid momentum_decay value: {}".format(momentum_decay))
  64. defaults = dict(lr=lr, betas=betas, eps=eps,
  65. weight_decay=weight_decay, momentum_decay=momentum_decay,
  66. foreach=foreach)
  67. super(NAdam, self).__init__(params, defaults)
  68. def __setstate__(self, state):
  69. super().__setstate__(state)
  70. for group in self.param_groups:
  71. group.setdefault('foreach', None)
  72. state_values = list(self.state.values())
  73. step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step'])
  74. if not step_is_tensor:
  75. for s in state_values:
  76. s['step'] = torch.tensor(float(s['step']))
  77. mu_product_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['mu_product'])
  78. if not mu_product_is_tensor:
  79. for s in state_values:
  80. s['mu_product'] = torch.tensor(s['mu_product'])
  81. @torch.no_grad()
  82. def step(self, closure=None):
  83. """Performs a single optimization step.
  84. Args:
  85. closure (callable, optional): A closure that reevaluates the model
  86. and returns the loss.
  87. """
  88. loss = None
  89. if closure is not None:
  90. with torch.enable_grad():
  91. loss = closure()
  92. for group in self.param_groups:
  93. params_with_grad = []
  94. grads = []
  95. exp_avgs = []
  96. exp_avg_sqs = []
  97. mu_products = []
  98. state_steps = []
  99. beta1, beta2 = group['betas']
  100. for p in group['params']:
  101. if p.grad is not None:
  102. params_with_grad.append(p)
  103. if p.grad.is_sparse:
  104. raise RuntimeError('NAdam does not support sparse gradients')
  105. grads.append(p.grad)
  106. state = self.state[p]
  107. # Lazy state initialization
  108. if len(state) == 0:
  109. state['step'] = torch.tensor(0.)
  110. state['mu_product'] = torch.tensor(1.)
  111. # Exponential moving average of gradient values
  112. state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
  113. # Exponential moving average of squared gradient values
  114. state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
  115. exp_avgs.append(state['exp_avg'])
  116. exp_avg_sqs.append(state['exp_avg_sq'])
  117. mu_products.append(state['mu_product'])
  118. state_steps.append(state['step'])
  119. nadam(params_with_grad,
  120. grads,
  121. exp_avgs,
  122. exp_avg_sqs,
  123. mu_products,
  124. state_steps,
  125. beta1=beta1,
  126. beta2=beta2,
  127. lr=group['lr'],
  128. weight_decay=group['weight_decay'],
  129. momentum_decay=group['momentum_decay'],
  130. eps=group['eps'],
  131. foreach=group['foreach'])
  132. return loss
  133. def nadam(params: List[Tensor],
  134. grads: List[Tensor],
  135. exp_avgs: List[Tensor],
  136. exp_avg_sqs: List[Tensor],
  137. mu_products: List[Tensor],
  138. state_steps: List[Tensor],
  139. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  140. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  141. foreach: bool = None,
  142. *,
  143. beta1: float,
  144. beta2: float,
  145. lr: float,
  146. weight_decay: float,
  147. momentum_decay: float,
  148. eps: float):
  149. r"""Functional API that performs NAdam algorithm computation.
  150. See :class:`~torch.optim.NAdam` for details.
  151. """
  152. if not all([isinstance(t, torch.Tensor) for t in state_steps]):
  153. raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")
  154. if not all([isinstance(t, torch.Tensor) for t in mu_products]):
  155. raise RuntimeError("API has changed, `mu_products` argument must contain a list of singleton tensors")
  156. if foreach is None:
  157. # Placeholder for more complex foreach logic to be added when value is not set
  158. foreach = False
  159. if foreach and torch.jit.is_scripting():
  160. raise RuntimeError('torch.jit.script not supported with foreach optimizers')
  161. if foreach and not torch.jit.is_scripting():
  162. func = _multi_tensor_nadam
  163. else:
  164. func = _single_tensor_nadam
  165. func(params,
  166. grads,
  167. exp_avgs,
  168. exp_avg_sqs,
  169. mu_products,
  170. state_steps,
  171. beta1=beta1,
  172. beta2=beta2,
  173. lr=lr,
  174. weight_decay=weight_decay,
  175. momentum_decay=momentum_decay,
  176. eps=eps)
  177. def _single_tensor_nadam(params: List[Tensor],
  178. grads: List[Tensor],
  179. exp_avgs: List[Tensor],
  180. exp_avg_sqs: List[Tensor],
  181. mu_products: List[Tensor],
  182. state_steps: List[Tensor],
  183. *,
  184. beta1: float,
  185. beta2: float,
  186. lr: float,
  187. weight_decay: float,
  188. momentum_decay: float,
  189. eps: float):
  190. for i, param in enumerate(params):
  191. grad = grads[i]
  192. exp_avg = exp_avgs[i]
  193. exp_avg_sq = exp_avg_sqs[i]
  194. mu_product = mu_products[i]
  195. step_t = state_steps[i]
  196. # update step
  197. step_t += 1
  198. step = step_t.item()
  199. bias_correction2 = 1 - beta2 ** step
  200. if weight_decay != 0:
  201. grad = grad.add(param, alpha=weight_decay)
  202. # calculate the momentum cache \mu^{t} and \mu^{t+1}
  203. mu = beta1 * (1. - 0.5 * (0.96 ** (step * momentum_decay)))
  204. mu_next = beta1 * (1. - 0.5 * (0.96 ** ((step + 1) * momentum_decay)))
  205. # update mu_product
  206. mu_product *= mu
  207. mu_product_next = mu_product * mu * mu_next
  208. # decay the first and second moment running average coefficient
  209. exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
  210. exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
  211. denom = exp_avg_sq.div(bias_correction2).sqrt().add_(eps)
  212. param.addcdiv_(grad, denom, value=-lr * (1. - mu) / (1. - mu_product.item()))
  213. param.addcdiv_(exp_avg, denom, value=-lr * mu_next / (1. - mu_product_next.item()))
  214. def _multi_tensor_nadam(params: List[Tensor],
  215. grads: List[Tensor],
  216. exp_avgs: List[Tensor],
  217. exp_avg_sqs: List[Tensor],
  218. mu_products: List[Tensor],
  219. state_steps: List[Tensor],
  220. *,
  221. beta1: float,
  222. beta2: float,
  223. lr: float,
  224. weight_decay: float,
  225. momentum_decay: float,
  226. eps: float):
  227. if len(params) == 0:
  228. return
  229. # update steps
  230. torch._foreach_add_(state_steps, 1)
  231. bias_correction1 = [1 - beta1 ** step.item() for step in state_steps]
  232. bias_correction2 = [1 - beta2 ** step.item() for step in state_steps]
  233. mus = [beta1 * (1. - 0.5 * (0.96 ** (step.item() * momentum_decay))) for step in state_steps]
  234. mu_nexts = [beta1 * (1. - 0.5 * (0.96 ** ((step.item() + 1) * momentum_decay)))
  235. for step in state_steps]
  236. # update mu_products
  237. torch._foreach_mul_(mu_products, mus)
  238. if weight_decay != 0:
  239. torch._foreach_add_(grads, params, alpha=weight_decay)
  240. # Decay the first and second moment running average coefficient
  241. torch._foreach_mul_(exp_avgs, beta1)
  242. torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1)
  243. torch._foreach_mul_(exp_avg_sqs, beta2)
  244. torch._foreach_addcmul_(exp_avg_sqs, grads, grads, 1 - beta2)
  245. exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sqs)
  246. bias_correction_sqrt = [math.sqrt(bc) for bc in bias_correction2]
  247. torch._foreach_div_(exp_avg_sq_sqrt, bias_correction_sqrt)
  248. denom = torch._foreach_add(exp_avg_sq_sqrt, eps)
  249. step_size_grads = [(lr * (1. - mu) / (1. - mu_product.item())) * -1
  250. for mu_product, mu in zip(mu_products, mus)]
  251. step_size_expavg = [(lr * mu_next / (1. - mu_product.item() * mu_next)) * -1
  252. for mu_product, mu_next in zip(mu_products, mu_nexts)]
  253. torch._foreach_addcdiv_(params, grads, denom, step_size_grads)
  254. torch._foreach_addcdiv_(params, exp_avgs, denom, step_size_expavg)