asgd.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. import math
  2. import torch
  3. from torch import Tensor
  4. from .optimizer import Optimizer
  5. from typing import List, Optional
  6. class ASGD(Optimizer):
  7. """Implements Averaged Stochastic Gradient Descent.
  8. It has been proposed in `Acceleration of stochastic approximation by
  9. averaging`_.
  10. Args:
  11. params (iterable): iterable of parameters to optimize or dicts defining
  12. parameter groups
  13. lr (float, optional): learning rate (default: 1e-2)
  14. lambd (float, optional): decay term (default: 1e-4)
  15. alpha (float, optional): power for eta update (default: 0.75)
  16. t0 (float, optional): point at which to start averaging (default: 1e6)
  17. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  18. foreach (bool, optional): whether foreach implementation of optimizer
  19. is used (default: None)
  20. .. _Acceleration of stochastic approximation by averaging:
  21. https://dl.acm.org/citation.cfm?id=131098
  22. """
  23. def __init__(self, params, lr=1e-2, lambd=1e-4, alpha=0.75, t0=1e6, weight_decay=0,
  24. foreach: Optional[bool] = None):
  25. if not 0.0 <= lr:
  26. raise ValueError("Invalid learning rate: {}".format(lr))
  27. if not 0.0 <= weight_decay:
  28. raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
  29. defaults = dict(lr=lr, lambd=lambd, alpha=alpha, t0=t0,
  30. weight_decay=weight_decay, foreach=foreach)
  31. super(ASGD, self).__init__(params, defaults)
  32. def __setstate__(self, state):
  33. super().__setstate__(state)
  34. for group in self.param_groups:
  35. group.setdefault('foreach', None)
  36. state_values = list(self.state.values())
  37. step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step'])
  38. if not step_is_tensor:
  39. for s in state_values:
  40. s['step'] = torch.tensor(float(s['step']))
  41. eta_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['eta'])
  42. if not eta_is_tensor:
  43. for s in state_values:
  44. s['eta'] = torch.tensor(s['eta'])
  45. mu_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['mu'])
  46. if not mu_is_tensor:
  47. for s in state_values:
  48. s['mu'] = torch.tensor(float(s['mu']))
  49. @torch.no_grad()
  50. def step(self, closure=None):
  51. """Performs a single optimization step.
  52. Args:
  53. closure (callable, optional): A closure that reevaluates the model
  54. and returns the loss.
  55. """
  56. loss = None
  57. if closure is not None:
  58. with torch.enable_grad():
  59. loss = closure()
  60. for group in self.param_groups:
  61. params_with_grad = []
  62. grads = []
  63. mus = []
  64. axs = []
  65. etas = []
  66. state_steps = []
  67. for p in group['params']:
  68. if p.grad is not None:
  69. params_with_grad.append(p)
  70. if p.grad.is_sparse:
  71. raise RuntimeError('ASGD does not support sparse gradients')
  72. grads.append(p.grad)
  73. state = self.state[p]
  74. # State initialization
  75. if len(state) == 0:
  76. state['step'] = torch.tensor(0.)
  77. state['eta'] = torch.tensor(group['lr'])
  78. state['mu'] = torch.tensor(1.)
  79. state['ax'] = torch.zeros_like(p, memory_format=torch.preserve_format)
  80. mus.append(state['mu'])
  81. axs.append(state['ax'])
  82. etas.append(state['eta'])
  83. state_steps.append(state['step'])
  84. asgd(params_with_grad,
  85. grads,
  86. axs,
  87. mus,
  88. etas,
  89. state_steps,
  90. lambd=group['lambd'],
  91. lr=group['lr'],
  92. t0=group['t0'],
  93. alpha=group['alpha'],
  94. weight_decay=group['weight_decay'],
  95. foreach=group['foreach'])
  96. return loss
  97. def asgd(params: List[Tensor],
  98. grads: List[Tensor],
  99. axs: List[Tensor],
  100. mus: List[Tensor],
  101. etas: List[Tensor],
  102. state_steps: List[Tensor],
  103. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  104. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  105. foreach: bool = None,
  106. *,
  107. lambd: float,
  108. lr: float,
  109. t0: float,
  110. alpha: float,
  111. weight_decay: float):
  112. r"""Functional API that performs asgd algorithm computation.
  113. See :class:`~torch.optim.ASGD` for details.
  114. """
  115. if foreach is None:
  116. # Placeholder for more complex foreach logic to be added when value is not set
  117. foreach = False
  118. if foreach and torch.jit.is_scripting():
  119. raise RuntimeError('torch.jit.script not supported with foreach optimizers')
  120. if foreach and not torch.jit.is_scripting():
  121. func = _multi_tensor_asgd
  122. else:
  123. func = _single_tensor_asgd
  124. func(params,
  125. grads,
  126. axs,
  127. mus,
  128. etas,
  129. state_steps,
  130. lambd=lambd,
  131. lr=lr,
  132. t0=t0,
  133. alpha=alpha,
  134. weight_decay=weight_decay)
  135. def _single_tensor_asgd(params: List[Tensor],
  136. grads: List[Tensor],
  137. axs: List[Tensor],
  138. mus: List[Tensor],
  139. etas: List[Tensor],
  140. state_steps: List[Tensor],
  141. *,
  142. lambd: float,
  143. lr: float,
  144. t0: float,
  145. alpha: float,
  146. weight_decay: float):
  147. for i, param in enumerate(params):
  148. grad = grads[i]
  149. mu = mus[i]
  150. ax = axs[i]
  151. eta = etas[i]
  152. step_t = state_steps[i]
  153. # update step
  154. step_t += 1
  155. step = step_t.item()
  156. if weight_decay != 0:
  157. grad = grad.add(param, alpha=weight_decay)
  158. # decay term
  159. param.mul_(1 - lambd * eta.item())
  160. # update parameter
  161. param.add_(grad, alpha=-eta.item())
  162. # averaging
  163. if mu.item() != 1:
  164. ax.add_(param.sub(ax).mul(mu))
  165. else:
  166. ax.copy_(param)
  167. new_eta = torch.tensor(lr / math.pow((1 + lambd * lr * step), alpha))
  168. eta.copy_(new_eta)
  169. new_mu = torch.tensor(1 / max(1, step - t0))
  170. mu.copy_(new_mu)
  171. def _multi_tensor_asgd(params: List[Tensor],
  172. grads: List[Tensor],
  173. axs: List[Tensor],
  174. mus: List[Tensor],
  175. etas: List[Tensor],
  176. state_steps: List[Tensor],
  177. *,
  178. lambd: float,
  179. lr: float,
  180. t0: float,
  181. alpha: float,
  182. weight_decay: float):
  183. if len(params) == 0:
  184. return
  185. # update step
  186. torch._foreach_add_(state_steps, 1)
  187. if weight_decay != 0:
  188. torch._foreach_add_(grads, params, alpha=weight_decay)
  189. # decay term
  190. eta = etas[0].item()
  191. torch._foreach_mul_(params, 1 - lambd * eta)
  192. # update parameter
  193. torch._foreach_add_(params, grads, alpha=-eta)
  194. # averaging
  195. for i in range(len(axs)):
  196. if mus[i].item() != 1:
  197. axs[i].add_(params[i].sub(axs[i]).mul(mus[i]))
  198. else:
  199. axs[i].copy_(params[i])
  200. # update eta and mu
  201. for i in range(len(mus)):
  202. new_eta = torch.tensor(lr / math.pow((1 + lambd * lr * state_steps[i].item()), alpha))
  203. etas[i].copy_(new_eta)
  204. new_mu = torch.tensor(1 / max(1, state_steps[i].item() - t0))
  205. mus[i].copy_(new_mu)