adadelta.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. import torch
  2. from torch import Tensor
  3. from .optimizer import Optimizer
  4. from typing import List, Optional
  5. class Adadelta(Optimizer):
  6. r"""Implements Adadelta algorithm.
  7. .. math::
  8. \begin{aligned}
  9. &\rule{110mm}{0.4pt} \\
  10. &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)},
  11. \: f(\theta) \text{ (objective)}, \: \rho \text{ (decay)},
  12. \: \lambda \text{ (weight decay)} \\
  13. &\textbf{initialize} : v_0 \leftarrow 0 \: \text{ (square avg)},
  14. \: u_0 \leftarrow 0 \: \text{ (accumulate variables)} \\[-1.ex]
  15. &\rule{110mm}{0.4pt} \\
  16. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  17. &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
  18. &\hspace{5mm}if \: \lambda \neq 0 \\
  19. &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
  20. &\hspace{5mm} v_t \leftarrow v_{t-1} \rho + g^2_t (1 - \rho) \\
  21. &\hspace{5mm}\Delta x_t \leftarrow \frac{\sqrt{u_{t-1} +
  22. \epsilon }}{ \sqrt{v_t + \epsilon} }g_t \hspace{21mm} \\
  23. &\hspace{5mm} u_t \leftarrow u_{t-1} \rho +
  24. \Delta x^2_t (1 - \rho) \\
  25. &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \Delta x_t \\
  26. &\rule{110mm}{0.4pt} \\[-1.ex]
  27. &\bf{return} \: \theta_t \\[-1.ex]
  28. &\rule{110mm}{0.4pt} \\[-1.ex]
  29. \end{aligned}
  30. For further details regarding the algorithm we refer to `ADADELTA: An Adaptive Learning Rate Method`_.
  31. Args:
  32. params (iterable): iterable of parameters to optimize or dicts defining
  33. parameter groups
  34. rho (float, optional): coefficient used for computing a running average
  35. of squared gradients (default: 0.9)
  36. eps (float, optional): term added to the denominator to improve
  37. numerical stability (default: 1e-6)
  38. lr (float, optional): coefficient that scale delta before it is applied
  39. to the parameters (default: 1.0)
  40. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  41. foreach (bool, optional): whether foreach implementation of optimizer is used (default: None)
  42. maximize (bool, optional): maximize the params based on the objective, instead of
  43. minimizing (default: False)
  44. .. _ADADELTA\: An Adaptive Learning Rate Method:
  45. https://arxiv.org/abs/1212.5701
  46. """
  47. def __init__(self, params, lr=1.0, rho=0.9, eps=1e-6, weight_decay=0,
  48. foreach: Optional[bool] = None, *, maximize: bool = False):
  49. if not 0.0 <= lr:
  50. raise ValueError("Invalid learning rate: {}".format(lr))
  51. if not 0.0 <= rho <= 1.0:
  52. raise ValueError("Invalid rho value: {}".format(rho))
  53. if not 0.0 <= eps:
  54. raise ValueError("Invalid epsilon value: {}".format(eps))
  55. if not 0.0 <= weight_decay:
  56. raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
  57. defaults = dict(lr=lr, rho=rho, eps=eps, weight_decay=weight_decay,
  58. maximize=maximize, foreach=foreach)
  59. super(Adadelta, self).__init__(params, defaults)
  60. def __setstate__(self, state):
  61. super().__setstate__(state)
  62. for group in self.param_groups:
  63. group.setdefault('foreach', None)
  64. group.setdefault('maximize', False)
  65. @torch.no_grad()
  66. def step(self, closure=None):
  67. """Performs a single optimization step.
  68. Args:
  69. closure (callable, optional): A closure that reevaluates the model
  70. and returns the loss.
  71. """
  72. loss = None
  73. if closure is not None:
  74. with torch.enable_grad():
  75. loss = closure()
  76. for group in self.param_groups:
  77. params_with_grad = []
  78. grads = []
  79. square_avgs = []
  80. acc_deltas = []
  81. lr, rho, eps, weight_decay, foreach, maximize = (group['lr'],
  82. group['rho'],
  83. group['eps'],
  84. group['weight_decay'],
  85. group['foreach'],
  86. group['maximize'])
  87. for p in group['params']:
  88. if p.grad is None:
  89. continue
  90. params_with_grad.append(p)
  91. if p.grad.is_sparse:
  92. raise RuntimeError('Adadelta does not support sparse gradients')
  93. grads.append(p.grad)
  94. state = self.state[p]
  95. # Lazy state initialization
  96. if len(state) == 0:
  97. state['step'] = 0
  98. state['square_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
  99. state['acc_delta'] = torch.zeros_like(p, memory_format=torch.preserve_format)
  100. square_avgs.append(state['square_avg'])
  101. acc_deltas.append(state['acc_delta'])
  102. state['step'] += 1
  103. adadelta(params_with_grad,
  104. grads,
  105. square_avgs,
  106. acc_deltas,
  107. lr=lr,
  108. rho=rho,
  109. eps=eps,
  110. weight_decay=weight_decay,
  111. foreach=foreach,
  112. maximize=maximize)
  113. return loss
  114. def adadelta(params: List[Tensor],
  115. grads: List[Tensor],
  116. square_avgs: List[Tensor],
  117. acc_deltas: List[Tensor],
  118. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  119. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  120. foreach: bool = None,
  121. *,
  122. lr: float,
  123. rho: float,
  124. eps: float,
  125. weight_decay: float,
  126. maximize: bool):
  127. r"""Functional API that performs Adadelta algorithm computation.
  128. See :class:`~torch.optim.Adadelta` for details.
  129. """
  130. if foreach is None:
  131. # Placeholder for more complex foreach logic to be added when value is not set
  132. foreach = False
  133. if foreach and torch.jit.is_scripting():
  134. raise RuntimeError('torch.jit.script not supported with foreach optimizers')
  135. if foreach and not torch.jit.is_scripting():
  136. func = _multi_tensor_adadelta
  137. else:
  138. func = _single_tensor_adadelta
  139. func(params,
  140. grads,
  141. square_avgs,
  142. acc_deltas,
  143. lr=lr,
  144. rho=rho,
  145. eps=eps,
  146. weight_decay=weight_decay,
  147. maximize=maximize)
  148. def _single_tensor_adadelta(params: List[Tensor],
  149. grads: List[Tensor],
  150. square_avgs: List[Tensor],
  151. acc_deltas: List[Tensor],
  152. *,
  153. lr: float,
  154. rho: float,
  155. eps: float,
  156. weight_decay: float,
  157. maximize: bool):
  158. for (param, grad, square_avg, acc_delta) in zip(params, grads, square_avgs, acc_deltas):
  159. grad = grad if not maximize else -grad
  160. if weight_decay != 0:
  161. grad = grad.add(param, alpha=weight_decay)
  162. if torch.is_complex(param):
  163. square_avg = torch.view_as_real(square_avg)
  164. acc_delta = torch.view_as_real(acc_delta)
  165. grad = torch.view_as_real(grad)
  166. square_avg.mul_(rho).addcmul_(grad, grad, value=1 - rho)
  167. std = square_avg.add(eps).sqrt_()
  168. delta = acc_delta.add(eps).sqrt_().div_(std).mul_(grad)
  169. acc_delta.mul_(rho).addcmul_(delta, delta, value=1 - rho)
  170. if torch.is_complex(param):
  171. delta = torch.view_as_complex(delta)
  172. param.add_(delta, alpha=-lr)
  173. def _multi_tensor_adadelta(params: List[Tensor],
  174. grads: List[Tensor],
  175. square_avgs: List[Tensor],
  176. acc_deltas: List[Tensor],
  177. *,
  178. lr: float,
  179. weight_decay: float,
  180. rho: float,
  181. eps: float,
  182. maximize: bool):
  183. if len(params) == 0:
  184. return
  185. if maximize:
  186. grads = torch._foreach_neg(grads)
  187. if weight_decay != 0:
  188. torch._foreach_add_(grads, params, alpha=weight_decay)
  189. torch._foreach_mul_(square_avgs, rho)
  190. torch._foreach_addcmul_(square_avgs, grads, grads, value=1 - rho)
  191. std = torch._foreach_add(square_avgs, eps)
  192. torch._foreach_sqrt_(std)
  193. deltas = torch._foreach_add(acc_deltas, eps)
  194. torch._foreach_sqrt_(deltas)
  195. torch._foreach_div_(deltas, std)
  196. torch._foreach_mul_(deltas, grads)
  197. torch._foreach_add_(params, deltas, alpha=-lr)
  198. torch._foreach_mul_(acc_deltas, rho)
  199. torch._foreach_addcmul_(acc_deltas, deltas, deltas, value=1 - rho)