adamax.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. import torch
  2. from torch import Tensor
  3. from .optimizer import Optimizer
  4. from typing import List, Optional
  5. class Adamax(Optimizer):
  6. r"""Implements Adamax algorithm (a variant of Adam based on infinity norm).
  7. .. math::
  8. \begin{aligned}
  9. &\rule{110mm}{0.4pt} \\
  10. &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2
  11. \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)},
  12. \: \lambda \text{ (weight decay)}, \\
  13. &\hspace{13mm} \epsilon \text{ (epsilon)} \\
  14. &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
  15. u_0 \leftarrow 0 \text{ ( infinity norm)} \\[-1.ex]
  16. &\rule{110mm}{0.4pt} \\
  17. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  18. &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
  19. &\hspace{5mm}if \: \lambda \neq 0 \\
  20. &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
  21. &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
  22. &\hspace{5mm}u_t \leftarrow \mathrm{max}(\beta_2 u_{t-1}, |g_{t}|+\epsilon) \\
  23. &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \frac{\gamma m_t}{(1-\beta^t_1) u_t} \\
  24. &\rule{110mm}{0.4pt} \\[-1.ex]
  25. &\bf{return} \: \theta_t \\[-1.ex]
  26. &\rule{110mm}{0.4pt} \\[-1.ex]
  27. \end{aligned}
  28. For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_.
  29. Args:
  30. params (iterable): iterable of parameters to optimize or dicts defining
  31. parameter groups
  32. lr (float, optional): learning rate (default: 2e-3)
  33. betas (Tuple[float, float], optional): coefficients used for computing
  34. running averages of gradient and its square
  35. eps (float, optional): term added to the denominator to improve
  36. numerical stability (default: 1e-8)
  37. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  38. foreach (bool, optional): whether foreach implementation of optimizer is used (default: None)
  39. maximize (bool, optional): maximize the params based on the objective, instead of
  40. minimizing (default: False)
  41. .. _Adam\: A Method for Stochastic Optimization:
  42. https://arxiv.org/abs/1412.6980
  43. """
  44. def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8,
  45. weight_decay=0, foreach: Optional[bool] = None, *, maximize: bool = False):
  46. if not 0.0 <= lr:
  47. raise ValueError("Invalid learning rate: {}".format(lr))
  48. if not 0.0 <= eps:
  49. raise ValueError("Invalid epsilon value: {}".format(eps))
  50. if not 0.0 <= betas[0] < 1.0:
  51. raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
  52. if not 0.0 <= betas[1] < 1.0:
  53. raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
  54. if not 0.0 <= weight_decay:
  55. raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
  56. defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
  57. foreach=foreach, maximize=maximize)
  58. super(Adamax, self).__init__(params, defaults)
  59. def __setstate__(self, state):
  60. super().__setstate__(state)
  61. for group in self.param_groups:
  62. group.setdefault('foreach', None)
  63. group.setdefault('maximize', False)
  64. state_values = list(self.state.values())
  65. step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step'])
  66. if not step_is_tensor:
  67. for s in state_values:
  68. s['step'] = torch.tensor(float(s['step']))
  69. @torch.no_grad()
  70. def step(self, closure=None):
  71. """Performs a single optimization step.
  72. Args:
  73. closure (callable, optional): A closure that reevaluates the model
  74. and returns the loss.
  75. """
  76. loss = None
  77. if closure is not None:
  78. with torch.enable_grad():
  79. loss = closure()
  80. for group in self.param_groups:
  81. params_with_grad = []
  82. grads = []
  83. exp_avgs = []
  84. exp_infs = []
  85. state_steps = []
  86. beta1, beta2 = group['betas']
  87. eps = group['eps']
  88. lr = group['lr']
  89. weight_decay = group['weight_decay']
  90. foreach = group['foreach']
  91. maximize = group['maximize']
  92. for p in group['params']:
  93. if p.grad is None:
  94. continue
  95. params_with_grad.append(p)
  96. if p.grad.is_sparse:
  97. raise RuntimeError('Adamax does not support sparse gradients')
  98. grads.append(p.grad)
  99. state = self.state[p]
  100. # State initialization
  101. if len(state) == 0:
  102. state['step'] = torch.tensor(0.)
  103. state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
  104. state['exp_inf'] = torch.zeros_like(p, memory_format=torch.preserve_format)
  105. exp_avgs.append(state['exp_avg'])
  106. exp_infs.append(state['exp_inf'])
  107. state_steps.append(state['step'])
  108. adamax(params_with_grad,
  109. grads,
  110. exp_avgs,
  111. exp_infs,
  112. state_steps,
  113. eps=eps,
  114. beta1=beta1,
  115. beta2=beta2,
  116. lr=lr,
  117. weight_decay=weight_decay,
  118. foreach=foreach,
  119. maximize=maximize)
  120. return loss
  121. def adamax(params: List[Tensor],
  122. grads: List[Tensor],
  123. exp_avgs: List[Tensor],
  124. exp_infs: List[Tensor],
  125. state_steps: List[Tensor],
  126. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  127. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  128. foreach: bool = None,
  129. maximize: bool = False,
  130. *,
  131. eps: float,
  132. beta1: float,
  133. beta2: float,
  134. lr: float,
  135. weight_decay: float):
  136. r"""Functional API that performs adamax algorithm computation.
  137. See :class:`~torch.optim.Adamax` for details.
  138. """
  139. if not all([isinstance(t, torch.Tensor) for t in state_steps]):
  140. raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")
  141. if foreach is None:
  142. # Placeholder for more complex foreach logic to be added when value is not set
  143. foreach = False
  144. if foreach and torch.jit.is_scripting():
  145. raise RuntimeError('torch.jit.script not supported with foreach optimizers')
  146. if foreach and not torch.jit.is_scripting():
  147. func = _multi_tensor_adamax
  148. else:
  149. func = _single_tensor_adamax
  150. func(params,
  151. grads,
  152. exp_avgs,
  153. exp_infs,
  154. state_steps,
  155. eps=eps,
  156. beta1=beta1,
  157. beta2=beta2,
  158. lr=lr,
  159. weight_decay=weight_decay,
  160. maximize=maximize)
  161. def _single_tensor_adamax(params: List[Tensor],
  162. grads: List[Tensor],
  163. exp_avgs: List[Tensor],
  164. exp_infs: List[Tensor],
  165. state_steps: List[Tensor],
  166. *,
  167. eps: float,
  168. beta1: float,
  169. beta2: float,
  170. lr: float,
  171. weight_decay: float,
  172. maximize: bool):
  173. for i, param in enumerate(params):
  174. grad = grads[i]
  175. grad = grad if not maximize else -grad
  176. exp_avg = exp_avgs[i]
  177. exp_inf = exp_infs[i]
  178. step_t = state_steps[i]
  179. # update step
  180. step_t += 1
  181. step = step_t.item()
  182. if weight_decay != 0:
  183. grad = grad.add(param, alpha=weight_decay)
  184. # Update biased first moment estimate.
  185. exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
  186. # Update the exponentially weighted infinity norm.
  187. norm_buf = torch.cat([
  188. exp_inf.mul_(beta2).unsqueeze(0),
  189. grad.abs().add_(eps).unsqueeze_(0)
  190. ], 0)
  191. torch.amax(norm_buf, 0, keepdim=False, out=exp_inf)
  192. bias_correction = 1 - beta1 ** step
  193. clr = lr / bias_correction
  194. param.addcdiv_(exp_avg, exp_inf, value=-clr)
  195. def _multi_tensor_adamax(params: List[Tensor],
  196. grads: List[Tensor],
  197. exp_avgs: List[Tensor],
  198. exp_infs: List[Tensor],
  199. state_steps: List[Tensor],
  200. *,
  201. beta1: float,
  202. beta2: float,
  203. lr: float,
  204. weight_decay: float,
  205. eps: float,
  206. maximize: bool):
  207. if len(params) == 0:
  208. return
  209. if maximize:
  210. grads = torch._foreach_neg(grads)
  211. # Update steps
  212. torch._foreach_add_(state_steps, 1)
  213. if weight_decay != 0:
  214. torch._foreach_add_(grads, params, alpha=weight_decay)
  215. # Update biased first moment estimate.
  216. torch._foreach_mul_(exp_avgs, beta1)
  217. torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1)
  218. # Update the exponentially weighted infinity norm.
  219. torch._foreach_mul_(exp_infs, beta2)
  220. for exp_inf, grad in zip(exp_infs, grads):
  221. norm_buf = torch.cat([
  222. exp_inf.unsqueeze(0),
  223. grad.abs().add_(eps).unsqueeze_(0)
  224. ], 0)
  225. torch.max(norm_buf, 0, keepdim=False, out=(exp_inf, exp_inf.new().long()))
  226. bias_corrections = [1 - beta1 ** step.item() for step in state_steps]
  227. clr = [-1 * (lr / bias_correction) for bias_correction in bias_corrections]
  228. torch._foreach_addcdiv_(params, exp_avgs, exp_infs, clr)