radam.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. import math
  2. import torch
  3. from torch import Tensor
  4. from .optimizer import Optimizer
  5. from typing import List, Optional
  6. class RAdam(Optimizer):
  7. r"""Implements RAdam algorithm.
  8. .. math::
  9. \begin{aligned}
  10. &\rule{110mm}{0.4pt} \\
  11. &\textbf{input} : \gamma \text{ (lr)}, \: \beta_1, \beta_2
  12. \text{ (betas)}, \: \theta_0 \text{ (params)}, \:f(\theta) \text{ (objective)}, \:
  13. \lambda \text{ (weightdecay)}, \\
  14. &\hspace{13mm} \epsilon \text{ (epsilon)} \\
  15. &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
  16. v_0 \leftarrow 0 \text{ ( second moment)}, \\
  17. &\hspace{18mm} \rho_{\infty} \leftarrow 2/(1-\beta_2) -1 \\[-1.ex]
  18. &\rule{110mm}{0.4pt} \\
  19. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  20. &\hspace{6mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
  21. &\hspace{5mm} \textbf{if} \: \lambda \neq 0 \\
  22. &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
  23. &\hspace{6mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
  24. &\hspace{6mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
  25. &\hspace{6mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
  26. &\hspace{6mm}\rho_t \leftarrow \rho_{\infty} -
  27. 2 t \beta^t_2 /\big(1-\beta_2^t \big) \\[0.1.ex]
  28. &\hspace{6mm}\textbf{if} \: \rho_t > 5 \\
  29. &\hspace{12mm} l_t \leftarrow \sqrt{ (1-\beta^t_2) / \big( v_t +\epsilon \big) } \\
  30. &\hspace{12mm} r_t \leftarrow
  31. \sqrt{\frac{(\rho_t-4)(\rho_t-2)\rho_{\infty}}{(\rho_{\infty}-4)(\rho_{\infty}-2) \rho_t}} \\
  32. &\hspace{12mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t} r_t l_t \\
  33. &\hspace{6mm}\textbf{else} \\
  34. &\hspace{12mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t} \\
  35. &\rule{110mm}{0.4pt} \\[-1.ex]
  36. &\bf{return} \: \theta_t \\[-1.ex]
  37. &\rule{110mm}{0.4pt} \\[-1.ex]
  38. \end{aligned}
  39. For further details regarding the algorithm we refer to `On the variance of the adaptive learning rate and beyond`_.
  40. Args:
  41. params (iterable): iterable of parameters to optimize or dicts defining
  42. parameter groups
  43. lr (float, optional): learning rate (default: 1e-3)
  44. betas (Tuple[float, float], optional): coefficients used for computing
  45. running averages of gradient and its square (default: (0.9, 0.999))
  46. eps (float, optional): term added to the denominator to improve
  47. numerical stability (default: 1e-8)
  48. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  49. foreach (bool, optional): whether foreach implementation of optimizer
  50. is used (default: None)
  51. .. _On the variance of the adaptive learning rate and beyond:
  52. https://arxiv.org/abs/1908.03265
  53. """
  54. def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
  55. weight_decay=0, foreach: Optional[bool] = None):
  56. if not 0.0 <= lr:
  57. raise ValueError("Invalid learning rate: {}".format(lr))
  58. if not 0.0 <= eps:
  59. raise ValueError("Invalid epsilon value: {}".format(eps))
  60. if not 0.0 <= betas[0] < 1.0:
  61. raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
  62. if not 0.0 <= betas[1] < 1.0:
  63. raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
  64. if not 0.0 <= weight_decay:
  65. raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
  66. defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
  67. foreach=foreach)
  68. super(RAdam, self).__init__(params, defaults)
  69. def __setstate__(self, state):
  70. super().__setstate__(state)
  71. for group in self.param_groups:
  72. group.setdefault('foreach', None)
  73. state_values = list(self.state.values())
  74. step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step'])
  75. if not step_is_tensor:
  76. for s in state_values:
  77. s['step'] = torch.tensor(float(s['step']))
  78. @torch.no_grad()
  79. def step(self, closure=None):
  80. """Performs a single optimization step.
  81. Args:
  82. closure (callable, optional): A closure that reevaluates the model
  83. and returns the loss.
  84. """
  85. loss = None
  86. if closure is not None:
  87. with torch.enable_grad():
  88. loss = closure()
  89. for group in self.param_groups:
  90. params_with_grad = []
  91. grads = []
  92. exp_avgs = []
  93. exp_avg_sqs = []
  94. state_steps = []
  95. beta1, beta2 = group['betas']
  96. for p in group['params']:
  97. if p.grad is not None:
  98. params_with_grad.append(p)
  99. if p.grad.is_sparse:
  100. raise RuntimeError('RAdam does not support sparse gradients')
  101. grads.append(p.grad)
  102. state = self.state[p]
  103. # Lazy state initialization
  104. if len(state) == 0:
  105. state['step'] = torch.tensor(0.)
  106. # Exponential moving average of gradient values
  107. state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
  108. # Exponential moving average of squared gradient values
  109. state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
  110. exp_avgs.append(state['exp_avg'])
  111. exp_avg_sqs.append(state['exp_avg_sq'])
  112. state_steps.append(state['step'])
  113. radam(params_with_grad,
  114. grads,
  115. exp_avgs,
  116. exp_avg_sqs,
  117. state_steps,
  118. beta1=beta1,
  119. beta2=beta2,
  120. lr=group['lr'],
  121. weight_decay=group['weight_decay'],
  122. eps=group['eps'],
  123. foreach=group['foreach'])
  124. return loss
  125. def radam(params: List[Tensor],
  126. grads: List[Tensor],
  127. exp_avgs: List[Tensor],
  128. exp_avg_sqs: List[Tensor],
  129. state_steps: List[Tensor],
  130. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  131. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  132. foreach: bool = None,
  133. *,
  134. beta1: float,
  135. beta2: float,
  136. lr: float,
  137. weight_decay: float,
  138. eps: float):
  139. r"""Functional API that performs RAdam algorithm computation.
  140. See :class:`~torch.optim.RAdam` for details.
  141. """
  142. if not all([isinstance(t, torch.Tensor) for t in state_steps]):
  143. raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")
  144. if foreach is None:
  145. # Placeholder for more complex foreach logic to be added when value is not set
  146. foreach = False
  147. if foreach and torch.jit.is_scripting():
  148. raise RuntimeError('torch.jit.script not supported with foreach optimizers')
  149. if foreach and not torch.jit.is_scripting():
  150. func = _multi_tensor_radam
  151. else:
  152. func = _single_tensor_radam
  153. func(params,
  154. grads,
  155. exp_avgs,
  156. exp_avg_sqs,
  157. state_steps,
  158. beta1=beta1,
  159. beta2=beta2,
  160. lr=lr,
  161. weight_decay=weight_decay,
  162. eps=eps)
  163. def _single_tensor_radam(params: List[Tensor],
  164. grads: List[Tensor],
  165. exp_avgs: List[Tensor],
  166. exp_avg_sqs: List[Tensor],
  167. state_steps: List[Tensor],
  168. *,
  169. beta1: float,
  170. beta2: float,
  171. lr: float,
  172. weight_decay: float,
  173. eps: float):
  174. for i, param in enumerate(params):
  175. grad = grads[i]
  176. exp_avg = exp_avgs[i]
  177. exp_avg_sq = exp_avg_sqs[i]
  178. step_t = state_steps[i]
  179. # update step
  180. step_t += 1
  181. step = step_t.item()
  182. bias_correction1 = 1 - beta1 ** step
  183. bias_correction2 = 1 - beta2 ** step
  184. if weight_decay != 0:
  185. grad = grad.add(param, alpha=weight_decay)
  186. # Decay the first and second moment running average coefficient
  187. exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
  188. exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
  189. # correcting bias for the first moving moment
  190. bias_corrected_exp_avg = exp_avg / bias_correction1
  191. # maximum length of the approximated SMA
  192. rho_inf = 2 / (1 - beta2) - 1
  193. # compute the length of the approximated SMA
  194. rho_t = rho_inf - 2 * step * (beta2 ** step) / bias_correction2
  195. if rho_t > 5.:
  196. # Compute the variance rectification term and update parameters accordingly
  197. rect = math.sqrt((rho_t - 4) * (rho_t - 2) * rho_inf / ((rho_inf - 4) * (rho_inf - 2) * rho_t))
  198. adaptive_lr = math.sqrt(bias_correction2) / exp_avg_sq.sqrt().add_(eps)
  199. param.add_(bias_corrected_exp_avg * lr * adaptive_lr * rect, alpha=-1.0)
  200. else:
  201. param.add_(bias_corrected_exp_avg * lr, alpha=-1.0)
  202. def _multi_tensor_radam(params: List[Tensor],
  203. grads: List[Tensor],
  204. exp_avgs: List[Tensor],
  205. exp_avg_sqs: List[Tensor],
  206. state_steps: List[Tensor],
  207. *,
  208. beta1: float,
  209. beta2: float,
  210. lr: float,
  211. weight_decay: float,
  212. eps: float):
  213. if len(params) == 0:
  214. return
  215. # Update steps
  216. torch._foreach_add_(state_steps, 1)
  217. # maximum length of the approximated SMA
  218. rho_inf = 2 / (1 - beta2) - 1
  219. # compute the length of the approximated SMA
  220. rho_t_list = [rho_inf - 2 * step.item() * (beta2 ** step.item()) / (1 - beta2 ** step.item()) for step in state_steps]
  221. bias_correction1 = [1 - beta1 ** step.item() for step in state_steps]
  222. bias_correction2 = [1 - beta2 ** step.item() for step in state_steps]
  223. if weight_decay != 0:
  224. torch._foreach_add_(grads, params, alpha=weight_decay)
  225. # Decay the first and second moment running average coefficient
  226. torch._foreach_mul_(exp_avgs, beta1)
  227. torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1)
  228. torch._foreach_mul_(exp_avg_sqs, beta2)
  229. torch._foreach_addcmul_(exp_avg_sqs, grads, grads, 1 - beta2)
  230. rect = [math.sqrt((rho_t - 4) * (rho_t - 2) * rho_inf / ((rho_inf - 4) * (rho_inf - 2) * rho_t))
  231. if rho_t > 5 else 0 for rho_t in rho_t_list]
  232. unrectified = [0 if rect > 0 else 1. for rect in rect]
  233. exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sqs)
  234. bias_correction_sqrt = [math.sqrt(bc) for bc in bias_correction2]
  235. denom = torch._foreach_div(exp_avg_sq_sqrt, bias_correction_sqrt)
  236. step_size = [(lr * rect / bc) * -1 for rect, bc in zip(rect, bias_correction1)]
  237. torch._foreach_addcdiv_(params, exp_avgs, denom, step_size)
  238. denom = [torch.ones_like(exp_av, memory_format=torch.preserve_format) for exp_av in exp_avgs]
  239. step_size = [(lr * rect / bc) * -1 for rect, bc in zip(unrectified, bias_correction1)]
  240. torch._foreach_addcdiv_(params, exp_avgs, denom, step_size)