rprop.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. import torch
  2. from torch import Tensor
  3. from .optimizer import Optimizer
  4. from typing import List, Optional
  5. class Rprop(Optimizer):
  6. r"""Implements the resilient backpropagation algorithm.
  7. .. math::
  8. \begin{aligned}
  9. &\rule{110mm}{0.4pt} \\
  10. &\textbf{input} : \theta_0 \in \mathbf{R}^d \text{ (params)},f(\theta)
  11. \text{ (objective)}, \\
  12. &\hspace{13mm} \eta_{+/-} \text{ (etaplus, etaminus)}, \Gamma_{max/min}
  13. \text{ (step sizes)} \\
  14. &\textbf{initialize} : g^0_{prev} \leftarrow 0,
  15. \: \eta_0 \leftarrow \text{lr (learning rate)} \\
  16. &\rule{110mm}{0.4pt} \\
  17. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  18. &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
  19. &\hspace{5mm} \textbf{for} \text{ } i = 0, 1, \ldots, d-1 \: \mathbf{do} \\
  20. &\hspace{10mm} \textbf{if} \: g^i_{prev} g^i_t > 0 \\
  21. &\hspace{15mm} \eta^i_t \leftarrow \mathrm{min}(\eta^i_{t-1} \eta_{+},
  22. \Gamma_{max}) \\
  23. &\hspace{10mm} \textbf{else if} \: g^i_{prev} g^i_t < 0 \\
  24. &\hspace{15mm} \eta^i_t \leftarrow \mathrm{max}(\eta^i_{t-1} \eta_{-},
  25. \Gamma_{min}) \\
  26. &\hspace{15mm} g^i_t \leftarrow 0 \\
  27. &\hspace{10mm} \textbf{else} \: \\
  28. &\hspace{15mm} \eta^i_t \leftarrow \eta^i_{t-1} \\
  29. &\hspace{5mm}\theta_t \leftarrow \theta_{t-1}- \eta_t \mathrm{sign}(g_t) \\
  30. &\hspace{5mm}g_{prev} \leftarrow g_t \\
  31. &\rule{110mm}{0.4pt} \\[-1.ex]
  32. &\bf{return} \: \theta_t \\[-1.ex]
  33. &\rule{110mm}{0.4pt} \\[-1.ex]
  34. \end{aligned}
  35. For further details regarding the algorithm we refer to the paper
  36. `A Direct Adaptive Method for Faster Backpropagation Learning: The RPROP Algorithm
  37. <http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.21.1417>`_.
  38. Args:
  39. params (iterable): iterable of parameters to optimize or dicts defining
  40. parameter groups
  41. lr (float, optional): learning rate (default: 1e-2)
  42. etas (Tuple[float, float], optional): pair of (etaminus, etaplis), that
  43. are multiplicative increase and decrease factors
  44. (default: (0.5, 1.2))
  45. step_sizes (Tuple[float, float], optional): a pair of minimal and
  46. maximal allowed step sizes (default: (1e-6, 50))
  47. foreach (bool, optional): whether foreach implementation of optimizer
  48. is used (default: None)
  49. """
  50. def __init__(self, params, lr=1e-2, etas=(0.5, 1.2), step_sizes=(1e-6, 50),
  51. foreach: Optional[bool] = None):
  52. if not 0.0 <= lr:
  53. raise ValueError("Invalid learning rate: {}".format(lr))
  54. if not 0.0 < etas[0] < 1.0 < etas[1]:
  55. raise ValueError("Invalid eta values: {}, {}".format(etas[0], etas[1]))
  56. defaults = dict(lr=lr, etas=etas, step_sizes=step_sizes, foreach=foreach)
  57. super(Rprop, self).__init__(params, defaults)
  58. def __setstate__(self, state):
  59. super().__setstate__(state)
  60. for group in self.param_groups:
  61. group.setdefault('foreach', None)
  62. @torch.no_grad()
  63. def step(self, closure=None):
  64. """Performs a single optimization step.
  65. Args:
  66. closure (callable, optional): A closure that reevaluates the model
  67. and returns the loss.
  68. """
  69. loss = None
  70. if closure is not None:
  71. with torch.enable_grad():
  72. loss = closure()
  73. for group in self.param_groups:
  74. params = []
  75. grads = []
  76. prevs = []
  77. step_sizes = []
  78. etaminus, etaplus = group['etas']
  79. step_size_min, step_size_max = group['step_sizes']
  80. foreach = group['foreach']
  81. for p in group['params']:
  82. if p.grad is None:
  83. continue
  84. params.append(p)
  85. grad = p.grad
  86. if grad.is_sparse:
  87. raise RuntimeError('Rprop does not support sparse gradients')
  88. grads.append(grad)
  89. state = self.state[p]
  90. # State initialization
  91. if len(state) == 0:
  92. state['step'] = 0
  93. state['prev'] = torch.zeros_like(p, memory_format=torch.preserve_format)
  94. state['step_size'] = grad.new().resize_as_(grad).fill_(group['lr'])
  95. prevs.append(state['prev'])
  96. step_sizes.append(state['step_size'])
  97. state['step'] += 1
  98. rprop(params,
  99. grads,
  100. prevs,
  101. step_sizes,
  102. step_size_min=step_size_min,
  103. step_size_max=step_size_max,
  104. etaminus=etaminus,
  105. etaplus=etaplus,
  106. foreach=foreach)
  107. return loss
  108. def rprop(params: List[Tensor],
  109. grads: List[Tensor],
  110. prevs: List[Tensor],
  111. step_sizes: List[Tensor],
  112. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  113. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  114. foreach: bool = None,
  115. *,
  116. step_size_min: float,
  117. step_size_max: float,
  118. etaminus: float,
  119. etaplus: float):
  120. r"""Functional API that performs rprop algorithm computation.
  121. See :class:`~torch.optim.Rprop` for details.
  122. """
  123. if foreach is None:
  124. # Placeholder for more complex foreach logic to be added when value is not set
  125. foreach = False
  126. if foreach and torch.jit.is_scripting():
  127. raise RuntimeError('torch.jit.script not supported with foreach optimizers')
  128. if foreach and not torch.jit.is_scripting():
  129. func = _multi_tensor_rprop
  130. else:
  131. func = _single_tensor_rprop
  132. func(params,
  133. grads,
  134. prevs,
  135. step_sizes,
  136. step_size_min=step_size_min,
  137. step_size_max=step_size_max,
  138. etaminus=etaminus,
  139. etaplus=etaplus)
  140. def _single_tensor_rprop(params: List[Tensor],
  141. grads: List[Tensor],
  142. prevs: List[Tensor],
  143. step_sizes: List[Tensor],
  144. *,
  145. step_size_min: float,
  146. step_size_max: float,
  147. etaminus: float,
  148. etaplus: float):
  149. for i, param in enumerate(params):
  150. grad = grads[i]
  151. prev = prevs[i]
  152. step_size = step_sizes[i]
  153. sign = grad.mul(prev).sign()
  154. sign[sign.gt(0)] = etaplus
  155. sign[sign.lt(0)] = etaminus
  156. sign[sign.eq(0)] = 1
  157. # update stepsizes with step size updates
  158. step_size.mul_(sign).clamp_(step_size_min, step_size_max)
  159. # for dir<0, dfdx=0
  160. # for dir>=0 dfdx=dfdx
  161. grad = grad.clone(memory_format=torch.preserve_format)
  162. grad[sign.eq(etaminus)] = 0
  163. # update parameters
  164. param.addcmul_(grad.sign(), step_size, value=-1)
  165. prev.copy_(grad)
  166. def _multi_tensor_rprop(params: List[Tensor],
  167. grads: List[Tensor],
  168. prevs: List[Tensor],
  169. step_sizes: List[Tensor],
  170. *,
  171. step_size_min: float,
  172. step_size_max: float,
  173. etaminus: float,
  174. etaplus: float):
  175. if len(params) == 0:
  176. return
  177. signs = torch._foreach_mul(grads, prevs)
  178. signs = [s.sign() for s in signs]
  179. for sign in signs:
  180. sign[sign.gt(0)] = etaplus
  181. sign[sign.lt(0)] = etaminus
  182. sign[sign.eq(0)] = 1
  183. # update stepsizes with step size updates
  184. torch._foreach_mul_(step_sizes, signs)
  185. for step_size in step_sizes:
  186. step_size.clamp_(step_size_min, step_size_max)
  187. # for dir<0, dfdx=0
  188. # for dir>=0 dfdx=dfdx
  189. for i in range(len(grads)):
  190. grads[i] = grads[i].clone(memory_format=torch.preserve_format)
  191. grads[i][signs[i].eq(etaminus)] = 0
  192. # update parameters
  193. grad_signs = [grad.sign() for grad in grads]
  194. torch._foreach_addcmul_(params, grad_signs, step_sizes, value=-1)
  195. for i in range(len(prevs)):
  196. prevs[i].copy_(grads[i])